domain/net/client/stream.rs
1//! A client transport using a stream socket.
2
3// RFC 7766 describes DNS over TCP
4// RFC 7828 describes the edns-tcp-keepalive option
5
6use super::request::{
7 ComposeRequest, ComposeRequestMulti, Error, GetResponse,
8 GetResponseMulti, SendRequest, SendRequestMulti,
9};
10use crate::base::iana::{Rcode, Rtype};
11use crate::base::message::Message;
12use crate::base::message_builder::StreamTarget;
13use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive};
14use crate::base::{ParsedName, Serial};
15use crate::rdata::AllRecordData;
16use crate::utils::config::DefMinMax;
17use bytes::{Bytes, BytesMut};
18use core::cmp;
19use octseq::Octets;
20use std::boxed::Box;
21use std::fmt::Debug;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use std::vec::Vec;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28use tokio::sync::{mpsc, oneshot};
29use tokio::time::sleep;
30use tracing::trace;
31
32//------------ Configuration Constants ----------------------------------------
33
34/// Default response timeout.
35///
36/// Note: nsd has 120 seconds, unbound has 3 seconds.
37const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
38 Duration::from_secs(19),
39 Duration::from_millis(1),
40 Duration::from_secs(600),
41);
42
43/// Default idle timeout.
44///
45/// Note that RFC 7766, Secton 6.2.3 says: "DNS clients SHOULD close the
46/// TCP connection of an idle session, unless an idle timeout has been
47/// established using some other signalling mechanism, for example,
48/// [edns-tcp-keepalive]."
49/// However, RFC 7858, Section 3.4 says: "In order to amortize TCP and TLS
50/// connection setup costs, clients and servers SHOULD NOT immediately close
51/// a connection after each response. Instead, clients and servers SHOULD
52/// reuse existing connections for subsequent queries as long as they have
53/// sufficient resources.".
54/// We set the default to 10 seconds, which is that same as what stubby
55/// uses. Minimum zero to allow idle timeout to be disabled. Assume that
56/// one hour is more than enough as maximum.
57const IDLE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
58 Duration::from_secs(10),
59 Duration::ZERO,
60 Duration::from_secs(3600),
61);
62
63/// Capacity of the channel that transports `ChanReq`s.
64const DEF_CHAN_CAP: usize = 8;
65
66/// Capacity of a private channel dispatching responses.
67const READ_REPLY_CHAN_CAP: usize = 8;
68
69//------------ Config ---------------------------------------------------------
70
71/// Configuration for a stream transport connection.
72#[derive(Clone, Debug)]
73pub struct Config {
74 /// Response timeout currently in effect.
75 response_timeout: Duration,
76
77 /// Single response timeout.
78 single_response_timeout: Duration,
79
80 /// Streaming response timeout.
81 streaming_response_timeout: Duration,
82
83 /// Default idle timeout.
84 ///
85 /// This value is used if the other side does not send a TcpKeepalive
86 /// option.
87 idle_timeout: Duration,
88}
89
90impl Config {
91 /// Creates a new, default config.
92 pub fn new() -> Self {
93 Default::default()
94 }
95
96 /// Returns the response timeout.
97 ///
98 /// This is the amount of time to wait on a non-idle connection for a
99 /// response to an outstanding request.
100 pub fn response_timeout(&self) -> Duration {
101 self.response_timeout
102 }
103
104 /// Sets the response timeout.
105 ///
106 /// For requests where ComposeRequest::is_streaming() returns true see
107 /// set_streaming_response_timeout() instead.
108 ///
109 /// Excessive values are quietly trimmed.
110 //
111 // XXX Maybe that’s wrong and we should rather return an error?
112 pub fn set_response_timeout(&mut self, timeout: Duration) {
113 self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
114 self.streaming_response_timeout = self.response_timeout;
115 }
116
117 /// Returns the streaming response timeout.
118 pub fn streaming_response_timeout(&self) -> Duration {
119 self.streaming_response_timeout
120 }
121
122 /// Sets the streaming response timeout.
123 ///
124 /// Only used for requests where ComposeRequest::is_streaming() returns
125 /// true as it is typically desirable that such response streams be
126 /// allowed to complete even if the individual responses arrive very
127 /// slowly.
128 ///
129 /// Excessive values are quietly trimmed.
130 pub fn set_streaming_response_timeout(&mut self, timeout: Duration) {
131 self.streaming_response_timeout = RESPONSE_TIMEOUT.limit(timeout);
132 }
133
134 /// Returns the initial idle timeout, if set.
135 pub fn idle_timeout(&self) -> Duration {
136 self.idle_timeout
137 }
138
139 /// Sets the initial idle timeout.
140 ///
141 /// By default the stream is immediately closed if there are no pending
142 /// requests or responses.
143 ///
144 /// Set this to allow requests to be sent in sequence with delays between
145 /// such as a SOA query followed by AXFR for more efficient use of the
146 /// stream per RFC 9103.
147 ///
148 /// Note: May be overridden by an RFC 7828 edns-tcp-keepalive timeout
149 /// received from a server.
150 ///
151 /// Excessive values are quietly trimmed.
152 pub fn set_idle_timeout(&mut self, timeout: Duration) {
153 self.idle_timeout = IDLE_TIMEOUT.limit(timeout)
154 }
155}
156
157impl Default for Config {
158 fn default() -> Self {
159 Self {
160 response_timeout: RESPONSE_TIMEOUT.default(),
161 single_response_timeout: RESPONSE_TIMEOUT.default(),
162 streaming_response_timeout: RESPONSE_TIMEOUT.default(),
163 idle_timeout: IDLE_TIMEOUT.default(),
164 }
165 }
166}
167
168//------------ Connection -----------------------------------------------------
169
170/// A connection to a single stream transport.
171#[derive(Debug)]
172pub struct Connection<Req, ReqMulti> {
173 /// The sender half of the request channel.
174 sender: mpsc::Sender<ChanReq<Req, ReqMulti>>,
175}
176
177impl<Req, ReqMulti> Connection<Req, ReqMulti> {
178 /// Creates a new stream transport with default configuration.
179 ///
180 /// Returns a connection and a future that drives the transport using
181 /// the provided stream. This future needs to be run while any queries
182 /// are active. This is most easly achieved by spawning it into a runtime.
183 /// It terminates when the last connection is dropped.
184 pub fn new<Stream>(
185 stream: Stream,
186 ) -> (Self, Transport<Stream, Req, ReqMulti>) {
187 Self::with_config(stream, Default::default())
188 }
189
190 /// Creates a new stream transport with the given configuration.
191 ///
192 /// Returns a connection and a future that drives the transport using
193 /// the provided stream. This future needs to be run while any queries
194 /// are active. This is most easly achieved by spawning it into a runtime.
195 /// It terminates when the last connection is dropped.
196 pub fn with_config<Stream>(
197 stream: Stream,
198 config: Config,
199 ) -> (Self, Transport<Stream, Req, ReqMulti>) {
200 let (sender, transport) = Transport::new(stream, config);
201 (Self { sender }, transport)
202 }
203}
204
205impl<Req, ReqMulti> Connection<Req, ReqMulti>
206where
207 Req: ComposeRequest + 'static,
208 ReqMulti: ComposeRequestMulti + 'static,
209{
210 /// Start a DNS request.
211 ///
212 /// This function takes a precomposed message as a parameter and
213 /// returns a [`Message`] object wrapped in a [`Result`].
214 async fn handle_request_impl(
215 self,
216 msg: Req,
217 ) -> Result<Message<Bytes>, Error> {
218 let (sender, receiver) = oneshot::channel();
219 let sender = ReplySender::Single(Some(sender));
220 let msg = ReqSingleMulti::Single(msg);
221 let req = ChanReq { sender, msg };
222 self.sender.send(req).await.map_err(|_| {
223 // Send error. The receiver is gone, this means that the
224 // connection is closed.
225 Error::ConnectionClosed
226 })?;
227 receiver.await.map_err(|_| Error::StreamReceiveError)?
228 }
229
230 /// Start a streaming request.
231 async fn handle_streaming_request_impl(
232 self,
233 msg: ReqMulti,
234 sender: mpsc::Sender<Result<Option<Message<Bytes>>, Error>>,
235 ) -> Result<(), Error> {
236 let reply_sender = ReplySender::Stream(sender);
237 let msg = ReqSingleMulti::Multi(msg);
238 let req = ChanReq {
239 sender: reply_sender,
240 msg,
241 };
242 self.sender.send(req).await.map_err(|_| {
243 // Send error. The receiver is gone, this means that the
244 // connection is closed.
245 Error::ConnectionClosed
246 })?;
247 Ok(())
248 }
249
250 /// Returns a request handler for a request.
251 pub fn get_request(&self, request_msg: Req) -> Request {
252 Request {
253 fut: Box::pin(self.clone().handle_request_impl(request_msg)),
254 }
255 }
256
257 /// Return a multiple-response request handler for a request.
258 fn get_streaming_request(&self, request_msg: ReqMulti) -> RequestMulti {
259 let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
260 RequestMulti {
261 stream: receiver,
262 fut: Some(Box::pin(
263 self.clone()
264 .handle_streaming_request_impl(request_msg, sender),
265 )),
266 }
267 }
268}
269
270impl<Req, ReqMulti> Clone for Connection<Req, ReqMulti> {
271 fn clone(&self) -> Self {
272 Self {
273 sender: self.sender.clone(),
274 }
275 }
276}
277
278impl<Req, ReqMulti> SendRequest<Req> for Connection<Req, ReqMulti>
279where
280 Req: ComposeRequest + 'static,
281 ReqMulti: ComposeRequestMulti + Debug + Send + Sync + 'static,
282{
283 fn send_request(
284 &self,
285 request_msg: Req,
286 ) -> Box<dyn GetResponse + Send + Sync> {
287 Box::new(self.get_request(request_msg))
288 }
289}
290
291impl<Req, ReqMulti> SendRequestMulti<ReqMulti> for Connection<Req, ReqMulti>
292where
293 Req: ComposeRequest + Debug + Send + Sync + 'static,
294 ReqMulti: ComposeRequestMulti + 'static,
295{
296 fn send_request(
297 &self,
298 request_msg: ReqMulti,
299 ) -> Box<dyn GetResponseMulti + Send + Sync> {
300 Box::new(self.get_streaming_request(request_msg))
301 }
302}
303
304//------------ Request -------------------------------------------------------
305
306/// An active request.
307pub struct Request {
308 /// The underlying future.
309 fut: Pin<
310 Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
311 >,
312}
313
314impl Request {
315 /// Async function that waits for the future stored in Request to complete.
316 async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
317 (&mut self.fut).await
318 }
319}
320
321impl GetResponse for Request {
322 fn get_response(
323 &mut self,
324 ) -> Pin<
325 Box<
326 dyn Future<Output = Result<Message<Bytes>, Error>>
327 + Send
328 + Sync
329 + '_,
330 >,
331 > {
332 Box::pin(self.get_response_impl())
333 }
334}
335
336impl Debug for Request {
337 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
338 f.debug_struct("Request")
339 .field("fut", &format_args!("_"))
340 .finish()
341 }
342}
343
344//------------ RequestMulti --------------------------------------------------
345
346/// An active request.
347pub struct RequestMulti {
348 /// Receiver for a stream of responses.
349 stream: mpsc::Receiver<Result<Option<Message<Bytes>>, Error>>,
350
351 /// The underlying future.
352 #[allow(clippy::type_complexity)]
353 fut: Option<
354 Pin<Box<dyn Future<Output = Result<(), Error>> + Send + Sync>>,
355 >,
356}
357
358impl RequestMulti {
359 /// Async function that waits for the future stored in Request to complete.
360 async fn get_response_impl(
361 &mut self,
362 ) -> Result<Option<Message<Bytes>>, Error> {
363 if self.fut.is_some() {
364 let fut = self.fut.take().expect("Some expected");
365 fut.await?;
366 }
367
368 // Fetch from the stream
369 self.stream
370 .recv()
371 .await
372 .ok_or(Error::ConnectionClosed)
373 .map_err(|_| Error::ConnectionClosed)?
374 }
375}
376
377impl GetResponseMulti for RequestMulti {
378 fn get_response(
379 &mut self,
380 ) -> Pin<
381 Box<
382 dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
383 + Send
384 + Sync
385 + '_,
386 >,
387 > {
388 let fut = self.get_response_impl();
389 Box::pin(fut)
390 }
391}
392
393impl Debug for RequestMulti {
394 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
395 f.debug_struct("Request")
396 .field("fut", &format_args!("_"))
397 .finish()
398 }
399}
400
401//------------ Transport -----------------------------------------------------
402
403/// The underlying machinery of a stream transport.
404#[derive(Debug)]
405pub struct Transport<Stream, Req, ReqMulti> {
406 /// The stream socket towards the remote end.
407 stream: Stream,
408
409 /// Transport configuration.
410 config: Config,
411
412 /// The receiver half of request channel.
413 receiver: mpsc::Receiver<ChanReq<Req, ReqMulti>>,
414}
415
416/// This is the type of sender in [ChanReq].
417#[derive(Debug)]
418enum ReplySender {
419 /// Return channel for a single response.
420 Single(Option<oneshot::Sender<ChanResp>>),
421
422 /// Return channel for a stream of responses.
423 Stream(mpsc::Sender<Result<Option<Message<Bytes>>, Error>>),
424}
425
426impl ReplySender {
427 /// Send a response.
428 async fn send(&mut self, resp: ChanResp) -> Result<(), ()> {
429 match self {
430 ReplySender::Single(sender) => match sender.take() {
431 Some(sender) => sender.send(resp).map_err(|_| ()),
432 None => Err(()),
433 },
434 ReplySender::Stream(sender) => {
435 sender.send(resp.map(Some)).await.map_err(|_| ())
436 }
437 }
438 }
439
440 /// Send EOF on a response stream.
441 async fn send_eof(&mut self) -> Result<(), ()> {
442 match self {
443 ReplySender::Single(_) => {
444 panic!("cannot send EOF for Single");
445 }
446 ReplySender::Stream(sender) => {
447 sender.send(Ok(None)).await.map_err(|_| ())
448 }
449 }
450 }
451
452 /// Report whether in stream mode or not.
453 fn is_stream(&self) -> bool {
454 matches!(self, Self::Stream(_))
455 }
456}
457
458#[derive(Debug)]
459/// Enum that can either store a request for a single response or one for
460/// multiple responses.
461enum ReqSingleMulti<Req, ReqMulti> {
462 /// Single response request.
463 Single(Req),
464 /// Multi-response request.
465 Multi(ReqMulti),
466}
467
468/// A message from a [`Request`] to start a new request.
469#[derive(Debug)]
470struct ChanReq<Req, ReqMulti> {
471 /// DNS request message
472 msg: ReqSingleMulti<Req, ReqMulti>,
473
474 /// Sender to send result back to [`Request`]
475 sender: ReplySender,
476}
477
478/// A message back to [`Request`] returning a response.
479type ChanResp = Result<Message<Bytes>, Error>;
480
481/// Internal datastructure of [Transport::run] to keep track of
482/// the status of the connection.
483// The types Status and ConnState are only used in Transport
484struct Status {
485 /// State of the connection.
486 state: ConnState,
487
488 /// Do we need to include edns-tcp-keepalive in an outogoing request.
489 ///
490 /// Typically this is true at the start of the connection and gets
491 /// cleared when we successfully managed to include the option in a
492 /// request.
493 send_keepalive: bool,
494
495 /// Time we are allow to keep the connection open when idle.
496 ///
497 /// Initially we set the idle timeout to the default in config. A received
498 /// edns-tcp-keepalive option may change that.
499 idle_timeout: Duration,
500}
501
502/// Status of the connection. Used in [`Status`].
503enum ConnState {
504 /// The connection is in this state from the start and when at least
505 /// one active DNS request is present.
506 ///
507 /// The instant contains the time of the first request or the
508 /// most recent response that was received.
509 Active(Option<Instant>),
510
511 /// This state represent a connection that went idle and has an
512 /// idle timeout.
513 ///
514 /// The instant contains the time the connection went idle.
515 Idle(Instant),
516
517 /// This state represent an idle connection where either there was no
518 /// idle timeout or the idle timer expired.
519 IdleTimeout,
520
521 /// A read error occurred.
522 ReadError(Error),
523
524 /// It took too long to receive a response.
525 ReadTimeout,
526
527 /// A write error occurred.
528 WriteError(Error),
529}
530
531//--- Display
532impl std::fmt::Display for ConnState {
533 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534 match self {
535 ConnState::Active(instant) => f.write_fmt(format_args!(
536 "Active (since {}s ago)",
537 instant
538 .map(|v| Instant::now().duration_since(v).as_secs())
539 .unwrap_or_default()
540 )),
541 ConnState::Idle(instant) => f.write_fmt(format_args!(
542 "Idle (since {}s ago)",
543 Instant::now().duration_since(*instant).as_secs()
544 )),
545 ConnState::IdleTimeout => f.write_str("IdleTimeout"),
546 ConnState::ReadError(err) => {
547 f.write_fmt(format_args!("ReadError: {err}"))
548 }
549 ConnState::ReadTimeout => f.write_str("ReadTimeout"),
550 ConnState::WriteError(err) => {
551 f.write_fmt(format_args!("WriteError: {err}"))
552 }
553 }
554 }
555}
556
557#[derive(Debug)]
558/// State of an AXFR or IXFR responses stream for detecting the end of the
559/// stream.
560enum XFRState {
561 /// Start of AXFR.
562 AXFRInit,
563 /// After the first SOA record has been encountered.
564 AXFRFirstSoa(Serial),
565 /// Start of IXFR.
566 IXFRInit,
567 /// After the first SOA record has been encountered.
568 IXFRFirstSoa(Serial),
569 /// After the first SOA record in a diff section has been encountered.
570 IXFRFirstDiffSoa(Serial),
571 /// After the second SOA record in a diff section has been encountered.
572 IXFRSecondDiffSoa(Serial),
573 /// End of the stream has been found.
574 Done,
575 /// An error has occured.
576 Error,
577}
578
579impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti> {
580 /// Creates a new transport.
581 fn new(
582 stream: Stream,
583 config: Config,
584 ) -> (mpsc::Sender<ChanReq<Req, ReqMulti>>, Self) {
585 let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
586 (
587 sender,
588 Self {
589 config,
590 stream,
591 receiver,
592 },
593 )
594 }
595}
596
597impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti>
598where
599 Stream: AsyncRead + AsyncWrite,
600 Req: ComposeRequest,
601 ReqMulti: ComposeRequestMulti,
602{
603 /// Run the transport machinery.
604 pub async fn run(mut self) {
605 let (reply_sender, mut reply_receiver) =
606 mpsc::channel::<Message<Bytes>>(READ_REPLY_CHAN_CAP);
607
608 let (read_stream, mut write_stream) = tokio::io::split(self.stream);
609
610 let reader_fut = Self::reader(read_stream, reply_sender);
611 tokio::pin!(reader_fut);
612
613 let mut status = Status {
614 state: ConnState::Active(None),
615 idle_timeout: self.config.idle_timeout,
616 send_keepalive: true,
617 };
618 let mut query_vec =
619 Queries::<(ChanReq<Req, ReqMulti>, Option<XFRState>)>::new();
620
621 let mut reqmsg: Option<Vec<u8>> = None;
622 let mut reqmsg_offset = 0;
623
624 loop {
625 let opt_timeout = match status.state {
626 ConnState::Active(opt_instant) => {
627 if let Some(instant) = opt_instant {
628 let elapsed = instant.elapsed();
629 if elapsed > self.config.response_timeout {
630 Self::error(
631 Error::StreamReadTimeout,
632 &mut query_vec,
633 );
634 status.state = ConnState::ReadTimeout;
635 break;
636 }
637 Some(self.config.response_timeout - elapsed)
638 } else {
639 None
640 }
641 }
642 ConnState::Idle(instant) => {
643 let elapsed = instant.elapsed();
644 if elapsed >= status.idle_timeout {
645 // Move to IdleTimeout and end
646 // the loop
647 status.state = ConnState::IdleTimeout;
648 break;
649 }
650 Some(status.idle_timeout - elapsed)
651 }
652 ConnState::IdleTimeout
653 | ConnState::ReadError(_)
654 | ConnState::WriteError(_) => None, // No timers here
655 ConnState::ReadTimeout => {
656 panic!("should not be in loop with ReadTimeout");
657 }
658 };
659
660 // For simplicity, make sure we always have a timeout
661 let timeout = match opt_timeout {
662 Some(timeout) => timeout,
663 None =>
664 // Just use the response timeout
665 {
666 self.config.response_timeout
667 }
668 };
669
670 let sleep_fut = sleep(timeout);
671 let recv_fut = self.receiver.recv();
672
673 let (do_write, msg) = match &reqmsg {
674 None => {
675 let msg: &[u8] = &[];
676 (false, msg)
677 }
678 Some(msg) => {
679 let msg: &[u8] = msg;
680 (true, msg)
681 }
682 };
683
684 tokio::select! {
685 biased;
686 res = &mut reader_fut => {
687 match res {
688 Ok(_) =>
689 // The reader should not
690 // terminate without
691 // error.
692 panic!("reader terminated"),
693 Err(error) => {
694 Self::error(error.clone(), &mut query_vec);
695 status.state = ConnState::ReadError(error);
696 // Reader failed. Break
697 // out of loop and
698 // shut down
699 break
700 }
701 }
702 }
703 opt_answer = reply_receiver.recv() => {
704 let answer = opt_answer.expect("reader died?");
705 // Check for a edns-tcp-keepalive option
706 let opt_record = answer.opt();
707 if let Some(ref opts) = opt_record {
708 Self::handle_opts(opts,
709 &mut status);
710 };
711 drop(opt_record);
712 Self::demux_reply(answer, &mut status, &mut query_vec).await;
713 }
714 res = write_stream.write(&msg[reqmsg_offset..]),
715 if do_write => {
716 match res {
717 Err(error) => {
718 let error =
719 Error::StreamWriteError(Arc::new(error));
720 Self::error(error.clone(), &mut query_vec);
721 status.state =
722 ConnState::WriteError(error);
723 break;
724 }
725 Ok(len) => {
726 reqmsg_offset += len;
727 if reqmsg_offset >= msg.len() {
728 reqmsg = None;
729 reqmsg_offset = 0;
730 }
731 }
732 }
733 }
734 res = recv_fut, if !do_write => {
735 match res {
736 Some(req) => {
737 if req.sender.is_stream() {
738 self.config.response_timeout =
739 self.config.streaming_response_timeout;
740 } else {
741 self.config.response_timeout =
742 self.config.single_response_timeout;
743 }
744 Self::insert_req(
745 req, &mut status, &mut reqmsg, &mut query_vec
746 );
747 }
748 None => {
749 // All references to the connection object have
750 // been dropped. Shutdown.
751 break;
752 }
753 }
754 }
755 _ = sleep_fut => {
756 // Timeout expired, just
757 // continue with the loop
758 }
759
760 }
761
762 // Check if the connection is idle
763 match status.state {
764 ConnState::Active(_) | ConnState::Idle(_) => {
765 // Keep going
766 }
767 ConnState::IdleTimeout => break,
768 ConnState::ReadError(_)
769 | ConnState::ReadTimeout
770 | ConnState::WriteError(_) => {
771 panic!("Should not be here");
772 }
773 }
774 }
775
776 trace!("Closing TCP connecting in state: {}", status.state);
777
778 // Send FIN
779 _ = write_stream.shutdown().await;
780 }
781
782 /// This function reads a DNS message from the connection and sends
783 /// it to [Transport::run].
784 ///
785 /// Reading has to be done in two steps: first read a two octet value
786 /// the specifies the length of the message, and then read in a loop the
787 /// body of the message.
788 ///
789 /// This function is not async cancellation safe.
790 async fn reader(
791 mut sock: tokio::io::ReadHalf<Stream>,
792 sender: mpsc::Sender<Message<Bytes>>,
793 ) -> Result<(), Error> {
794 loop {
795 let read_res = sock.read_u16().await;
796 let len = match read_res {
797 Ok(len) => len,
798 Err(error) => {
799 return Err(Error::StreamReadError(Arc::new(error)));
800 }
801 } as usize;
802
803 let mut buf = BytesMut::with_capacity(len);
804
805 loop {
806 let curlen = buf.len();
807 if curlen >= len {
808 if curlen > len {
809 panic!(
810 "reader: got too much data {curlen}, expetect {len}");
811 }
812
813 // We got what we need
814 break;
815 }
816
817 let read_res = sock.read_buf(&mut buf).await;
818
819 match read_res {
820 Ok(readlen) => {
821 if readlen == 0 {
822 return Err(Error::StreamUnexpectedEndOfData);
823 }
824 }
825 Err(error) => {
826 return Err(Error::StreamReadError(Arc::new(error)));
827 }
828 };
829
830 // Check if we are done at the head of the loop
831 }
832
833 let reply_message = Message::<Bytes>::from_octets(buf.into());
834 match reply_message {
835 Ok(answer) => {
836 sender
837 .send(answer)
838 .await
839 .expect("can't send reply to run");
840 }
841 Err(_) => {
842 // The only possible error is short message
843 return Err(Error::ShortMessage);
844 }
845 }
846 }
847 }
848
849 /// Reports an error to all outstanding queries.
850 fn error(
851 error: Error,
852 query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
853 ) {
854 // Update all requests that are in progress. Don't wait for
855 // any reply that may be on its way.
856 for (mut req, _) in query_vec.drain() {
857 _ = req.sender.send(Err(error.clone()));
858 }
859 }
860
861 /// Handles received EDNS options.
862 ///
863 /// In particular, it processes the edns-tcp-keepalive option.
864 fn handle_opts<Octs: Octets + AsRef<[u8]>>(
865 opts: &OptRecord<Octs>,
866 status: &mut Status,
867 ) {
868 // XXX This handles _all_ keepalive options. I think just using the
869 // first option as returned by Opt::tcp_keepalive should be good
870 // enough? -- M.
871 for option in opts.opt().iter().flatten() {
872 if let AllOptData::TcpKeepalive(tcpkeepalive) = option {
873 Self::handle_keepalive(tcpkeepalive, status);
874 }
875 }
876 }
877
878 /// Demultiplexes a response and sends it to the right query.
879 ///
880 /// In addition, the status is updated to IdleTimeout or Idle if there
881 /// are no remaining pending requests.
882 async fn demux_reply(
883 answer: Message<Bytes>,
884 status: &mut Status,
885 query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
886 ) {
887 // We got an answer, reset the timer
888 status.state = ConnState::Active(Some(Instant::now()));
889
890 let id = answer.header().id();
891
892 // Get the correct query and send it the reply.
893 let (mut req, mut opt_xfr_data) = match query_vec.try_remove(id) {
894 Some(req) => req,
895 None => {
896 // No query with this ID. We should
897 // mark the connection as broken
898 return;
899 }
900 };
901 let mut send_eof = false;
902 let answer = if match &req.msg {
903 ReqSingleMulti::Single(msg) => msg.is_answer(answer.for_slice()),
904 ReqSingleMulti::Multi(msg) => {
905 let xfr_data =
906 opt_xfr_data.expect("xfr_data should be present");
907 let (eof, xfr_data, is_answer) =
908 check_stream(msg, xfr_data, &answer);
909 send_eof = eof;
910 opt_xfr_data = Some(xfr_data);
911 is_answer
912 }
913 } {
914 Ok(answer)
915 } else {
916 Err(Error::WrongReplyForQuery)
917 };
918 _ = req.sender.send(answer).await;
919
920 if req.sender.is_stream() {
921 if send_eof {
922 _ = req.sender.send_eof().await;
923 } else {
924 query_vec.insert_at(id, (req, opt_xfr_data));
925 }
926 }
927
928 if query_vec.is_empty() {
929 // Clear the activity timer. There is no need to do
930 // this because state will be set to either IdleTimeout
931 // or Idle just below. However, it is nicer to keep
932 // this independent.
933 status.state = ConnState::Active(None);
934
935 status.state = if status.idle_timeout.is_zero() {
936 // Assume that we can just move to IdleTimeout
937 // state
938 ConnState::IdleTimeout
939 } else {
940 ConnState::Idle(Instant::now())
941 }
942 }
943 }
944
945 /// Insert a request in query_vec and return the request to be sent
946 /// in *reqmsg.
947 ///
948 /// First the status is checked, an error is returned if not Active or
949 /// idle. Addend a edns-tcp-keepalive option if needed.
950 // Note: maybe reqmsg should be a return value.
951 fn insert_req(
952 mut req: ChanReq<Req, ReqMulti>,
953 status: &mut Status,
954 reqmsg: &mut Option<Vec<u8>>,
955 query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
956 ) {
957 match &status.state {
958 ConnState::Active(timer) => {
959 // Set timer if we don't have one already
960 if timer.is_none() {
961 status.state = ConnState::Active(Some(Instant::now()));
962 }
963 }
964 ConnState::Idle(_) => {
965 // Go back to active
966 status.state = ConnState::Active(Some(Instant::now()));
967 }
968 ConnState::IdleTimeout => {
969 // The connection has been closed. Report error
970 _ = req.sender.send(Err(Error::StreamIdleTimeout));
971 return;
972 }
973 ConnState::ReadError(error) => {
974 _ = req.sender.send(Err(error.clone()));
975 return;
976 }
977 ConnState::ReadTimeout => {
978 _ = req.sender.send(Err(Error::StreamReadTimeout));
979 return;
980 }
981 ConnState::WriteError(error) => {
982 _ = req.sender.send(Err(error.clone()));
983 return;
984 }
985 }
986
987 let xfr_data = match &req.msg {
988 ReqSingleMulti::Single(_) => None,
989 ReqSingleMulti::Multi(msg) => {
990 let qtype = match msg.to_message().and_then(|m| {
991 m.sole_question()
992 .map_err(|_| Error::MessageParseError)
993 .map(|q| q.qtype())
994 }) {
995 Ok(msg) => msg,
996 Err(e) => {
997 _ = req.sender.send(Err(e));
998 return;
999 }
1000 };
1001 if qtype == Rtype::AXFR {
1002 Some(XFRState::AXFRInit)
1003 } else if qtype == Rtype::IXFR {
1004 Some(XFRState::IXFRInit)
1005 } else {
1006 // Stream requests should be either AXFR or IXFR.
1007 _ = req.sender.send(Err(Error::FormError));
1008 return;
1009 }
1010 }
1011 };
1012
1013 // Note that insert may fail if there are too many
1014 // outstanding queries. First call insert before checking
1015 // send_keepalive.
1016 let (index, (req, _)) = match query_vec.insert((req, xfr_data)) {
1017 Ok(res) => res,
1018 Err((mut req, _)) => {
1019 // Send an appropriate error and return.
1020 _ = req
1021 .sender
1022 .send(Err(Error::StreamTooManyOutstandingQueries));
1023 return;
1024 }
1025 };
1026
1027 // We set the ID to the array index. Defense in depth
1028 // suggests that a random ID is better because it works
1029 // even if TCP sequence numbers could be predicted. However,
1030 // Section 9.3 of RFC 5452 recommends retrying over TCP
1031 // if many spoofed answers arrive over UDP: "TCP, by the
1032 // nature of its use of sequence numbers, is far more
1033 // resilient against forgery by third parties."
1034
1035 let hdr = match &mut req.msg {
1036 ReqSingleMulti::Single(msg) => msg.header_mut(),
1037 ReqSingleMulti::Multi(msg) => msg.header_mut(),
1038 };
1039 hdr.set_id(index);
1040
1041 if status.send_keepalive
1042 && match &mut req.msg {
1043 ReqSingleMulti::Single(msg) => {
1044 msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1045 }
1046 ReqSingleMulti::Multi(msg) => {
1047 msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1048 }
1049 }
1050 {
1051 status.send_keepalive = false;
1052 }
1053
1054 match Self::convert_query(&req.msg) {
1055 Ok(msg) => {
1056 *reqmsg = Some(msg);
1057 }
1058 Err(err) => {
1059 // Take the sender out again and return the error.
1060 if let Some((mut req, _)) = query_vec.try_remove(index) {
1061 _ = req.sender.send(Err(err));
1062 }
1063 }
1064 }
1065 }
1066
1067 /// Handle a received edns-tcp-keepalive option.
1068 fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) {
1069 if let Some(value) = opt_value.timeout() {
1070 let value_dur = Duration::from(value);
1071 status.idle_timeout = value_dur;
1072 }
1073 }
1074
1075 /// Convert the query message to a vector.
1076 fn convert_query(
1077 msg: &ReqSingleMulti<Req, ReqMulti>,
1078 ) -> Result<Vec<u8>, Error> {
1079 match msg {
1080 ReqSingleMulti::Single(msg) => {
1081 let mut target = StreamTarget::new_vec();
1082 msg.append_message(&mut target)
1083 .map_err(|_| Error::StreamLongMessage)?;
1084 Ok(target.into_target())
1085 }
1086 ReqSingleMulti::Multi(msg) => {
1087 let target = StreamTarget::new_vec();
1088 let target = msg
1089 .append_message(target)
1090 .map_err(|_| Error::StreamLongMessage)?;
1091 Ok(target.finish().into_target())
1092 }
1093 }
1094 }
1095}
1096
1097/// Upstate the response stream state based on a response message.
1098fn check_stream<CRM>(
1099 msg: &CRM,
1100 mut xfr_state: XFRState,
1101 answer: &Message<Bytes>,
1102) -> (bool, XFRState, bool)
1103where
1104 CRM: ComposeRequestMulti,
1105{
1106 // First check if the reply matches the request.
1107 // RFC 5936, Section 2.2.2:
1108 // "In the first response message, this section MUST be copied from the
1109 // query. In subsequent messages, this section MAY be copied from the
1110 // query, or it MAY be empty. However, in an error response message
1111 // (see Section 2.2), this section MUST be copied as well."
1112 match xfr_state {
1113 XFRState::AXFRInit | XFRState::IXFRInit => {
1114 if !msg.is_answer(answer.for_slice()) {
1115 xfr_state = XFRState::Error;
1116 // If we detect an error, then keep the stream open. We are
1117 // likely out of sync with respect to the sender.
1118 return (false, xfr_state, false);
1119 }
1120 }
1121 XFRState::AXFRFirstSoa(_)
1122 | XFRState::IXFRFirstSoa(_)
1123 | XFRState::IXFRFirstDiffSoa(_)
1124 | XFRState::IXFRSecondDiffSoa(_) =>
1125 // No need to check anything.
1126 {}
1127 XFRState::Done => {
1128 // We should not be here. Switch to error state.
1129 xfr_state = XFRState::Error;
1130 return (false, xfr_state, false);
1131 }
1132 XFRState::Error =>
1133 // Keep the stream open.
1134 {
1135 return (false, xfr_state, false)
1136 }
1137 }
1138
1139 // Then check if the reply status an error.
1140 if answer.header().rcode() != Rcode::NOERROR {
1141 // Also check if this answers the question.
1142 if !msg.is_answer(answer.for_slice()) {
1143 xfr_state = XFRState::Error;
1144 // If we detect an error, then keep the stream open. We are
1145 // likely out of sync with respect to the sender.
1146 return (false, xfr_state, false);
1147 }
1148 return (true, xfr_state, true);
1149 }
1150
1151 let ans_sec = match answer.answer() {
1152 Ok(ans) => ans,
1153 Err(_) => {
1154 // Bad message, switch to error state.
1155 xfr_state = XFRState::Error;
1156 // If we detect an error, then keep the stream open.
1157 return (true, xfr_state, false);
1158 }
1159 };
1160 for rr in
1161 ans_sec.into_records::<AllRecordData<Bytes, ParsedName<Bytes>>>()
1162 {
1163 let rr = match rr {
1164 Ok(rr) => rr,
1165 Err(_) => {
1166 // Bad message, switch to error state.
1167 xfr_state = XFRState::Error;
1168 return (true, xfr_state, false);
1169 }
1170 };
1171 match xfr_state {
1172 XFRState::AXFRInit => {
1173 // The first record has to be a SOA record.
1174 if let AllRecordData::Soa(soa) = rr.data() {
1175 xfr_state = XFRState::AXFRFirstSoa(soa.serial());
1176 continue;
1177 }
1178 // Bad data. Switch to error status.
1179 xfr_state = XFRState::Error;
1180 return (false, xfr_state, false);
1181 }
1182 XFRState::AXFRFirstSoa(serial) => {
1183 if let AllRecordData::Soa(soa) = rr.data() {
1184 if serial == soa.serial() {
1185 // We found a match.
1186 xfr_state = XFRState::Done;
1187 continue;
1188 }
1189
1190 // Serial does not match. Move to error state.
1191 xfr_state = XFRState::Error;
1192 return (false, xfr_state, false);
1193 }
1194
1195 // Any other record, just continue.
1196 }
1197 XFRState::IXFRInit => {
1198 // The first record has to be a SOA record.
1199 if let AllRecordData::Soa(soa) = rr.data() {
1200 xfr_state = XFRState::IXFRFirstSoa(soa.serial());
1201 continue;
1202 }
1203 // Bad data. Switch to error status.
1204 xfr_state = XFRState::Error;
1205 return (false, xfr_state, false);
1206 }
1207 XFRState::IXFRFirstSoa(serial) => {
1208 // We have three possibilities:
1209 // 1) The record is not a SOA. In that case the format is AXFR.
1210 // 2) The record is a SOA and the serial is not the current
1211 // serial. That is expected for an IXFR format. Move to
1212 // IXFRFirstDiffSoa.
1213 // 3) The record is a SOA and the serial is equal to the
1214 // current serial. Treat this as a strange empty AXFR.
1215 if let AllRecordData::Soa(soa) = rr.data() {
1216 if serial == soa.serial() {
1217 // We found a match.
1218 xfr_state = XFRState::Done;
1219 continue;
1220 }
1221
1222 xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1223 continue;
1224 }
1225
1226 // Any other record, move to AXFRFirstSoa.
1227 xfr_state = XFRState::AXFRFirstSoa(serial);
1228 }
1229 XFRState::IXFRFirstDiffSoa(serial) => {
1230 // Move to IXFRSecondDiffSoa if the record is a SOA record,
1231 // otherwise stay in the current state.
1232 if let AllRecordData::Soa(_) = rr.data() {
1233 xfr_state = XFRState::IXFRSecondDiffSoa(serial);
1234 continue;
1235 }
1236
1237 // Any other record, just continue.
1238 }
1239 XFRState::IXFRSecondDiffSoa(serial) => {
1240 // Move to Done if the record is a SOA record and the
1241 // serial is the one from the first SOA record, move to
1242 // IXFRFirstDiffSoa for any other SOA record and
1243 // otherwise stay in the current state.
1244 if let AllRecordData::Soa(soa) = rr.data() {
1245 if serial == soa.serial() {
1246 // We found a match.
1247 xfr_state = XFRState::Done;
1248 continue;
1249 }
1250
1251 xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1252 continue;
1253 }
1254
1255 // Any other record, just continue.
1256 }
1257 XFRState::Done => {
1258 // We got a record after we are done. Switch to error state.
1259 xfr_state = XFRState::Error;
1260 return (false, xfr_state, false);
1261 }
1262 XFRState::Error => panic!("should not be here"),
1263 }
1264 }
1265
1266 // Check the final state.
1267 match xfr_state {
1268 XFRState::AXFRInit | XFRState::IXFRInit => {
1269 // Still in one of the init state. So the data section was empty.
1270 // Switch to error state.
1271 xfr_state = XFRState::Error;
1272 return (false, xfr_state, false);
1273 }
1274 XFRState::AXFRFirstSoa(_)
1275 | XFRState::IXFRFirstDiffSoa(_)
1276 | XFRState::IXFRSecondDiffSoa(_) =>
1277 // Just continue.
1278 {}
1279 XFRState::IXFRFirstSoa(_) => {
1280 // We are still in IXFRFirstSoa. Assume the other side doesn't
1281 // have anything more to say. We could check the SOA serial in
1282 // the request. Just assume that we are done.
1283 xfr_state = XFRState::Done;
1284 return (true, xfr_state, true);
1285 }
1286 XFRState::Done => return (true, xfr_state, true),
1287 XFRState::Error => unreachable!(),
1288 }
1289
1290 // (eof, xfr_data, is_answer)
1291 (false, xfr_state, true)
1292}
1293
1294//------------ Queries -------------------------------------------------------
1295
1296/// Mapping outstanding queries to their ID.
1297///
1298/// This is generic over anything rather than our concrete request type for
1299/// easier testing.
1300#[derive(Clone, Debug)]
1301struct Queries<T> {
1302 /// The number of elements in `vec` that are not None.
1303 count: usize,
1304
1305 /// Index in `vec` where to look for a space for a new query.
1306 curr: usize,
1307
1308 /// Vector of senders to forward a DNS reply message (or error) to.
1309 vec: Vec<Option<T>>,
1310}
1311
1312impl<T> Queries<T> {
1313 /// Creates a new empty value.
1314 fn new() -> Self {
1315 Self {
1316 count: 0,
1317 curr: 0,
1318 vec: Vec::new(),
1319 }
1320 }
1321
1322 /// Returns whether there are no more outstanding queries.
1323 fn is_empty(&self) -> bool {
1324 self.count == 0
1325 }
1326
1327 /// Inserts the given query.
1328 ///
1329 /// Upon success, returns the index and a mutable reference to the stored
1330 /// query.
1331 ///
1332 /// Upon error, which means the set is full, returns the query.
1333 fn insert(&mut self, req: T) -> Result<(u16, &mut T), T> {
1334 // Fail if there are to many entries already in this vector
1335 // We cannot have more than u16::MAX entries because the
1336 // index needs to fit in an u16. For efficiency we want to
1337 // keep the vector half empty. So we return a failure if
1338 // 2*count > u16::MAX
1339 if 2 * self.count > u16::MAX as usize {
1340 return Err(req);
1341 }
1342
1343 // If more than half the vec is empty, we try and find the index of
1344 // an empty slot.
1345 let idx = if self.vec.len() >= 2 * self.count {
1346 let mut found = None;
1347 for idx in self.curr..self.vec.len() {
1348 if self.vec[idx].is_none() {
1349 found = Some(idx);
1350 break;
1351 }
1352 }
1353 found
1354 } else {
1355 None
1356 };
1357
1358 // If we have an index, we can insert there, otherwise we need to
1359 // append.
1360 let idx = match idx {
1361 Some(idx) => {
1362 self.vec[idx] = Some(req);
1363 idx
1364 }
1365 None => {
1366 let idx = self.vec.len();
1367 self.vec.push(Some(req));
1368 idx
1369 }
1370 };
1371
1372 self.count += 1;
1373 if idx == self.curr {
1374 self.curr += 1;
1375 }
1376 let req = self.vec[idx].as_mut().expect("no inserted item?");
1377 let idx = u16::try_from(idx).expect("query vec too large");
1378 Ok((idx, req))
1379 }
1380
1381 /// Inserts the given query at a specified position. A pre-condition is
1382 /// is that the slot has to be empty.
1383 fn insert_at(&mut self, id: u16, req: T) {
1384 let id = id as usize;
1385 self.vec[id] = Some(req);
1386
1387 self.count += 1;
1388 if id == self.curr {
1389 self.curr += 1;
1390 }
1391 }
1392
1393 /// Tries to remove and return the query at the given index.
1394 ///
1395 /// Returns `None` if there was no query there.
1396 fn try_remove(&mut self, index: u16) -> Option<T> {
1397 let res = self.vec.get_mut(usize::from(index))?.take()?;
1398 self.count = self.count.saturating_sub(1);
1399 self.curr = cmp::min(self.curr, index.into());
1400 Some(res)
1401 }
1402
1403 /// Removes all queries and returns an iterator over them.
1404 fn drain(&mut self) -> impl Iterator<Item = T> + '_ {
1405 let res = self.vec.drain(..).flatten(); // Skips all the `None`s.
1406 self.count = 0;
1407 self.curr = 0;
1408 res
1409 }
1410}
1411
1412//============ Tests =========================================================
1413
1414#[cfg(test)]
1415mod test {
1416 use super::*;
1417
1418 #[test]
1419 #[allow(clippy::needless_range_loop)]
1420 fn queries_insert_remove() {
1421 // Insert items, remove a few, insert a few more. Check that
1422 // everything looks right.
1423 let mut idxs = [None; 20];
1424 let mut queries = Queries::new();
1425
1426 for i in 0..12 {
1427 let (idx, item) = queries.insert(i).expect("test failed");
1428 idxs[i] = Some(idx);
1429 assert_eq!(i, *item);
1430 }
1431 assert_eq!(queries.count, 12);
1432 assert_eq!(queries.vec.iter().flatten().count(), 12);
1433
1434 for i in [1, 2, 3, 4, 7, 9] {
1435 let item = queries
1436 .try_remove(idxs[i].expect("test failed"))
1437 .expect("test failed");
1438 assert_eq!(i, item);
1439 idxs[i] = None;
1440 }
1441 assert_eq!(queries.count, 6);
1442 assert_eq!(queries.vec.iter().flatten().count(), 6);
1443
1444 for i in 12..20 {
1445 let (idx, item) = queries.insert(i).expect("test failed");
1446 idxs[i] = Some(idx);
1447 assert_eq!(i, *item);
1448 }
1449 assert_eq!(queries.count, 14);
1450 assert_eq!(queries.vec.iter().flatten().count(), 14);
1451
1452 for i in 0..20 {
1453 if let Some(idx) = idxs[i] {
1454 let item = queries.try_remove(idx).expect("test failed");
1455 assert_eq!(i, item);
1456 }
1457 }
1458 assert_eq!(queries.count, 0);
1459 assert_eq!(queries.vec.iter().flatten().count(), 0);
1460 }
1461
1462 #[test]
1463 fn queries_overrun() {
1464 // This is just a quick check that inserting to much stuff doesn’t
1465 // break.
1466 let mut queries = Queries::new();
1467 for i in 0..usize::from(u16::MAX) * 2 {
1468 let _ = queries.insert(i);
1469 }
1470 }
1471}