domain/net/client/
stream.rs

1//! A client transport using a stream socket.
2
3// RFC 7766 describes DNS over TCP
4// RFC 7828 describes the edns-tcp-keepalive option
5
6use super::request::{
7    ComposeRequest, ComposeRequestMulti, Error, GetResponse,
8    GetResponseMulti, SendRequest, SendRequestMulti,
9};
10use crate::base::iana::{Rcode, Rtype};
11use crate::base::message::Message;
12use crate::base::message_builder::StreamTarget;
13use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive};
14use crate::base::{ParsedName, Serial};
15use crate::rdata::AllRecordData;
16use crate::utils::config::DefMinMax;
17use bytes::{Bytes, BytesMut};
18use core::cmp;
19use octseq::Octets;
20use std::boxed::Box;
21use std::fmt::Debug;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use std::vec::Vec;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28use tokio::sync::{mpsc, oneshot};
29use tokio::time::sleep;
30use tracing::trace;
31
32//------------ Configuration Constants ----------------------------------------
33
34/// Default response timeout.
35///
36/// Note: nsd has 120 seconds, unbound has 3 seconds.
37const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
38    Duration::from_secs(19),
39    Duration::from_millis(1),
40    Duration::from_secs(600),
41);
42
43/// Default idle timeout.
44///
45/// Note that RFC 7766, Secton 6.2.3 says: "DNS clients SHOULD close the
46/// TCP connection of an idle session, unless an idle timeout has been
47/// established using some other signalling mechanism, for example,
48/// [edns-tcp-keepalive]."
49/// However, RFC 7858, Section 3.4 says: "In order to amortize TCP and TLS
50/// connection setup costs, clients and servers SHOULD NOT immediately close
51/// a connection after each response.  Instead, clients and servers SHOULD
52/// reuse existing connections for subsequent queries as long as they have
53/// sufficient resources.".
54/// We set the default to 10 seconds, which is that same as what stubby
55/// uses. Minimum zero to allow idle timeout to be disabled. Assume that
56/// one hour is more than enough as maximum.
57const IDLE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
58    Duration::from_secs(10),
59    Duration::ZERO,
60    Duration::from_secs(3600),
61);
62
63/// Capacity of the channel that transports `ChanReq`s.
64const DEF_CHAN_CAP: usize = 8;
65
66/// Capacity of a private channel dispatching responses.
67const READ_REPLY_CHAN_CAP: usize = 8;
68
69//------------ Config ---------------------------------------------------------
70
71/// Configuration for a stream transport connection.
72#[derive(Clone, Debug)]
73pub struct Config {
74    /// Response timeout currently in effect.
75    response_timeout: Duration,
76
77    /// Single response timeout.
78    single_response_timeout: Duration,
79
80    /// Streaming response timeout.
81    streaming_response_timeout: Duration,
82
83    /// Default idle timeout.
84    ///
85    /// This value is used if the other side does not send a TcpKeepalive
86    /// option.
87    idle_timeout: Duration,
88}
89
90impl Config {
91    /// Creates a new, default config.
92    pub fn new() -> Self {
93        Default::default()
94    }
95
96    /// Returns the response timeout.
97    ///
98    /// This is the amount of time to wait on a non-idle connection for a
99    /// response to an outstanding request.
100    pub fn response_timeout(&self) -> Duration {
101        self.response_timeout
102    }
103
104    /// Sets the response timeout.
105    ///
106    /// For requests where ComposeRequest::is_streaming() returns true see
107    /// set_streaming_response_timeout() instead.
108    ///
109    /// Excessive values are quietly trimmed.
110    //
111    //  XXX Maybe that’s wrong and we should rather return an error?
112    pub fn set_response_timeout(&mut self, timeout: Duration) {
113        self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
114        self.streaming_response_timeout = self.response_timeout;
115    }
116
117    /// Returns the streaming response timeout.
118    pub fn streaming_response_timeout(&self) -> Duration {
119        self.streaming_response_timeout
120    }
121
122    /// Sets the streaming response timeout.
123    ///
124    /// Only used for requests where ComposeRequest::is_streaming() returns
125    /// true as it is typically desirable that such response streams be
126    /// allowed to complete even if the individual responses arrive very
127    /// slowly.
128    ///
129    /// Excessive values are quietly trimmed.
130    pub fn set_streaming_response_timeout(&mut self, timeout: Duration) {
131        self.streaming_response_timeout = RESPONSE_TIMEOUT.limit(timeout);
132    }
133
134    /// Returns the initial idle timeout, if set.
135    pub fn idle_timeout(&self) -> Duration {
136        self.idle_timeout
137    }
138
139    /// Sets the initial idle timeout.
140    ///
141    /// By default the stream is immediately closed if there are no pending
142    /// requests or responses.
143    ///
144    /// Set this to allow requests to be sent in sequence with delays between
145    /// such as a SOA query followed by AXFR for more efficient use of the
146    /// stream per RFC 9103.
147    ///
148    /// Note: May be overridden by an RFC 7828 edns-tcp-keepalive timeout
149    /// received from a server.
150    ///
151    /// Excessive values are quietly trimmed.
152    pub fn set_idle_timeout(&mut self, timeout: Duration) {
153        self.idle_timeout = IDLE_TIMEOUT.limit(timeout)
154    }
155}
156
157impl Default for Config {
158    fn default() -> Self {
159        Self {
160            response_timeout: RESPONSE_TIMEOUT.default(),
161            single_response_timeout: RESPONSE_TIMEOUT.default(),
162            streaming_response_timeout: RESPONSE_TIMEOUT.default(),
163            idle_timeout: IDLE_TIMEOUT.default(),
164        }
165    }
166}
167
168//------------ Connection -----------------------------------------------------
169
170/// A connection to a single stream transport.
171#[derive(Debug)]
172pub struct Connection<Req, ReqMulti> {
173    /// The sender half of the request channel.
174    sender: mpsc::Sender<ChanReq<Req, ReqMulti>>,
175}
176
177impl<Req, ReqMulti> Connection<Req, ReqMulti> {
178    /// Creates a new stream transport with default configuration.
179    ///
180    /// Returns a connection and a future that drives the transport using
181    /// the provided stream. This future needs to be run while any queries
182    /// are active. This is most easly achieved by spawning it into a runtime.
183    /// It terminates when the last connection is dropped.
184    pub fn new<Stream>(
185        stream: Stream,
186    ) -> (Self, Transport<Stream, Req, ReqMulti>) {
187        Self::with_config(stream, Default::default())
188    }
189
190    /// Creates a new stream transport with the given configuration.
191    ///
192    /// Returns a connection and a future that drives the transport using
193    /// the provided stream. This future needs to be run while any queries
194    /// are active. This is most easly achieved by spawning it into a runtime.
195    /// It terminates when the last connection is dropped.
196    pub fn with_config<Stream>(
197        stream: Stream,
198        config: Config,
199    ) -> (Self, Transport<Stream, Req, ReqMulti>) {
200        let (sender, transport) = Transport::new(stream, config);
201        (Self { sender }, transport)
202    }
203}
204
205impl<Req, ReqMulti> Connection<Req, ReqMulti>
206where
207    Req: ComposeRequest + 'static,
208    ReqMulti: ComposeRequestMulti + 'static,
209{
210    /// Start a DNS request.
211    ///
212    /// This function takes a precomposed message as a parameter and
213    /// returns a [`Message`] object wrapped in a [`Result`].
214    async fn handle_request_impl(
215        self,
216        msg: Req,
217    ) -> Result<Message<Bytes>, Error> {
218        let (sender, receiver) = oneshot::channel();
219        let sender = ReplySender::Single(Some(sender));
220        let msg = ReqSingleMulti::Single(msg);
221        let req = ChanReq { sender, msg };
222        self.sender.send(req).await.map_err(|_| {
223            // Send error. The receiver is gone, this means that the
224            // connection is closed.
225            Error::ConnectionClosed
226        })?;
227        receiver.await.map_err(|_| Error::StreamReceiveError)?
228    }
229
230    /// Start a streaming request.
231    async fn handle_streaming_request_impl(
232        self,
233        msg: ReqMulti,
234        sender: mpsc::Sender<Result<Option<Message<Bytes>>, Error>>,
235    ) -> Result<(), Error> {
236        let reply_sender = ReplySender::Stream(sender);
237        let msg = ReqSingleMulti::Multi(msg);
238        let req = ChanReq {
239            sender: reply_sender,
240            msg,
241        };
242        self.sender.send(req).await.map_err(|_| {
243            // Send error. The receiver is gone, this means that the
244            // connection is closed.
245            Error::ConnectionClosed
246        })?;
247        Ok(())
248    }
249
250    /// Returns a request handler for a request.
251    pub fn get_request(&self, request_msg: Req) -> Request {
252        Request {
253            fut: Box::pin(self.clone().handle_request_impl(request_msg)),
254        }
255    }
256
257    /// Return a multiple-response request handler for a request.
258    fn get_streaming_request(&self, request_msg: ReqMulti) -> RequestMulti {
259        let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
260        RequestMulti {
261            stream: receiver,
262            fut: Some(Box::pin(
263                self.clone()
264                    .handle_streaming_request_impl(request_msg, sender),
265            )),
266        }
267    }
268}
269
270impl<Req, ReqMulti> Clone for Connection<Req, ReqMulti> {
271    fn clone(&self) -> Self {
272        Self {
273            sender: self.sender.clone(),
274        }
275    }
276}
277
278impl<Req, ReqMulti> SendRequest<Req> for Connection<Req, ReqMulti>
279where
280    Req: ComposeRequest + 'static,
281    ReqMulti: ComposeRequestMulti + Debug + Send + Sync + 'static,
282{
283    fn send_request(
284        &self,
285        request_msg: Req,
286    ) -> Box<dyn GetResponse + Send + Sync> {
287        Box::new(self.get_request(request_msg))
288    }
289}
290
291impl<Req, ReqMulti> SendRequestMulti<ReqMulti> for Connection<Req, ReqMulti>
292where
293    Req: ComposeRequest + Debug + Send + Sync + 'static,
294    ReqMulti: ComposeRequestMulti + 'static,
295{
296    fn send_request(
297        &self,
298        request_msg: ReqMulti,
299    ) -> Box<dyn GetResponseMulti + Send + Sync> {
300        Box::new(self.get_streaming_request(request_msg))
301    }
302}
303
304//------------ Request -------------------------------------------------------
305
306/// An active request.
307pub struct Request {
308    /// The underlying future.
309    fut: Pin<
310        Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
311    >,
312}
313
314impl Request {
315    /// Async function that waits for the future stored in Request to complete.
316    async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
317        (&mut self.fut).await
318    }
319}
320
321impl GetResponse for Request {
322    fn get_response(
323        &mut self,
324    ) -> Pin<
325        Box<
326            dyn Future<Output = Result<Message<Bytes>, Error>>
327                + Send
328                + Sync
329                + '_,
330        >,
331    > {
332        Box::pin(self.get_response_impl())
333    }
334}
335
336impl Debug for Request {
337    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338        f.debug_struct("Request")
339            .field("fut", &format_args!("_"))
340            .finish()
341    }
342}
343
344//------------ RequestMulti --------------------------------------------------
345
346/// An active request.
347pub struct RequestMulti {
348    /// Receiver for a stream of responses.
349    stream: mpsc::Receiver<Result<Option<Message<Bytes>>, Error>>,
350
351    /// The underlying future.
352    #[allow(clippy::type_complexity)]
353    fut: Option<
354        Pin<Box<dyn Future<Output = Result<(), Error>> + Send + Sync>>,
355    >,
356}
357
358impl RequestMulti {
359    /// Async function that waits for the future stored in Request to complete.
360    async fn get_response_impl(
361        &mut self,
362    ) -> Result<Option<Message<Bytes>>, Error> {
363        if self.fut.is_some() {
364            let fut = self.fut.take().expect("Some expected");
365            fut.await?;
366        }
367
368        // Fetch from the stream
369        self.stream
370            .recv()
371            .await
372            .ok_or(Error::ConnectionClosed)
373            .map_err(|_| Error::ConnectionClosed)?
374    }
375}
376
377impl GetResponseMulti for RequestMulti {
378    fn get_response(
379        &mut self,
380    ) -> Pin<
381        Box<
382            dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
383                + Send
384                + Sync
385                + '_,
386        >,
387    > {
388        let fut = self.get_response_impl();
389        Box::pin(fut)
390    }
391}
392
393impl Debug for RequestMulti {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        f.debug_struct("Request")
396            .field("fut", &format_args!("_"))
397            .finish()
398    }
399}
400
401//------------ Transport -----------------------------------------------------
402
403/// The underlying machinery of a stream transport.
404#[derive(Debug)]
405pub struct Transport<Stream, Req, ReqMulti> {
406    /// The stream socket towards the remote end.
407    stream: Stream,
408
409    /// Transport configuration.
410    config: Config,
411
412    /// The receiver half of request channel.
413    receiver: mpsc::Receiver<ChanReq<Req, ReqMulti>>,
414}
415
416/// This is the type of sender in [ChanReq].
417#[derive(Debug)]
418enum ReplySender {
419    /// Return channel for a single response.
420    Single(Option<oneshot::Sender<ChanResp>>),
421
422    /// Return channel for a stream of responses.
423    Stream(mpsc::Sender<Result<Option<Message<Bytes>>, Error>>),
424}
425
426impl ReplySender {
427    /// Send a response.
428    async fn send(&mut self, resp: ChanResp) -> Result<(), ()> {
429        match self {
430            ReplySender::Single(sender) => match sender.take() {
431                Some(sender) => sender.send(resp).map_err(|_| ()),
432                None => Err(()),
433            },
434            ReplySender::Stream(sender) => {
435                sender.send(resp.map(Some)).await.map_err(|_| ())
436            }
437        }
438    }
439
440    /// Send EOF on a response stream.
441    async fn send_eof(&mut self) -> Result<(), ()> {
442        match self {
443            ReplySender::Single(_) => {
444                panic!("cannot send EOF for Single");
445            }
446            ReplySender::Stream(sender) => {
447                sender.send(Ok(None)).await.map_err(|_| ())
448            }
449        }
450    }
451
452    /// Report whether in stream mode or not.
453    fn is_stream(&self) -> bool {
454        matches!(self, Self::Stream(_))
455    }
456}
457
458#[derive(Debug)]
459/// Enum that can either store a request for a single response or one for
460/// multiple responses.
461enum ReqSingleMulti<Req, ReqMulti> {
462    /// Single response request.
463    Single(Req),
464    /// Multi-response request.
465    Multi(ReqMulti),
466}
467
468/// A message from a [`Request`] to start a new request.
469#[derive(Debug)]
470struct ChanReq<Req, ReqMulti> {
471    /// DNS request message
472    msg: ReqSingleMulti<Req, ReqMulti>,
473
474    /// Sender to send result back to [`Request`]
475    sender: ReplySender,
476}
477
478/// A message back to [`Request`] returning a response.
479type ChanResp = Result<Message<Bytes>, Error>;
480
481/// Internal datastructure of [Transport::run] to keep track of
482/// the status of the connection.
483// The types Status and ConnState are only used in Transport
484struct Status {
485    /// State of the connection.
486    state: ConnState,
487
488    /// Do we need to include edns-tcp-keepalive in an outogoing request.
489    ///
490    /// Typically this is true at the start of the connection and gets
491    /// cleared when we successfully managed to include the option in a
492    /// request.
493    send_keepalive: bool,
494
495    /// Time we are allow to keep the connection open when idle.
496    ///
497    /// Initially we set the idle timeout to the default in config. A received
498    /// edns-tcp-keepalive option may change that.
499    idle_timeout: Duration,
500}
501
502/// Status of the connection. Used in [`Status`].
503enum ConnState {
504    /// The connection is in this state from the start and when at least
505    /// one active DNS request is present.
506    ///
507    /// The instant contains the time of the first request or the
508    /// most recent response that was received.
509    Active(Option<Instant>),
510
511    /// This state represent a connection that went idle and has an
512    /// idle timeout.
513    ///
514    /// The instant contains the time the connection went idle.
515    Idle(Instant),
516
517    /// This state represent an idle connection where either there was no
518    /// idle timeout or the idle timer expired.
519    IdleTimeout,
520
521    /// A read error occurred.
522    ReadError(Error),
523
524    /// It took too long to receive a response.
525    ReadTimeout,
526
527    /// A write error occurred.
528    WriteError(Error),
529}
530
531//--- Display
532impl std::fmt::Display for ConnState {
533    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534        match self {
535            ConnState::Active(instant) => f.write_fmt(format_args!(
536                "Active (since {}s ago)",
537                instant
538                    .map(|v| Instant::now().duration_since(v).as_secs())
539                    .unwrap_or_default()
540            )),
541            ConnState::Idle(instant) => f.write_fmt(format_args!(
542                "Idle (since {}s ago)",
543                Instant::now().duration_since(*instant).as_secs()
544            )),
545            ConnState::IdleTimeout => f.write_str("IdleTimeout"),
546            ConnState::ReadError(err) => {
547                f.write_fmt(format_args!("ReadError: {err}"))
548            }
549            ConnState::ReadTimeout => f.write_str("ReadTimeout"),
550            ConnState::WriteError(err) => {
551                f.write_fmt(format_args!("WriteError: {err}"))
552            }
553        }
554    }
555}
556
557#[derive(Debug)]
558/// State of an AXFR or IXFR responses stream for detecting the end of the
559/// stream.
560enum XFRState {
561    /// Start of AXFR.
562    AXFRInit,
563    /// After the first SOA record has been encountered.
564    AXFRFirstSoa(Serial),
565    /// Start of IXFR.
566    IXFRInit,
567    /// After the first SOA record has been encountered.
568    IXFRFirstSoa(Serial),
569    /// After the first SOA record in a diff section has been encountered.
570    IXFRFirstDiffSoa(Serial),
571    /// After the second SOA record in a diff section has been encountered.
572    IXFRSecondDiffSoa(Serial),
573    /// End of the stream has been found.
574    Done,
575    /// An error has occured.
576    Error,
577}
578
579impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti> {
580    /// Creates a new transport.
581    fn new(
582        stream: Stream,
583        config: Config,
584    ) -> (mpsc::Sender<ChanReq<Req, ReqMulti>>, Self) {
585        let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
586        (
587            sender,
588            Self {
589                config,
590                stream,
591                receiver,
592            },
593        )
594    }
595}
596
597impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti>
598where
599    Stream: AsyncRead + AsyncWrite,
600    Req: ComposeRequest,
601    ReqMulti: ComposeRequestMulti,
602{
603    /// Run the transport machinery.
604    pub async fn run(mut self) {
605        let (reply_sender, mut reply_receiver) =
606            mpsc::channel::<Message<Bytes>>(READ_REPLY_CHAN_CAP);
607
608        let (read_stream, mut write_stream) = tokio::io::split(self.stream);
609
610        let reader_fut = Self::reader(read_stream, reply_sender);
611        tokio::pin!(reader_fut);
612
613        let mut status = Status {
614            state: ConnState::Active(None),
615            idle_timeout: self.config.idle_timeout,
616            send_keepalive: true,
617        };
618        let mut query_vec =
619            Queries::<(ChanReq<Req, ReqMulti>, Option<XFRState>)>::new();
620
621        let mut reqmsg: Option<Vec<u8>> = None;
622        let mut reqmsg_offset = 0;
623
624        loop {
625            let opt_timeout = match status.state {
626                ConnState::Active(opt_instant) => {
627                    if let Some(instant) = opt_instant {
628                        let elapsed = instant.elapsed();
629                        if elapsed > self.config.response_timeout {
630                            Self::error(
631                                Error::StreamReadTimeout,
632                                &mut query_vec,
633                            )
634                            .await;
635                            status.state = ConnState::ReadTimeout;
636                            break;
637                        }
638                        Some(self.config.response_timeout - elapsed)
639                    } else {
640                        None
641                    }
642                }
643                ConnState::Idle(instant) => {
644                    let elapsed = instant.elapsed();
645                    if elapsed >= status.idle_timeout {
646                        // Move to IdleTimeout and end
647                        // the loop
648                        status.state = ConnState::IdleTimeout;
649                        break;
650                    }
651                    Some(status.idle_timeout - elapsed)
652                }
653                ConnState::IdleTimeout
654                | ConnState::ReadError(_)
655                | ConnState::WriteError(_) => None, // No timers here
656                ConnState::ReadTimeout => {
657                    panic!("should not be in loop with ReadTimeout");
658                }
659            };
660
661            // For simplicity, make sure we always have a timeout
662            let timeout = match opt_timeout {
663                Some(timeout) => timeout,
664                None =>
665                // Just use the response timeout
666                {
667                    self.config.response_timeout
668                }
669            };
670
671            let sleep_fut = sleep(timeout);
672            let recv_fut = self.receiver.recv();
673
674            let (do_write, msg) = match &reqmsg {
675                None => {
676                    let msg: &[u8] = &[];
677                    (false, msg)
678                }
679                Some(msg) => {
680                    let msg: &[u8] = msg;
681                    (true, msg)
682                }
683            };
684
685            tokio::select! {
686                biased;
687                res = &mut reader_fut => {
688                    // The reader might have sent replies before dying
689                    while let Ok(answer) = reply_receiver.try_recv() {
690                        Self::demux_reply(answer, &mut status, &mut query_vec).await;
691                    }
692
693                    match res {
694                        Ok(_) =>
695                            // The reader should not
696                            // terminate without
697                            // error.
698                            panic!("reader terminated"),
699                        Err(error) => {
700                            Self::error(error.clone(), &mut query_vec).await;
701                            status.state = ConnState::ReadError(error);
702                            // Reader failed. Break
703                            // out of loop and
704                            // shut down
705                            break
706                        }
707                    }
708                }
709                opt_answer = reply_receiver.recv() => {
710                    let answer = opt_answer.expect("reader died?");
711                    Self::demux_reply(answer, &mut status, &mut query_vec).await;
712                }
713                res = write_stream.write(&msg[reqmsg_offset..]),
714                if do_write => {
715            match res {
716            Err(error) => {
717                let error =
718                Error::StreamWriteError(Arc::new(error));
719                Self::error(error.clone(), &mut query_vec).await;
720                status.state =
721                ConnState::WriteError(error);
722                break;
723            }
724            Ok(len) => {
725                reqmsg_offset += len;
726                if reqmsg_offset >= msg.len() {
727                reqmsg = None;
728                reqmsg_offset = 0;
729                }
730            }
731            }
732                }
733                res = recv_fut, if !do_write => {
734                    match res {
735                        Some(req) => {
736                            if req.sender.is_stream() {
737                                self.config.response_timeout =
738                                    self.config.streaming_response_timeout;
739                            } else {
740                                self.config.response_timeout =
741                                    self.config.single_response_timeout;
742                            }
743                            Self::insert_req(
744                                req, &mut status, &mut reqmsg, &mut query_vec
745                            );
746                        }
747                        None => {
748                            // All references to the connection object have
749                            // been dropped. Shutdown.
750                            break;
751                        }
752                    }
753                }
754                _ = sleep_fut => {
755                    // Timeout expired, just
756                    // continue with the loop
757                }
758
759            }
760
761            // Check if the connection is idle
762            match status.state {
763                ConnState::Active(_) | ConnState::Idle(_) => {
764                    // Keep going
765                }
766                ConnState::IdleTimeout => break,
767                ConnState::ReadError(_)
768                | ConnState::ReadTimeout
769                | ConnState::WriteError(_) => {
770                    panic!("Should not be here");
771                }
772            }
773        }
774
775        trace!("Closing TCP connecting in state: {}", status.state);
776
777        // Send FIN
778        _ = write_stream.shutdown().await;
779    }
780
781    /// This function reads a DNS message from the connection and sends
782    /// it to [Transport::run].
783    ///
784    /// Reading has to be done in two steps: first read a two octet value
785    /// the specifies the length of the message, and then read in a loop the
786    /// body of the message.
787    ///
788    /// This function is not async cancellation safe.
789    async fn reader(
790        mut sock: tokio::io::ReadHalf<Stream>,
791        sender: mpsc::Sender<Message<Bytes>>,
792    ) -> Result<(), Error> {
793        loop {
794            let read_res = sock.read_u16().await;
795            let len = match read_res {
796                Ok(len) => len,
797                Err(error) => {
798                    return Err(Error::StreamReadError(Arc::new(error)));
799                }
800            } as usize;
801
802            let mut buf = BytesMut::with_capacity(len);
803
804            loop {
805                let curlen = buf.len();
806                if curlen >= len {
807                    if curlen > len {
808                        panic!(
809                        "reader: got too much data {curlen}, expetect {len}");
810                    }
811
812                    // We got what we need
813                    break;
814                }
815
816                let read_res = sock.read_buf(&mut buf).await;
817
818                match read_res {
819                    Ok(readlen) => {
820                        if readlen == 0 {
821                            return Err(Error::StreamUnexpectedEndOfData);
822                        }
823                    }
824                    Err(error) => {
825                        return Err(Error::StreamReadError(Arc::new(error)));
826                    }
827                };
828
829                // Check if we are done at the head of the loop
830            }
831
832            let reply_message = Message::<Bytes>::from_octets(buf.into());
833            match reply_message {
834                Ok(answer) => {
835                    sender
836                        .send(answer)
837                        .await
838                        .expect("can't send reply to run");
839                }
840                Err(_) => {
841                    // The only possible error is short message
842                    return Err(Error::ShortMessage);
843                }
844            }
845        }
846    }
847
848    /// Reports an error to all outstanding queries.
849    async fn error(
850        error: Error,
851        query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
852    ) {
853        // Update all requests that are in progress. Don't wait for
854        // any reply that may be on its way.
855        for (mut req, _) in query_vec.drain() {
856            _ = req.sender.send(Err(error.clone())).await;
857        }
858    }
859
860    /// Handles received EDNS options.
861    ///
862    /// In particular, it processes the edns-tcp-keepalive option.
863    fn handle_opts<Octs: Octets + AsRef<[u8]>>(
864        opts: &OptRecord<Octs>,
865        status: &mut Status,
866    ) {
867        // XXX This handles _all_ keepalive options. I think just using the
868        //     first option as returned by Opt::tcp_keepalive should be good
869        //     enough? -- M.
870        for option in opts.opt().iter().flatten() {
871            if let AllOptData::TcpKeepalive(tcpkeepalive) = option {
872                Self::handle_keepalive(tcpkeepalive, status);
873            }
874        }
875    }
876
877    /// Demultiplexes a response and sends it to the right query.
878    ///
879    /// In addition, the status is updated to IdleTimeout or Idle if there
880    /// are no remaining pending requests.
881    async fn demux_reply(
882        answer: Message<Bytes>,
883        status: &mut Status,
884        query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
885    ) {
886        // Check for an edns-tcp-keepalive option
887        if let Some(opts) = answer.opt() {
888            Self::handle_opts(&opts, status);
889        };
890
891        // We got an answer, reset the timer
892        status.state = ConnState::Active(Some(Instant::now()));
893
894        let id = answer.header().id();
895
896        // Get the correct query and send it the reply.
897        let (mut req, mut opt_xfr_data) = match query_vec.try_remove(id) {
898            Some(req) => req,
899            None => {
900                // No query with this ID. We should
901                // mark the connection as broken
902                return;
903            }
904        };
905        let mut send_eof = false;
906        let answer = if match &req.msg {
907            ReqSingleMulti::Single(msg) => msg.is_answer(answer.for_slice()),
908            ReqSingleMulti::Multi(msg) => {
909                let xfr_data =
910                    opt_xfr_data.expect("xfr_data should be present");
911                let (eof, xfr_data, is_answer) =
912                    check_stream(msg, xfr_data, &answer);
913                send_eof = eof;
914                opt_xfr_data = Some(xfr_data);
915                is_answer
916            }
917        } {
918            Ok(answer)
919        } else {
920            Err(Error::WrongReplyForQuery)
921        };
922        _ = req.sender.send(answer).await;
923
924        if req.sender.is_stream() {
925            if send_eof {
926                _ = req.sender.send_eof().await;
927            } else {
928                query_vec.insert_at(id, (req, opt_xfr_data));
929            }
930        }
931
932        if query_vec.is_empty() {
933            // Clear the activity timer. There is no need to do
934            // this because state will be set to either IdleTimeout
935            // or Idle just below. However, it is nicer to keep
936            // this independent.
937            status.state = ConnState::Active(None);
938
939            status.state = if status.idle_timeout.is_zero() {
940                // Assume that we can just move to IdleTimeout
941                // state
942                ConnState::IdleTimeout
943            } else {
944                ConnState::Idle(Instant::now())
945            }
946        }
947    }
948
949    /// Insert a request in query_vec and return the request to be sent
950    /// in *reqmsg.
951    ///
952    /// First the status is checked, an error is returned if not Active or
953    /// idle. Addend a edns-tcp-keepalive option if needed.
954    // Note: maybe reqmsg should be a return value.
955    fn insert_req(
956        mut req: ChanReq<Req, ReqMulti>,
957        status: &mut Status,
958        reqmsg: &mut Option<Vec<u8>>,
959        query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
960    ) {
961        match &status.state {
962            ConnState::Active(timer) => {
963                // Set timer if we don't have one already
964                if timer.is_none() {
965                    status.state = ConnState::Active(Some(Instant::now()));
966                }
967            }
968            ConnState::Idle(_) => {
969                // Go back to active
970                status.state = ConnState::Active(Some(Instant::now()));
971            }
972            ConnState::IdleTimeout => {
973                // The connection has been closed. Report error
974                _ = req.sender.send(Err(Error::StreamIdleTimeout));
975                return;
976            }
977            ConnState::ReadError(error) => {
978                _ = req.sender.send(Err(error.clone()));
979                return;
980            }
981            ConnState::ReadTimeout => {
982                _ = req.sender.send(Err(Error::StreamReadTimeout));
983                return;
984            }
985            ConnState::WriteError(error) => {
986                _ = req.sender.send(Err(error.clone()));
987                return;
988            }
989        }
990
991        let xfr_data = match &req.msg {
992            ReqSingleMulti::Single(_) => None,
993            ReqSingleMulti::Multi(msg) => {
994                let qtype = match msg.to_message().and_then(|m| {
995                    m.sole_question()
996                        .map_err(|_| Error::MessageParseError)
997                        .map(|q| q.qtype())
998                }) {
999                    Ok(msg) => msg,
1000                    Err(e) => {
1001                        _ = req.sender.send(Err(e));
1002                        return;
1003                    }
1004                };
1005                if qtype == Rtype::AXFR {
1006                    Some(XFRState::AXFRInit)
1007                } else if qtype == Rtype::IXFR {
1008                    Some(XFRState::IXFRInit)
1009                } else {
1010                    // Stream requests should be either AXFR or IXFR.
1011                    _ = req.sender.send(Err(Error::FormError));
1012                    return;
1013                }
1014            }
1015        };
1016
1017        // Note that insert may fail if there are too many
1018        // outstanding queries. First call insert before checking
1019        // send_keepalive.
1020        let (index, (req, _)) = match query_vec.insert((req, xfr_data)) {
1021            Ok(res) => res,
1022            Err((mut req, _)) => {
1023                // Send an appropriate error and return.
1024                _ = req
1025                    .sender
1026                    .send(Err(Error::StreamTooManyOutstandingQueries));
1027                return;
1028            }
1029        };
1030
1031        // We set the ID to the array index. Defense in depth
1032        // suggests that a random ID is better because it works
1033        // even if TCP sequence numbers could be predicted. However,
1034        // Section 9.3 of RFC 5452 recommends retrying over TCP
1035        // if many spoofed answers arrive over UDP: "TCP, by the
1036        // nature of its use of sequence numbers, is far more
1037        // resilient against forgery by third parties."
1038
1039        let hdr = match &mut req.msg {
1040            ReqSingleMulti::Single(msg) => msg.header_mut(),
1041            ReqSingleMulti::Multi(msg) => msg.header_mut(),
1042        };
1043        hdr.set_id(index);
1044
1045        if status.send_keepalive
1046            && match &mut req.msg {
1047                ReqSingleMulti::Single(msg) => {
1048                    msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1049                }
1050                ReqSingleMulti::Multi(msg) => {
1051                    msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1052                }
1053            }
1054        {
1055            status.send_keepalive = false;
1056        }
1057
1058        match Self::convert_query(&req.msg) {
1059            Ok(msg) => {
1060                *reqmsg = Some(msg);
1061            }
1062            Err(err) => {
1063                // Take the sender out again and return the error.
1064                if let Some((mut req, _)) = query_vec.try_remove(index) {
1065                    _ = req.sender.send(Err(err));
1066                }
1067            }
1068        }
1069    }
1070
1071    /// Handle a received edns-tcp-keepalive option.
1072    fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) {
1073        if let Some(value) = opt_value.timeout() {
1074            let value_dur = Duration::from(value);
1075            status.idle_timeout = value_dur;
1076        }
1077    }
1078
1079    /// Convert the query message to a vector.
1080    fn convert_query(
1081        msg: &ReqSingleMulti<Req, ReqMulti>,
1082    ) -> Result<Vec<u8>, Error> {
1083        match msg {
1084            ReqSingleMulti::Single(msg) => {
1085                let mut target = StreamTarget::new_vec();
1086                msg.append_message(&mut target)
1087                    .map_err(|_| Error::StreamLongMessage)?;
1088                Ok(target.into_target())
1089            }
1090            ReqSingleMulti::Multi(msg) => {
1091                let target = StreamTarget::new_vec();
1092                let target = msg
1093                    .append_message(target)
1094                    .map_err(|_| Error::StreamLongMessage)?;
1095                Ok(target.finish().into_target())
1096            }
1097        }
1098    }
1099}
1100
1101/// Upstate the response stream state based on a response message.
1102fn check_stream<CRM>(
1103    msg: &CRM,
1104    mut xfr_state: XFRState,
1105    answer: &Message<Bytes>,
1106) -> (bool, XFRState, bool)
1107where
1108    CRM: ComposeRequestMulti,
1109{
1110    // First check if the reply matches the request.
1111    // RFC 5936, Section 2.2.2:
1112    // "In the first response message, this section MUST be copied from the
1113    // query.  In subsequent messages, this section MAY be copied from the
1114    // query, or it MAY be empty.  However, in an error response message
1115    // (see Section 2.2), this section MUST be copied as well."
1116    match xfr_state {
1117        XFRState::AXFRInit | XFRState::IXFRInit => {
1118            if !msg.is_answer(answer.for_slice()) {
1119                xfr_state = XFRState::Error;
1120                // If we detect an error, then keep the stream open. We are
1121                // likely out of sync with respect to the sender.
1122                return (false, xfr_state, false);
1123            }
1124        }
1125        XFRState::AXFRFirstSoa(_)
1126        | XFRState::IXFRFirstSoa(_)
1127        | XFRState::IXFRFirstDiffSoa(_)
1128        | XFRState::IXFRSecondDiffSoa(_) =>
1129            // No need to check anything.
1130            {}
1131        XFRState::Done => {
1132            // We should not be here. Switch to error state.
1133            xfr_state = XFRState::Error;
1134            return (false, xfr_state, false);
1135        }
1136        XFRState::Error =>
1137        // Keep the stream open.
1138        {
1139            return (false, xfr_state, false)
1140        }
1141    }
1142
1143    // Then check if the reply status an error.
1144    if answer.header().rcode() != Rcode::NOERROR {
1145        // Also check if this answers the question.
1146        if !msg.is_answer(answer.for_slice()) {
1147            xfr_state = XFRState::Error;
1148            // If we detect an error, then keep the stream open. We are
1149            // likely out of sync with respect to the sender.
1150            return (false, xfr_state, false);
1151        }
1152        return (true, xfr_state, true);
1153    }
1154
1155    let ans_sec = match answer.answer() {
1156        Ok(ans) => ans,
1157        Err(_) => {
1158            // Bad message, switch to error state.
1159            xfr_state = XFRState::Error;
1160            // If we detect an error, then keep the stream open.
1161            return (true, xfr_state, false);
1162        }
1163    };
1164    for rr in
1165        ans_sec.into_records::<AllRecordData<Bytes, ParsedName<Bytes>>>()
1166    {
1167        let rr = match rr {
1168            Ok(rr) => rr,
1169            Err(_) => {
1170                // Bad message, switch to error state.
1171                xfr_state = XFRState::Error;
1172                return (true, xfr_state, false);
1173            }
1174        };
1175        match xfr_state {
1176            XFRState::AXFRInit => {
1177                // The first record has to be a SOA record.
1178                if let AllRecordData::Soa(soa) = rr.data() {
1179                    xfr_state = XFRState::AXFRFirstSoa(soa.serial());
1180                    continue;
1181                }
1182                // Bad data. Switch to error status.
1183                xfr_state = XFRState::Error;
1184                return (false, xfr_state, false);
1185            }
1186            XFRState::AXFRFirstSoa(serial) => {
1187                if let AllRecordData::Soa(soa) = rr.data() {
1188                    if serial == soa.serial() {
1189                        // We found a match.
1190                        xfr_state = XFRState::Done;
1191                        continue;
1192                    }
1193
1194                    // Serial does not match. Move to error state.
1195                    xfr_state = XFRState::Error;
1196                    return (false, xfr_state, false);
1197                }
1198
1199                // Any other record, just continue.
1200            }
1201            XFRState::IXFRInit => {
1202                // The first record has to be a SOA record.
1203                if let AllRecordData::Soa(soa) = rr.data() {
1204                    xfr_state = XFRState::IXFRFirstSoa(soa.serial());
1205                    continue;
1206                }
1207                // Bad data. Switch to error status.
1208                xfr_state = XFRState::Error;
1209                return (false, xfr_state, false);
1210            }
1211            XFRState::IXFRFirstSoa(serial) => {
1212                // We have three possibilities:
1213                // 1) The record is not a SOA. In that case the format is AXFR.
1214                // 2) The record is a SOA and the serial is not the current
1215                //    serial. That is expected for an IXFR format. Move to
1216                //    IXFRFirstDiffSoa.
1217                // 3) The record is a SOA and the serial is equal to the
1218                //    current serial. Treat this as a strange empty AXFR.
1219                if let AllRecordData::Soa(soa) = rr.data() {
1220                    if serial == soa.serial() {
1221                        // We found a match.
1222                        xfr_state = XFRState::Done;
1223                        continue;
1224                    }
1225
1226                    xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1227                    continue;
1228                }
1229
1230                // Any other record, move to AXFRFirstSoa.
1231                xfr_state = XFRState::AXFRFirstSoa(serial);
1232            }
1233            XFRState::IXFRFirstDiffSoa(serial) => {
1234                // Move to IXFRSecondDiffSoa if the record is a SOA record,
1235                // otherwise stay in the current state.
1236                if let AllRecordData::Soa(_) = rr.data() {
1237                    xfr_state = XFRState::IXFRSecondDiffSoa(serial);
1238                    continue;
1239                }
1240
1241                // Any other record, just continue.
1242            }
1243            XFRState::IXFRSecondDiffSoa(serial) => {
1244                // Move to Done if the record is a SOA record and the
1245                // serial is the one from the first SOA record, move to
1246                // IXFRFirstDiffSoa for any other SOA record and
1247                // otherwise stay in the current state.
1248                if let AllRecordData::Soa(soa) = rr.data() {
1249                    if serial == soa.serial() {
1250                        // We found a match.
1251                        xfr_state = XFRState::Done;
1252                        continue;
1253                    }
1254
1255                    xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1256                    continue;
1257                }
1258
1259                // Any other record, just continue.
1260            }
1261            XFRState::Done => {
1262                // We got a record after we are done. Switch to error state.
1263                xfr_state = XFRState::Error;
1264                return (false, xfr_state, false);
1265            }
1266            XFRState::Error => panic!("should not be here"),
1267        }
1268    }
1269
1270    // Check the final state.
1271    match xfr_state {
1272        XFRState::AXFRInit | XFRState::IXFRInit => {
1273            // Still in one of the init state. So the data section was empty.
1274            // Switch to error state.
1275            xfr_state = XFRState::Error;
1276            return (false, xfr_state, false);
1277        }
1278        XFRState::AXFRFirstSoa(_)
1279        | XFRState::IXFRFirstDiffSoa(_)
1280        | XFRState::IXFRSecondDiffSoa(_) =>
1281            // Just continue.
1282            {}
1283        XFRState::IXFRFirstSoa(_) => {
1284            // We are still in IXFRFirstSoa. Assume the other side doesn't
1285            // have anything more to say. We could check the SOA serial in
1286            // the request. Just assume that we are done.
1287            xfr_state = XFRState::Done;
1288            return (true, xfr_state, true);
1289        }
1290        XFRState::Done => return (true, xfr_state, true),
1291        XFRState::Error => unreachable!(),
1292    }
1293
1294    // (eof, xfr_data, is_answer)
1295    (false, xfr_state, true)
1296}
1297
1298//------------ Queries -------------------------------------------------------
1299
1300/// Mapping outstanding queries to their ID.
1301///
1302/// This is generic over anything rather than our concrete request type for
1303/// easier testing.
1304#[derive(Clone, Debug)]
1305struct Queries<T> {
1306    /// The number of elements in `vec` that are not None.
1307    count: usize,
1308
1309    /// Index in `vec` where to look for a space for a new query.
1310    curr: usize,
1311
1312    /// Vector of senders to forward a DNS reply message (or error) to.
1313    vec: Vec<Option<T>>,
1314}
1315
1316impl<T> Queries<T> {
1317    /// Creates a new empty value.
1318    fn new() -> Self {
1319        Self {
1320            count: 0,
1321            curr: 0,
1322            vec: Vec::new(),
1323        }
1324    }
1325
1326    /// Returns whether there are no more outstanding queries.
1327    fn is_empty(&self) -> bool {
1328        self.count == 0
1329    }
1330
1331    /// Inserts the given query.
1332    ///
1333    /// Upon success, returns the index and a mutable reference to the stored
1334    /// query.
1335    ///
1336    /// Upon error, which means the set is full, returns the query.
1337    fn insert(&mut self, req: T) -> Result<(u16, &mut T), T> {
1338        // Fail if there are to many entries already in this vector
1339        // We cannot have more than u16::MAX entries because the
1340        // index needs to fit in an u16. For efficiency we want to
1341        // keep the vector half empty. So we return a failure if
1342        // 2*count > u16::MAX
1343        if 2 * self.count > u16::MAX as usize {
1344            return Err(req);
1345        }
1346
1347        // If more than half the vec is empty, we try and find the index of
1348        // an empty slot.
1349        let idx = if self.vec.len() >= 2 * self.count {
1350            let mut found = None;
1351            for idx in self.curr..self.vec.len() {
1352                if self.vec[idx].is_none() {
1353                    found = Some(idx);
1354                    break;
1355                }
1356            }
1357            found
1358        } else {
1359            None
1360        };
1361
1362        // If we have an index, we can insert there, otherwise we need to
1363        // append.
1364        let idx = match idx {
1365            Some(idx) => {
1366                self.vec[idx] = Some(req);
1367                idx
1368            }
1369            None => {
1370                let idx = self.vec.len();
1371                self.vec.push(Some(req));
1372                idx
1373            }
1374        };
1375
1376        self.count += 1;
1377        if idx == self.curr {
1378            self.curr += 1;
1379        }
1380        let req = self.vec[idx].as_mut().expect("no inserted item?");
1381        let idx = u16::try_from(idx).expect("query vec too large");
1382        Ok((idx, req))
1383    }
1384
1385    /// Inserts the given query at a specified position. A pre-condition is
1386    /// is that the slot has to be empty.
1387    fn insert_at(&mut self, id: u16, req: T) {
1388        let id = id as usize;
1389        self.vec[id] = Some(req);
1390
1391        self.count += 1;
1392        if id == self.curr {
1393            self.curr += 1;
1394        }
1395    }
1396
1397    /// Tries to remove and return the query at the given index.
1398    ///
1399    /// Returns `None` if there was no query there.
1400    fn try_remove(&mut self, index: u16) -> Option<T> {
1401        let res = self.vec.get_mut(usize::from(index))?.take()?;
1402        self.count = self.count.saturating_sub(1);
1403        self.curr = cmp::min(self.curr, index.into());
1404        Some(res)
1405    }
1406
1407    /// Removes all queries and returns an iterator over them.
1408    fn drain(&mut self) -> impl Iterator<Item = T> + '_ {
1409        let res = self.vec.drain(..).flatten(); // Skips all the `None`s.
1410        self.count = 0;
1411        self.curr = 0;
1412        res
1413    }
1414}
1415
1416//============ Tests =========================================================
1417
1418#[cfg(test)]
1419mod test {
1420    use super::*;
1421
1422    #[test]
1423    #[allow(clippy::needless_range_loop)]
1424    fn queries_insert_remove() {
1425        // Insert items, remove a few, insert a few more. Check that
1426        // everything looks right.
1427        let mut idxs = [None; 20];
1428        let mut queries = Queries::new();
1429
1430        for i in 0..12 {
1431            let (idx, item) = queries.insert(i).expect("test failed");
1432            idxs[i] = Some(idx);
1433            assert_eq!(i, *item);
1434        }
1435        assert_eq!(queries.count, 12);
1436        assert_eq!(queries.vec.iter().flatten().count(), 12);
1437
1438        for i in [1, 2, 3, 4, 7, 9] {
1439            let item = queries
1440                .try_remove(idxs[i].expect("test failed"))
1441                .expect("test failed");
1442            assert_eq!(i, item);
1443            idxs[i] = None;
1444        }
1445        assert_eq!(queries.count, 6);
1446        assert_eq!(queries.vec.iter().flatten().count(), 6);
1447
1448        for i in 12..20 {
1449            let (idx, item) = queries.insert(i).expect("test failed");
1450            idxs[i] = Some(idx);
1451            assert_eq!(i, *item);
1452        }
1453        assert_eq!(queries.count, 14);
1454        assert_eq!(queries.vec.iter().flatten().count(), 14);
1455
1456        for i in 0..20 {
1457            if let Some(idx) = idxs[i] {
1458                let item = queries.try_remove(idx).expect("test failed");
1459                assert_eq!(i, item);
1460            }
1461        }
1462        assert_eq!(queries.count, 0);
1463        assert_eq!(queries.vec.iter().flatten().count(), 0);
1464    }
1465
1466    #[test]
1467    fn queries_overrun() {
1468        // This is just a quick check that inserting to much stuff doesn’t
1469        // break.
1470        let mut queries = Queries::new();
1471        for i in 0..usize::from(u16::MAX) * 2 {
1472            let _ = queries.insert(i);
1473        }
1474    }
1475}