1use crate::base::iana::{Opcode, Rcode};
3use crate::base::message::{CopyRecordsError, ShortMessage};
4use crate::base::message_builder::{
5 AdditionalBuilder, MessageBuilder, PushError,
6};
7use crate::base::opt::{ComposeOptData, LongOptData, OptRecord};
8use crate::base::wire::{Composer, ParseError};
9use crate::base::{
10 Header, Message, Rtype, StaticCompressor, UnknownRecordData,
11};
12use bytes::Bytes;
13use octseq::Octets;
14use std::boxed::Box;
15use std::fmt::Debug;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::vec::Vec;
20use std::{error, fmt};
21use tracing::trace;
22
23#[cfg(feature = "tsig")]
24use crate::tsig;
25
26pub trait ComposeRequest: Debug + Send + Sync {
30 fn append_message<Target: Composer>(
32 &self,
33 target: Target,
34 ) -> Result<AdditionalBuilder<Target>, CopyRecordsError>;
35
36 fn to_message(&self) -> Result<Message<Vec<u8>>, Error>;
38
39 fn to_vec(&self) -> Result<Vec<u8>, Error>;
42
43 fn header(&self) -> &Header;
45
46 fn header_mut(&mut self) -> &mut Header;
48
49 fn set_udp_payload_size(&mut self, value: u16);
51
52 fn set_dnssec_ok(&mut self, value: bool);
54
55 fn add_opt(
57 &mut self,
58 opt: &impl ComposeOptData,
59 ) -> Result<(), LongOptData>;
60
61 fn is_answer(&self, answer: &Message<[u8]>) -> bool;
63
64 fn dnssec_ok(&self) -> bool;
66}
67
68pub trait ComposeRequestMulti: Debug + Send + Sync {
72 fn append_message<Target: Composer>(
74 &self,
75 target: Target,
76 ) -> Result<AdditionalBuilder<Target>, CopyRecordsError>;
77
78 fn to_message(&self) -> Result<Message<Vec<u8>>, Error>;
80
81 fn header(&self) -> &Header;
83
84 fn header_mut(&mut self) -> &mut Header;
86
87 fn set_udp_payload_size(&mut self, value: u16);
89
90 fn set_dnssec_ok(&mut self, value: bool);
92
93 fn add_opt(
95 &mut self,
96 opt: &impl ComposeOptData,
97 ) -> Result<(), LongOptData>;
98
99 fn is_answer(&self, answer: &Message<[u8]>) -> bool;
101
102 fn dnssec_ok(&self) -> bool;
104}
105
106pub trait SendRequest<CR> {
113 fn send_request(
115 &self,
116 request_msg: CR,
117 ) -> Box<dyn GetResponse + Send + Sync>;
118}
119
120impl<T: SendRequest<RequestMessage<Octs>> + ?Sized, Octs: Octets>
121 SendRequest<RequestMessage<Octs>> for Box<T>
122{
123 fn send_request(
124 &self,
125 request_msg: RequestMessage<Octs>,
126 ) -> Box<dyn GetResponse + Send + Sync> {
127 (**self).send_request(request_msg)
128 }
129}
130
131pub trait SendRequestMulti<CR> {
138 fn send_request(
140 &self,
141 request_msg: CR,
142 ) -> Box<dyn GetResponseMulti + Send + Sync>;
143}
144
145impl<T: SendRequestMulti<RequestMessage<Octs>> + ?Sized, Octs: Octets>
146 SendRequestMulti<RequestMessage<Octs>> for Box<T>
147{
148 fn send_request(
149 &self,
150 request_msg: RequestMessage<Octs>,
151 ) -> Box<dyn GetResponseMulti + Send + Sync> {
152 (**self).send_request(request_msg)
153 }
154}
155
156pub trait GetResponse: Debug {
163 fn get_response(
167 &mut self,
168 ) -> Pin<
169 Box<
170 dyn Future<Output = Result<Message<Bytes>, Error>>
171 + Send
172 + Sync
173 + '_,
174 >,
175 >;
176}
177
178#[allow(clippy::type_complexity)]
184pub trait GetResponseMulti: Debug {
185 fn get_response(
189 &mut self,
190 ) -> Pin<
191 Box<
192 dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
193 + Send
194 + Sync
195 + '_,
196 >,
197 >;
198}
199
200#[derive(Clone, Debug)]
204pub struct RequestMessage<Octs: AsRef<[u8]>> {
205 msg: Message<Octs>,
207
208 header: Header,
210
211 opt: Option<OptRecord<Vec<u8>>>,
213}
214
215impl<Octs: AsRef<[u8]> + Debug + Octets> RequestMessage<Octs> {
216 pub fn new(msg: impl Into<Message<Octs>>) -> Result<Self, Error> {
218 let msg = msg.into();
219
220 if msg.header().opcode() == Opcode::QUERY
224 && msg.first_question().ok_or(Error::FormError)?.qtype()
225 == Rtype::AXFR
226 {
227 return Err(Error::FormError);
228 }
229
230 let header = msg.header();
231 Ok(Self {
232 msg,
233 header,
234 opt: None,
235 })
236 }
237
238 fn opt_mut(&mut self) -> &mut OptRecord<Vec<u8>> {
242 self.opt.get_or_insert_with(Default::default)
243 }
244
245 fn append_message_impl<Target: Composer>(
247 &self,
248 mut target: MessageBuilder<Target>,
249 ) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
250 let source = &self.msg;
251
252 *target.header_mut() = self.header;
253
254 let source = source.question();
255 let mut target = target.question();
256 for rr in source {
257 target.push(rr?)?;
258 }
259 let mut source = source.answer()?;
260 let mut target = target.answer();
261 for rr in &mut source {
262 let rr = rr?
263 .into_record::<UnknownRecordData<_>>()?
264 .expect("record expected");
265 target.push(rr)?;
266 }
267
268 let mut source =
269 source.next_section()?.expect("section should be present");
270 let mut target = target.authority();
271 for rr in &mut source {
272 let rr = rr?
273 .into_record::<UnknownRecordData<_>>()?
274 .expect("record expected");
275 target.push(rr)?;
276 }
277
278 let source =
279 source.next_section()?.expect("section should be present");
280 let mut target = target.additional();
281 for rr in source {
282 let rr = rr?;
283 if rr.rtype() != Rtype::OPT {
284 let rr = rr
285 .into_record::<UnknownRecordData<_>>()?
286 .expect("record expected");
287 target.push(rr)?;
288 }
289 }
290
291 if let Some(opt) = self.opt.as_ref() {
292 target.push(opt.as_record())?;
293 }
294
295 Ok(target)
296 }
297
298 fn to_message_impl(&self) -> Result<Message<Vec<u8>>, Error> {
300 let target =
301 MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
302 .expect("Vec is expected to have enough space");
303
304 let target = self.append_message_impl(target)?;
305
306 let result = target.as_builder().clone();
310 let msg = Message::from_octets(result.finish().into_target()).expect(
311 "Message should be able to parse output from MessageBuilder",
312 );
313 Ok(msg)
314 }
315}
316
317impl<Octs: AsRef<[u8]> + Debug + Octets + Send + Sync> ComposeRequest
318 for RequestMessage<Octs>
319{
320 fn append_message<Target: Composer>(
321 &self,
322 target: Target,
323 ) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
324 let target = MessageBuilder::from_target(target)
325 .map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?;
326 let builder = self.append_message_impl(target)?;
327 Ok(builder)
328 }
329
330 fn to_vec(&self) -> Result<Vec<u8>, Error> {
331 let msg = self.to_message()?;
332 Ok(msg.as_octets().clone())
333 }
334
335 fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
336 self.to_message_impl()
337 }
338
339 fn header(&self) -> &Header {
340 &self.header
341 }
342
343 fn header_mut(&mut self) -> &mut Header {
344 &mut self.header
345 }
346
347 fn set_udp_payload_size(&mut self, value: u16) {
348 self.opt_mut().set_udp_payload_size(value);
349 }
350
351 fn set_dnssec_ok(&mut self, value: bool) {
352 self.opt_mut().set_dnssec_ok(value);
353 }
354
355 fn add_opt(
356 &mut self,
357 opt: &impl ComposeOptData,
358 ) -> Result<(), LongOptData> {
359 self.opt_mut().push(opt).map_err(|e| e.unlimited_buf())
360 }
361
362 fn is_answer(&self, answer: &Message<[u8]>) -> bool {
363 let answer_header = answer.header();
364 let answer_hcounts = answer.header_counts();
365
366 if !answer_header.qr() || answer_header.id() != self.header.id() {
368 trace!(
369 "Wrong QR or ID: QR={}, answer ID={}, self ID={}",
370 answer_header.qr(),
371 answer_header.id(),
372 self.header.id()
373 );
374 return false;
375 }
376
377 if answer_header.rcode() != Rcode::NOERROR
380 && answer_hcounts.qdcount() == 0
381 && answer_hcounts.ancount() == 0
382 && answer_hcounts.nscount() == 0
383 && answer_hcounts.arcount() == 0
384 {
385 return true;
387 }
388
389 if answer_hcounts.qdcount() != self.msg.header_counts().qdcount() {
392 trace!("Wrong QD count");
393 false
394 } else {
395 let res = answer.question() == self.msg.for_slice().question();
396 if !res {
397 trace!("Wrong question");
398 }
399 res
400 }
401 }
402
403 fn dnssec_ok(&self) -> bool {
404 match &self.opt {
405 None => false,
406 Some(opt) => opt.dnssec_ok(),
407 }
408 }
409}
410
411#[derive(Clone, Debug)]
415pub struct RequestMessageMulti<Octs>
416where
417 Octs: AsRef<[u8]>,
418{
419 msg: Message<Octs>,
421
422 header: Header,
424
425 opt: Option<OptRecord<Vec<u8>>>,
427}
428
429impl<Octs: AsRef<[u8]> + Debug + Octets> RequestMessageMulti<Octs> {
430 pub fn new(msg: impl Into<Message<Octs>>) -> Result<Self, Error> {
432 let msg = msg.into();
433
434 if !msg.is_xfr() {
436 return Err(Error::FormError);
437 }
438 let header = msg.header();
439 Ok(Self {
440 msg,
441 header,
442 opt: None,
443 })
444 }
445
446 fn opt_mut(&mut self) -> &mut OptRecord<Vec<u8>> {
450 self.opt.get_or_insert_with(Default::default)
451 }
452
453 fn append_message_impl<Target: Composer>(
455 &self,
456 mut target: MessageBuilder<Target>,
457 ) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
458 let source = &self.msg;
459
460 *target.header_mut() = self.header;
461
462 let source = source.question();
463 let mut target = target.question();
464 for rr in source {
465 target.push(rr?)?;
466 }
467 let mut source = source.answer()?;
468 let mut target = target.answer();
469 for rr in &mut source {
470 let rr = rr?
471 .into_record::<UnknownRecordData<_>>()?
472 .expect("record expected");
473 target.push(rr)?;
474 }
475
476 let mut source =
477 source.next_section()?.expect("section should be present");
478 let mut target = target.authority();
479 for rr in &mut source {
480 let rr = rr?
481 .into_record::<UnknownRecordData<_>>()?
482 .expect("record expected");
483 target.push(rr)?;
484 }
485
486 let source =
487 source.next_section()?.expect("section should be present");
488 let mut target = target.additional();
489 for rr in source {
490 let rr = rr?;
491 if rr.rtype() != Rtype::OPT {
492 let rr = rr
493 .into_record::<UnknownRecordData<_>>()?
494 .expect("record expected");
495 target.push(rr)?;
496 }
497 }
498
499 if let Some(opt) = self.opt.as_ref() {
500 target.push(opt.as_record())?;
501 }
502
503 Ok(target)
504 }
505
506 fn to_message_impl(&self) -> Result<Message<Vec<u8>>, Error> {
508 let target =
509 MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
510 .expect("Vec is expected to have enough space");
511
512 let target = self.append_message_impl(target)?;
513
514 let result = target.as_builder().clone();
518 let msg = Message::from_octets(result.finish().into_target()).expect(
519 "Message should be able to parse output from MessageBuilder",
520 );
521 Ok(msg)
522 }
523}
524
525impl<Octs: AsRef<[u8]> + Debug + Octets + Send + Sync> ComposeRequestMulti
526 for RequestMessageMulti<Octs>
527{
528 fn append_message<Target: Composer>(
529 &self,
530 target: Target,
531 ) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
532 let target = MessageBuilder::from_target(target)
533 .map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?;
534 let builder = self.append_message_impl(target)?;
535 Ok(builder)
536 }
537
538 fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
539 self.to_message_impl()
540 }
541
542 fn header(&self) -> &Header {
543 &self.header
544 }
545
546 fn header_mut(&mut self) -> &mut Header {
547 &mut self.header
548 }
549
550 fn set_udp_payload_size(&mut self, value: u16) {
551 self.opt_mut().set_udp_payload_size(value);
552 }
553
554 fn set_dnssec_ok(&mut self, value: bool) {
555 self.opt_mut().set_dnssec_ok(value);
556 }
557
558 fn add_opt(
559 &mut self,
560 opt: &impl ComposeOptData,
561 ) -> Result<(), LongOptData> {
562 self.opt_mut().push(opt).map_err(|e| e.unlimited_buf())
563 }
564
565 fn is_answer(&self, answer: &Message<[u8]>) -> bool {
566 let answer_header = answer.header();
567 let answer_hcounts = answer.header_counts();
568
569 if !answer_header.qr() || answer_header.id() != self.header.id() {
571 trace!(
572 "Wrong QR or ID: QR={}, answer ID={}, self ID={}",
573 answer_header.qr(),
574 answer_header.id(),
575 self.header.id()
576 );
577 return false;
578 }
579
580 if answer_header.rcode() != Rcode::NOERROR
583 && answer_hcounts.qdcount() == 0
584 && answer_hcounts.ancount() == 0
585 && answer_hcounts.nscount() == 0
586 && answer_hcounts.arcount() == 0
587 {
588 return true;
590 }
591
592 if self.msg.qtype() == Some(Rtype::AXFR)
602 && answer_hcounts.qdcount() == 0
603 {
604 true
605 } else if answer_hcounts.qdcount()
606 != self.msg.header_counts().qdcount()
607 {
608 trace!("Wrong QD count");
609 false
610 } else {
611 let res = answer.question() == self.msg.for_slice().question();
612 if !res {
613 trace!("Wrong question");
614 }
615 res
616 }
617 }
618
619 fn dnssec_ok(&self) -> bool {
620 match &self.opt {
621 None => false,
622 Some(opt) => opt.dnssec_ok(),
623 }
624 }
625}
626
627#[derive(Clone, Debug)]
631pub enum Error {
632 ConnectionClosed,
634
635 OptTooLong,
637
638 MessageBuilderPushError,
640
641 MessageParseError,
643
644 RedundantTransportNotFound,
646
647 FormError,
649
650 ShortMessage,
652
653 StreamLongMessage,
655
656 StreamIdleTimeout,
658
659 StreamReceiveError,
662
663 StreamReadError(Arc<std::io::Error>),
665
666 StreamReadTimeout,
668
669 StreamTooManyOutstandingQueries,
671
672 StreamWriteError(Arc<std::io::Error>),
674
675 StreamUnexpectedEndOfData,
677
678 WrongReplyForQuery,
680
681 NoTransportAvailable,
683
684 Dgram(Arc<super::dgram::QueryError>),
686
687 #[cfg(feature = "unstable-server-transport")]
688 ZoneWrite,
690
691 #[cfg(feature = "tsig")]
692 Authentication(tsig::ValidationError),
694
695 #[cfg(all(
696 feature = "unstable-validator",
697 any(feature = "ring", feature = "openssl")
698 ))]
699 Validation(crate::dnssec::validator::context::Error),
701}
702
703impl From<LongOptData> for Error {
704 fn from(_: LongOptData) -> Self {
705 Self::OptTooLong
706 }
707}
708
709impl From<ParseError> for Error {
710 fn from(_: ParseError) -> Self {
711 Self::MessageParseError
712 }
713}
714
715impl From<ShortMessage> for Error {
716 fn from(_: ShortMessage) -> Self {
717 Self::ShortMessage
718 }
719}
720
721impl From<super::dgram::QueryError> for Error {
722 fn from(err: super::dgram::QueryError) -> Self {
723 Self::Dgram(err.into())
724 }
725}
726
727#[cfg(all(
728 feature = "unstable-validator",
729 any(feature = "ring", feature = "openssl")
730))]
731impl From<crate::dnssec::validator::context::Error> for Error {
732 fn from(err: crate::dnssec::validator::context::Error) -> Self {
733 Self::Validation(err)
734 }
735}
736
737impl fmt::Display for Error {
738 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
739 match self {
740 Error::ConnectionClosed => write!(f, "connection closed"),
741 Error::OptTooLong => write!(f, "OPT record is too long"),
742 Error::MessageBuilderPushError => {
743 write!(f, "PushError from MessageBuilder")
744 }
745 Error::MessageParseError => write!(f, "ParseError from Message"),
746 Error::RedundantTransportNotFound => write!(
747 f,
748 "Underlying transport not found in redundant connection"
749 ),
750 Error::ShortMessage => {
751 write!(f, "octet sequence to short to be a valid message")
752 }
753 Error::FormError => {
754 write!(f, "message violates a constraint")
755 }
756 Error::StreamLongMessage => {
757 write!(f, "message too long for stream transport")
758 }
759 Error::StreamIdleTimeout => {
760 write!(f, "stream was idle for too long")
761 }
762 Error::StreamReceiveError => write!(f, "error receiving a reply"),
763 Error::StreamReadError(_) => {
764 write!(f, "error reading from stream")
765 }
766 Error::StreamReadTimeout => {
767 write!(f, "timeout reading from stream")
768 }
769 Error::StreamTooManyOutstandingQueries => {
770 write!(f, "too many outstanding queries on stream")
771 }
772 Error::StreamWriteError(_) => {
773 write!(f, "error writing to stream")
774 }
775 Error::StreamUnexpectedEndOfData => {
776 write!(f, "unexpected end of data")
777 }
778 Error::WrongReplyForQuery => {
779 write!(f, "reply does not match query")
780 }
781 Error::NoTransportAvailable => {
782 write!(f, "no transport available")
783 }
784 Error::Dgram(err) => fmt::Display::fmt(err, f),
785
786 #[cfg(feature = "unstable-server-transport")]
787 Error::ZoneWrite => write!(f, "error writing to zone"),
788
789 #[cfg(feature = "tsig")]
790 Error::Authentication(err) => fmt::Display::fmt(err, f),
791
792 #[cfg(all(
793 feature = "unstable-validator",
794 any(feature = "ring", feature = "openssl")
795 ))]
796 Error::Validation(_) => {
797 write!(f, "error validating response")
798 }
799 }
800 }
801}
802
803impl From<CopyRecordsError> for Error {
804 fn from(err: CopyRecordsError) -> Self {
805 match err {
806 CopyRecordsError::Parse(_) => Self::MessageParseError,
807 CopyRecordsError::Push(_) => Self::MessageBuilderPushError,
808 }
809 }
810}
811
812impl error::Error for Error {
813 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
814 match self {
815 Error::ConnectionClosed => None,
816 Error::OptTooLong => None,
817 Error::MessageBuilderPushError => None,
818 Error::MessageParseError => None,
819 Error::RedundantTransportNotFound => None,
820 Error::ShortMessage => None,
821 Error::FormError => None,
822 Error::StreamLongMessage => None,
823 Error::StreamIdleTimeout => None,
824 Error::StreamReceiveError => None,
825 Error::StreamReadError(e) => Some(e),
826 Error::StreamReadTimeout => None,
827 Error::StreamTooManyOutstandingQueries => None,
828 Error::StreamWriteError(e) => Some(e),
829 Error::StreamUnexpectedEndOfData => None,
830 Error::WrongReplyForQuery => None,
831 Error::NoTransportAvailable => None,
832 Error::Dgram(err) => Some(err),
833
834 #[cfg(feature = "unstable-server-transport")]
835 Error::ZoneWrite => None,
836
837 #[cfg(feature = "tsig")]
838 Error::Authentication(e) => Some(e),
839
840 #[cfg(all(
841 feature = "unstable-validator",
842 any(feature = "ring", feature = "openssl")
843 ))]
844 Error::Validation(e) => Some(e),
845 }
846 }
847}