domain/net/client/
stream.rs

1//! A client transport using a stream socket.
2
3// RFC 7766 describes DNS over TCP
4// RFC 7828 describes the edns-tcp-keepalive option
5
6use super::request::{
7    ComposeRequest, ComposeRequestMulti, Error, GetResponse,
8    GetResponseMulti, SendRequest, SendRequestMulti,
9};
10use crate::base::iana::{Rcode, Rtype};
11use crate::base::message::Message;
12use crate::base::message_builder::StreamTarget;
13use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive};
14use crate::base::{ParsedName, Serial};
15use crate::rdata::AllRecordData;
16use crate::utils::config::DefMinMax;
17use bytes::{Bytes, BytesMut};
18use core::cmp;
19use octseq::Octets;
20use std::boxed::Box;
21use std::fmt::Debug;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use std::vec::Vec;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28use tokio::sync::{mpsc, oneshot};
29use tokio::time::sleep;
30use tracing::trace;
31
32//------------ Configuration Constants ----------------------------------------
33
34/// Default response timeout.
35///
36/// Note: nsd has 120 seconds, unbound has 3 seconds.
37const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
38    Duration::from_secs(19),
39    Duration::from_millis(1),
40    Duration::from_secs(600),
41);
42
43/// Default idle timeout.
44///
45/// Note that RFC 7766, Secton 6.2.3 says: "DNS clients SHOULD close the
46/// TCP connection of an idle session, unless an idle timeout has been
47/// established using some other signalling mechanism, for example,
48/// [edns-tcp-keepalive]."
49/// However, RFC 7858, Section 3.4 says: "In order to amortize TCP and TLS
50/// connection setup costs, clients and servers SHOULD NOT immediately close
51/// a connection after each response.  Instead, clients and servers SHOULD
52/// reuse existing connections for subsequent queries as long as they have
53/// sufficient resources.".
54/// We set the default to 10 seconds, which is that same as what stubby
55/// uses. Minimum zero to allow idle timeout to be disabled. Assume that
56/// one hour is more than enough as maximum.
57const IDLE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
58    Duration::from_secs(10),
59    Duration::ZERO,
60    Duration::from_secs(3600),
61);
62
63/// Capacity of the channel that transports `ChanReq`s.
64const DEF_CHAN_CAP: usize = 8;
65
66/// Capacity of a private channel dispatching responses.
67const READ_REPLY_CHAN_CAP: usize = 8;
68
69//------------ Config ---------------------------------------------------------
70
71/// Configuration for a stream transport connection.
72#[derive(Clone, Debug)]
73pub struct Config {
74    /// Response timeout currently in effect.
75    response_timeout: Duration,
76
77    /// Single response timeout.
78    single_response_timeout: Duration,
79
80    /// Streaming response timeout.
81    streaming_response_timeout: Duration,
82
83    /// Default idle timeout.
84    ///
85    /// This value is used if the other side does not send a TcpKeepalive
86    /// option.
87    idle_timeout: Duration,
88}
89
90impl Config {
91    /// Creates a new, default config.
92    pub fn new() -> Self {
93        Default::default()
94    }
95
96    /// Returns the response timeout.
97    ///
98    /// This is the amount of time to wait on a non-idle connection for a
99    /// response to an outstanding request.
100    pub fn response_timeout(&self) -> Duration {
101        self.response_timeout
102    }
103
104    /// Sets the response timeout.
105    ///
106    /// For requests where ComposeRequest::is_streaming() returns true see
107    /// set_streaming_response_timeout() instead.
108    ///
109    /// Excessive values are quietly trimmed.
110    //
111    //  XXX Maybe that’s wrong and we should rather return an error?
112    pub fn set_response_timeout(&mut self, timeout: Duration) {
113        self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
114        self.streaming_response_timeout = self.response_timeout;
115    }
116
117    /// Returns the streaming response timeout.
118    pub fn streaming_response_timeout(&self) -> Duration {
119        self.streaming_response_timeout
120    }
121
122    /// Sets the streaming response timeout.
123    ///
124    /// Only used for requests where ComposeRequest::is_streaming() returns
125    /// true as it is typically desirable that such response streams be
126    /// allowed to complete even if the individual responses arrive very
127    /// slowly.
128    ///
129    /// Excessive values are quietly trimmed.
130    pub fn set_streaming_response_timeout(&mut self, timeout: Duration) {
131        self.streaming_response_timeout = RESPONSE_TIMEOUT.limit(timeout);
132    }
133
134    /// Returns the initial idle timeout, if set.
135    pub fn idle_timeout(&self) -> Duration {
136        self.idle_timeout
137    }
138
139    /// Sets the initial idle timeout.
140    ///
141    /// By default the stream is immediately closed if there are no pending
142    /// requests or responses.
143    ///
144    /// Set this to allow requests to be sent in sequence with delays between
145    /// such as a SOA query followed by AXFR for more efficient use of the
146    /// stream per RFC 9103.
147    ///
148    /// Note: May be overridden by an RFC 7828 edns-tcp-keepalive timeout
149    /// received from a server.
150    ///
151    /// Excessive values are quietly trimmed.
152    pub fn set_idle_timeout(&mut self, timeout: Duration) {
153        self.idle_timeout = IDLE_TIMEOUT.limit(timeout)
154    }
155}
156
157impl Default for Config {
158    fn default() -> Self {
159        Self {
160            response_timeout: RESPONSE_TIMEOUT.default(),
161            single_response_timeout: RESPONSE_TIMEOUT.default(),
162            streaming_response_timeout: RESPONSE_TIMEOUT.default(),
163            idle_timeout: IDLE_TIMEOUT.default(),
164        }
165    }
166}
167
168//------------ Connection -----------------------------------------------------
169
170/// A connection to a single stream transport.
171#[derive(Debug)]
172pub struct Connection<Req, ReqMulti> {
173    /// The sender half of the request channel.
174    sender: mpsc::Sender<ChanReq<Req, ReqMulti>>,
175}
176
177impl<Req, ReqMulti> Connection<Req, ReqMulti> {
178    /// Creates a new stream transport with default configuration.
179    ///
180    /// Returns a connection and a future that drives the transport using
181    /// the provided stream. This future needs to be run while any queries
182    /// are active. This is most easly achieved by spawning it into a runtime.
183    /// It terminates when the last connection is dropped.
184    pub fn new<Stream>(
185        stream: Stream,
186    ) -> (Self, Transport<Stream, Req, ReqMulti>) {
187        Self::with_config(stream, Default::default())
188    }
189
190    /// Creates a new stream transport with the given configuration.
191    ///
192    /// Returns a connection and a future that drives the transport using
193    /// the provided stream. This future needs to be run while any queries
194    /// are active. This is most easly achieved by spawning it into a runtime.
195    /// It terminates when the last connection is dropped.
196    pub fn with_config<Stream>(
197        stream: Stream,
198        config: Config,
199    ) -> (Self, Transport<Stream, Req, ReqMulti>) {
200        let (sender, transport) = Transport::new(stream, config);
201        (Self { sender }, transport)
202    }
203}
204
205impl<Req, ReqMulti> Connection<Req, ReqMulti>
206where
207    Req: ComposeRequest + 'static,
208    ReqMulti: ComposeRequestMulti + 'static,
209{
210    /// Start a DNS request.
211    ///
212    /// This function takes a precomposed message as a parameter and
213    /// returns a [`Message`] object wrapped in a [`Result`].
214    async fn handle_request_impl(
215        self,
216        msg: Req,
217    ) -> Result<Message<Bytes>, Error> {
218        let (sender, receiver) = oneshot::channel();
219        let sender = ReplySender::Single(Some(sender));
220        let msg = ReqSingleMulti::Single(msg);
221        let req = ChanReq { sender, msg };
222        self.sender.send(req).await.map_err(|_| {
223            // Send error. The receiver is gone, this means that the
224            // connection is closed.
225            Error::ConnectionClosed
226        })?;
227        receiver.await.map_err(|_| Error::StreamReceiveError)?
228    }
229
230    /// Start a streaming request.
231    async fn handle_streaming_request_impl(
232        self,
233        msg: ReqMulti,
234        sender: mpsc::Sender<Result<Option<Message<Bytes>>, Error>>,
235    ) -> Result<(), Error> {
236        let reply_sender = ReplySender::Stream(sender);
237        let msg = ReqSingleMulti::Multi(msg);
238        let req = ChanReq {
239            sender: reply_sender,
240            msg,
241        };
242        self.sender.send(req).await.map_err(|_| {
243            // Send error. The receiver is gone, this means that the
244            // connection is closed.
245            Error::ConnectionClosed
246        })?;
247        Ok(())
248    }
249
250    /// Returns a request handler for a request.
251    pub fn get_request(&self, request_msg: Req) -> Request {
252        Request {
253            fut: Box::pin(self.clone().handle_request_impl(request_msg)),
254        }
255    }
256
257    /// Return a multiple-response request handler for a request.
258    fn get_streaming_request(&self, request_msg: ReqMulti) -> RequestMulti {
259        let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
260        RequestMulti {
261            stream: receiver,
262            fut: Some(Box::pin(
263                self.clone()
264                    .handle_streaming_request_impl(request_msg, sender),
265            )),
266        }
267    }
268}
269
270impl<Req, ReqMulti> Clone for Connection<Req, ReqMulti> {
271    fn clone(&self) -> Self {
272        Self {
273            sender: self.sender.clone(),
274        }
275    }
276}
277
278impl<Req, ReqMulti> SendRequest<Req> for Connection<Req, ReqMulti>
279where
280    Req: ComposeRequest + 'static,
281    ReqMulti: ComposeRequestMulti + Debug + Send + Sync + 'static,
282{
283    fn send_request(
284        &self,
285        request_msg: Req,
286    ) -> Box<dyn GetResponse + Send + Sync> {
287        Box::new(self.get_request(request_msg))
288    }
289}
290
291impl<Req, ReqMulti> SendRequestMulti<ReqMulti> for Connection<Req, ReqMulti>
292where
293    Req: ComposeRequest + Debug + Send + Sync + 'static,
294    ReqMulti: ComposeRequestMulti + 'static,
295{
296    fn send_request(
297        &self,
298        request_msg: ReqMulti,
299    ) -> Box<dyn GetResponseMulti + Send + Sync> {
300        Box::new(self.get_streaming_request(request_msg))
301    }
302}
303
304//------------ Request -------------------------------------------------------
305
306/// An active request.
307pub struct Request {
308    /// The underlying future.
309    fut: Pin<
310        Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
311    >,
312}
313
314impl Request {
315    /// Async function that waits for the future stored in Request to complete.
316    async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
317        (&mut self.fut).await
318    }
319}
320
321impl GetResponse for Request {
322    fn get_response(
323        &mut self,
324    ) -> Pin<
325        Box<
326            dyn Future<Output = Result<Message<Bytes>, Error>>
327                + Send
328                + Sync
329                + '_,
330        >,
331    > {
332        Box::pin(self.get_response_impl())
333    }
334}
335
336impl Debug for Request {
337    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
338        f.debug_struct("Request")
339            .field("fut", &format_args!("_"))
340            .finish()
341    }
342}
343
344//------------ RequestMulti --------------------------------------------------
345
346/// An active request.
347pub struct RequestMulti {
348    /// Receiver for a stream of responses.
349    stream: mpsc::Receiver<Result<Option<Message<Bytes>>, Error>>,
350
351    /// The underlying future.
352    #[allow(clippy::type_complexity)]
353    fut: Option<
354        Pin<Box<dyn Future<Output = Result<(), Error>> + Send + Sync>>,
355    >,
356}
357
358impl RequestMulti {
359    /// Async function that waits for the future stored in Request to complete.
360    async fn get_response_impl(
361        &mut self,
362    ) -> Result<Option<Message<Bytes>>, Error> {
363        if self.fut.is_some() {
364            let fut = self.fut.take().expect("Some expected");
365            fut.await?;
366        }
367
368        // Fetch from the stream
369        self.stream
370            .recv()
371            .await
372            .ok_or(Error::ConnectionClosed)
373            .map_err(|_| Error::ConnectionClosed)?
374    }
375}
376
377impl GetResponseMulti for RequestMulti {
378    fn get_response(
379        &mut self,
380    ) -> Pin<
381        Box<
382            dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
383                + Send
384                + Sync
385                + '_,
386        >,
387    > {
388        let fut = self.get_response_impl();
389        Box::pin(fut)
390    }
391}
392
393impl Debug for RequestMulti {
394    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
395        f.debug_struct("Request")
396            .field("fut", &format_args!("_"))
397            .finish()
398    }
399}
400
401//------------ Transport -----------------------------------------------------
402
403/// The underlying machinery of a stream transport.
404#[derive(Debug)]
405pub struct Transport<Stream, Req, ReqMulti> {
406    /// The stream socket towards the remote end.
407    stream: Stream,
408
409    /// Transport configuration.
410    config: Config,
411
412    /// The receiver half of request channel.
413    receiver: mpsc::Receiver<ChanReq<Req, ReqMulti>>,
414}
415
416/// This is the type of sender in [ChanReq].
417#[derive(Debug)]
418enum ReplySender {
419    /// Return channel for a single response.
420    Single(Option<oneshot::Sender<ChanResp>>),
421
422    /// Return channel for a stream of responses.
423    Stream(mpsc::Sender<Result<Option<Message<Bytes>>, Error>>),
424}
425
426impl ReplySender {
427    /// Send a response.
428    async fn send(&mut self, resp: ChanResp) -> Result<(), ()> {
429        match self {
430            ReplySender::Single(sender) => match sender.take() {
431                Some(sender) => sender.send(resp).map_err(|_| ()),
432                None => Err(()),
433            },
434            ReplySender::Stream(sender) => {
435                sender.send(resp.map(Some)).await.map_err(|_| ())
436            }
437        }
438    }
439
440    /// Send EOF on a response stream.
441    async fn send_eof(&mut self) -> Result<(), ()> {
442        match self {
443            ReplySender::Single(_) => {
444                panic!("cannot send EOF for Single");
445            }
446            ReplySender::Stream(sender) => {
447                sender.send(Ok(None)).await.map_err(|_| ())
448            }
449        }
450    }
451
452    /// Report whether in stream mode or not.
453    fn is_stream(&self) -> bool {
454        matches!(self, Self::Stream(_))
455    }
456}
457
458#[derive(Debug)]
459/// Enum that can either store a request for a single response or one for
460/// multiple responses.
461enum ReqSingleMulti<Req, ReqMulti> {
462    /// Single response request.
463    Single(Req),
464    /// Multi-response request.
465    Multi(ReqMulti),
466}
467
468/// A message from a [`Request`] to start a new request.
469#[derive(Debug)]
470struct ChanReq<Req, ReqMulti> {
471    /// DNS request message
472    msg: ReqSingleMulti<Req, ReqMulti>,
473
474    /// Sender to send result back to [`Request`]
475    sender: ReplySender,
476}
477
478/// A message back to [`Request`] returning a response.
479type ChanResp = Result<Message<Bytes>, Error>;
480
481/// Internal datastructure of [Transport::run] to keep track of
482/// the status of the connection.
483// The types Status and ConnState are only used in Transport
484struct Status {
485    /// State of the connection.
486    state: ConnState,
487
488    /// Do we need to include edns-tcp-keepalive in an outogoing request.
489    ///
490    /// Typically this is true at the start of the connection and gets
491    /// cleared when we successfully managed to include the option in a
492    /// request.
493    send_keepalive: bool,
494
495    /// Time we are allow to keep the connection open when idle.
496    ///
497    /// Initially we set the idle timeout to the default in config. A received
498    /// edns-tcp-keepalive option may change that.
499    idle_timeout: Duration,
500}
501
502/// Status of the connection. Used in [`Status`].
503enum ConnState {
504    /// The connection is in this state from the start and when at least
505    /// one active DNS request is present.
506    ///
507    /// The instant contains the time of the first request or the
508    /// most recent response that was received.
509    Active(Option<Instant>),
510
511    /// This state represent a connection that went idle and has an
512    /// idle timeout.
513    ///
514    /// The instant contains the time the connection went idle.
515    Idle(Instant),
516
517    /// This state represent an idle connection where either there was no
518    /// idle timeout or the idle timer expired.
519    IdleTimeout,
520
521    /// A read error occurred.
522    ReadError(Error),
523
524    /// It took too long to receive a response.
525    ReadTimeout,
526
527    /// A write error occurred.
528    WriteError(Error),
529}
530
531//--- Display
532impl std::fmt::Display for ConnState {
533    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534        match self {
535            ConnState::Active(instant) => f.write_fmt(format_args!(
536                "Active (since {}s ago)",
537                instant
538                    .map(|v| Instant::now().duration_since(v).as_secs())
539                    .unwrap_or_default()
540            )),
541            ConnState::Idle(instant) => f.write_fmt(format_args!(
542                "Idle (since {}s ago)",
543                Instant::now().duration_since(*instant).as_secs()
544            )),
545            ConnState::IdleTimeout => f.write_str("IdleTimeout"),
546            ConnState::ReadError(err) => {
547                f.write_fmt(format_args!("ReadError: {err}"))
548            }
549            ConnState::ReadTimeout => f.write_str("ReadTimeout"),
550            ConnState::WriteError(err) => {
551                f.write_fmt(format_args!("WriteError: {err}"))
552            }
553        }
554    }
555}
556
557#[derive(Debug)]
558/// State of an AXFR or IXFR responses stream for detecting the end of the
559/// stream.
560enum XFRState {
561    /// Start of AXFR.
562    AXFRInit,
563    /// After the first SOA record has been encountered.
564    AXFRFirstSoa(Serial),
565    /// Start of IXFR.
566    IXFRInit,
567    /// After the first SOA record has been encountered.
568    IXFRFirstSoa(Serial),
569    /// After the first SOA record in a diff section has been encountered.
570    IXFRFirstDiffSoa(Serial),
571    /// After the second SOA record in a diff section has been encountered.
572    IXFRSecondDiffSoa(Serial),
573    /// End of the stream has been found.
574    Done,
575    /// An error has occured.
576    Error,
577}
578
579impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti> {
580    /// Creates a new transport.
581    fn new(
582        stream: Stream,
583        config: Config,
584    ) -> (mpsc::Sender<ChanReq<Req, ReqMulti>>, Self) {
585        let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
586        (
587            sender,
588            Self {
589                config,
590                stream,
591                receiver,
592            },
593        )
594    }
595}
596
597impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti>
598where
599    Stream: AsyncRead + AsyncWrite,
600    Req: ComposeRequest,
601    ReqMulti: ComposeRequestMulti,
602{
603    /// Run the transport machinery.
604    pub async fn run(mut self) {
605        let (reply_sender, mut reply_receiver) =
606            mpsc::channel::<Message<Bytes>>(READ_REPLY_CHAN_CAP);
607
608        let (read_stream, mut write_stream) = tokio::io::split(self.stream);
609
610        let reader_fut = Self::reader(read_stream, reply_sender);
611        tokio::pin!(reader_fut);
612
613        let mut status = Status {
614            state: ConnState::Active(None),
615            idle_timeout: self.config.idle_timeout,
616            send_keepalive: true,
617        };
618        let mut query_vec =
619            Queries::<(ChanReq<Req, ReqMulti>, Option<XFRState>)>::new();
620
621        let mut reqmsg: Option<Vec<u8>> = None;
622        let mut reqmsg_offset = 0;
623
624        loop {
625            let opt_timeout = match status.state {
626                ConnState::Active(opt_instant) => {
627                    if let Some(instant) = opt_instant {
628                        let elapsed = instant.elapsed();
629                        if elapsed > self.config.response_timeout {
630                            Self::error(
631                                Error::StreamReadTimeout,
632                                &mut query_vec,
633                            );
634                            status.state = ConnState::ReadTimeout;
635                            break;
636                        }
637                        Some(self.config.response_timeout - elapsed)
638                    } else {
639                        None
640                    }
641                }
642                ConnState::Idle(instant) => {
643                    let elapsed = instant.elapsed();
644                    if elapsed >= status.idle_timeout {
645                        // Move to IdleTimeout and end
646                        // the loop
647                        status.state = ConnState::IdleTimeout;
648                        break;
649                    }
650                    Some(status.idle_timeout - elapsed)
651                }
652                ConnState::IdleTimeout
653                | ConnState::ReadError(_)
654                | ConnState::WriteError(_) => None, // No timers here
655                ConnState::ReadTimeout => {
656                    panic!("should not be in loop with ReadTimeout");
657                }
658            };
659
660            // For simplicity, make sure we always have a timeout
661            let timeout = match opt_timeout {
662                Some(timeout) => timeout,
663                None =>
664                // Just use the response timeout
665                {
666                    self.config.response_timeout
667                }
668            };
669
670            let sleep_fut = sleep(timeout);
671            let recv_fut = self.receiver.recv();
672
673            let (do_write, msg) = match &reqmsg {
674                None => {
675                    let msg: &[u8] = &[];
676                    (false, msg)
677                }
678                Some(msg) => {
679                    let msg: &[u8] = msg;
680                    (true, msg)
681                }
682            };
683
684            tokio::select! {
685                biased;
686                res = &mut reader_fut => {
687                    match res {
688                        Ok(_) =>
689                            // The reader should not
690                            // terminate without
691                            // error.
692                            panic!("reader terminated"),
693                        Err(error) => {
694                            Self::error(error.clone(), &mut query_vec);
695                            status.state = ConnState::ReadError(error);
696                            // Reader failed. Break
697                            // out of loop and
698                            // shut down
699                            break
700                        }
701                    }
702                }
703                opt_answer = reply_receiver.recv() => {
704                    let answer = opt_answer.expect("reader died?");
705                    // Check for a edns-tcp-keepalive option
706                    let opt_record = answer.opt();
707                    if let Some(ref opts) = opt_record {
708                        Self::handle_opts(opts,
709                            &mut status);
710                    };
711                    drop(opt_record);
712                    Self::demux_reply(answer, &mut status, &mut query_vec).await;
713                }
714                res = write_stream.write(&msg[reqmsg_offset..]),
715                if do_write => {
716            match res {
717            Err(error) => {
718                let error =
719                Error::StreamWriteError(Arc::new(error));
720                Self::error(error.clone(), &mut query_vec);
721                status.state =
722                ConnState::WriteError(error);
723                break;
724            }
725            Ok(len) => {
726                reqmsg_offset += len;
727                if reqmsg_offset >= msg.len() {
728                reqmsg = None;
729                reqmsg_offset = 0;
730                }
731            }
732            }
733                }
734                res = recv_fut, if !do_write => {
735                    match res {
736                        Some(req) => {
737                            if req.sender.is_stream() {
738                                self.config.response_timeout =
739                                    self.config.streaming_response_timeout;
740                            } else {
741                                self.config.response_timeout =
742                                    self.config.single_response_timeout;
743                            }
744                            Self::insert_req(
745                                req, &mut status, &mut reqmsg, &mut query_vec
746                            );
747                        }
748                        None => {
749                            // All references to the connection object have
750                            // been dropped. Shutdown.
751                            break;
752                        }
753                    }
754                }
755                _ = sleep_fut => {
756                    // Timeout expired, just
757                    // continue with the loop
758                }
759
760            }
761
762            // Check if the connection is idle
763            match status.state {
764                ConnState::Active(_) | ConnState::Idle(_) => {
765                    // Keep going
766                }
767                ConnState::IdleTimeout => break,
768                ConnState::ReadError(_)
769                | ConnState::ReadTimeout
770                | ConnState::WriteError(_) => {
771                    panic!("Should not be here");
772                }
773            }
774        }
775
776        trace!("Closing TCP connecting in state: {}", status.state);
777
778        // Send FIN
779        _ = write_stream.shutdown().await;
780    }
781
782    /// This function reads a DNS message from the connection and sends
783    /// it to [Transport::run].
784    ///
785    /// Reading has to be done in two steps: first read a two octet value
786    /// the specifies the length of the message, and then read in a loop the
787    /// body of the message.
788    ///
789    /// This function is not async cancellation safe.
790    async fn reader(
791        mut sock: tokio::io::ReadHalf<Stream>,
792        sender: mpsc::Sender<Message<Bytes>>,
793    ) -> Result<(), Error> {
794        loop {
795            let read_res = sock.read_u16().await;
796            let len = match read_res {
797                Ok(len) => len,
798                Err(error) => {
799                    return Err(Error::StreamReadError(Arc::new(error)));
800                }
801            } as usize;
802
803            let mut buf = BytesMut::with_capacity(len);
804
805            loop {
806                let curlen = buf.len();
807                if curlen >= len {
808                    if curlen > len {
809                        panic!(
810                        "reader: got too much data {curlen}, expetect {len}");
811                    }
812
813                    // We got what we need
814                    break;
815                }
816
817                let read_res = sock.read_buf(&mut buf).await;
818
819                match read_res {
820                    Ok(readlen) => {
821                        if readlen == 0 {
822                            return Err(Error::StreamUnexpectedEndOfData);
823                        }
824                    }
825                    Err(error) => {
826                        return Err(Error::StreamReadError(Arc::new(error)));
827                    }
828                };
829
830                // Check if we are done at the head of the loop
831            }
832
833            let reply_message = Message::<Bytes>::from_octets(buf.into());
834            match reply_message {
835                Ok(answer) => {
836                    sender
837                        .send(answer)
838                        .await
839                        .expect("can't send reply to run");
840                }
841                Err(_) => {
842                    // The only possible error is short message
843                    return Err(Error::ShortMessage);
844                }
845            }
846        }
847    }
848
849    /// Reports an error to all outstanding queries.
850    fn error(
851        error: Error,
852        query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
853    ) {
854        // Update all requests that are in progress. Don't wait for
855        // any reply that may be on its way.
856        for (mut req, _) in query_vec.drain() {
857            _ = req.sender.send(Err(error.clone()));
858        }
859    }
860
861    /// Handles received EDNS options.
862    ///
863    /// In particular, it processes the edns-tcp-keepalive option.
864    fn handle_opts<Octs: Octets + AsRef<[u8]>>(
865        opts: &OptRecord<Octs>,
866        status: &mut Status,
867    ) {
868        // XXX This handles _all_ keepalive options. I think just using the
869        //     first option as returned by Opt::tcp_keepalive should be good
870        //     enough? -- M.
871        for option in opts.opt().iter().flatten() {
872            if let AllOptData::TcpKeepalive(tcpkeepalive) = option {
873                Self::handle_keepalive(tcpkeepalive, status);
874            }
875        }
876    }
877
878    /// Demultiplexes a response and sends it to the right query.
879    ///
880    /// In addition, the status is updated to IdleTimeout or Idle if there
881    /// are no remaining pending requests.
882    async fn demux_reply(
883        answer: Message<Bytes>,
884        status: &mut Status,
885        query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
886    ) {
887        // We got an answer, reset the timer
888        status.state = ConnState::Active(Some(Instant::now()));
889
890        let id = answer.header().id();
891
892        // Get the correct query and send it the reply.
893        let (mut req, mut opt_xfr_data) = match query_vec.try_remove(id) {
894            Some(req) => req,
895            None => {
896                // No query with this ID. We should
897                // mark the connection as broken
898                return;
899            }
900        };
901        let mut send_eof = false;
902        let answer = if match &req.msg {
903            ReqSingleMulti::Single(msg) => msg.is_answer(answer.for_slice()),
904            ReqSingleMulti::Multi(msg) => {
905                let xfr_data =
906                    opt_xfr_data.expect("xfr_data should be present");
907                let (eof, xfr_data, is_answer) =
908                    check_stream(msg, xfr_data, &answer);
909                send_eof = eof;
910                opt_xfr_data = Some(xfr_data);
911                is_answer
912            }
913        } {
914            Ok(answer)
915        } else {
916            Err(Error::WrongReplyForQuery)
917        };
918        _ = req.sender.send(answer).await;
919
920        if req.sender.is_stream() {
921            if send_eof {
922                _ = req.sender.send_eof().await;
923            } else {
924                query_vec.insert_at(id, (req, opt_xfr_data));
925            }
926        }
927
928        if query_vec.is_empty() {
929            // Clear the activity timer. There is no need to do
930            // this because state will be set to either IdleTimeout
931            // or Idle just below. However, it is nicer to keep
932            // this independent.
933            status.state = ConnState::Active(None);
934
935            status.state = if status.idle_timeout.is_zero() {
936                // Assume that we can just move to IdleTimeout
937                // state
938                ConnState::IdleTimeout
939            } else {
940                ConnState::Idle(Instant::now())
941            }
942        }
943    }
944
945    /// Insert a request in query_vec and return the request to be sent
946    /// in *reqmsg.
947    ///
948    /// First the status is checked, an error is returned if not Active or
949    /// idle. Addend a edns-tcp-keepalive option if needed.
950    // Note: maybe reqmsg should be a return value.
951    fn insert_req(
952        mut req: ChanReq<Req, ReqMulti>,
953        status: &mut Status,
954        reqmsg: &mut Option<Vec<u8>>,
955        query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
956    ) {
957        match &status.state {
958            ConnState::Active(timer) => {
959                // Set timer if we don't have one already
960                if timer.is_none() {
961                    status.state = ConnState::Active(Some(Instant::now()));
962                }
963            }
964            ConnState::Idle(_) => {
965                // Go back to active
966                status.state = ConnState::Active(Some(Instant::now()));
967            }
968            ConnState::IdleTimeout => {
969                // The connection has been closed. Report error
970                _ = req.sender.send(Err(Error::StreamIdleTimeout));
971                return;
972            }
973            ConnState::ReadError(error) => {
974                _ = req.sender.send(Err(error.clone()));
975                return;
976            }
977            ConnState::ReadTimeout => {
978                _ = req.sender.send(Err(Error::StreamReadTimeout));
979                return;
980            }
981            ConnState::WriteError(error) => {
982                _ = req.sender.send(Err(error.clone()));
983                return;
984            }
985        }
986
987        let xfr_data = match &req.msg {
988            ReqSingleMulti::Single(_) => None,
989            ReqSingleMulti::Multi(msg) => {
990                let qtype = match msg.to_message().and_then(|m| {
991                    m.sole_question()
992                        .map_err(|_| Error::MessageParseError)
993                        .map(|q| q.qtype())
994                }) {
995                    Ok(msg) => msg,
996                    Err(e) => {
997                        _ = req.sender.send(Err(e));
998                        return;
999                    }
1000                };
1001                if qtype == Rtype::AXFR {
1002                    Some(XFRState::AXFRInit)
1003                } else if qtype == Rtype::IXFR {
1004                    Some(XFRState::IXFRInit)
1005                } else {
1006                    // Stream requests should be either AXFR or IXFR.
1007                    _ = req.sender.send(Err(Error::FormError));
1008                    return;
1009                }
1010            }
1011        };
1012
1013        // Note that insert may fail if there are too many
1014        // outstanding queries. First call insert before checking
1015        // send_keepalive.
1016        let (index, (req, _)) = match query_vec.insert((req, xfr_data)) {
1017            Ok(res) => res,
1018            Err((mut req, _)) => {
1019                // Send an appropriate error and return.
1020                _ = req
1021                    .sender
1022                    .send(Err(Error::StreamTooManyOutstandingQueries));
1023                return;
1024            }
1025        };
1026
1027        // We set the ID to the array index. Defense in depth
1028        // suggests that a random ID is better because it works
1029        // even if TCP sequence numbers could be predicted. However,
1030        // Section 9.3 of RFC 5452 recommends retrying over TCP
1031        // if many spoofed answers arrive over UDP: "TCP, by the
1032        // nature of its use of sequence numbers, is far more
1033        // resilient against forgery by third parties."
1034
1035        let hdr = match &mut req.msg {
1036            ReqSingleMulti::Single(msg) => msg.header_mut(),
1037            ReqSingleMulti::Multi(msg) => msg.header_mut(),
1038        };
1039        hdr.set_id(index);
1040
1041        if status.send_keepalive
1042            && match &mut req.msg {
1043                ReqSingleMulti::Single(msg) => {
1044                    msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1045                }
1046                ReqSingleMulti::Multi(msg) => {
1047                    msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1048                }
1049            }
1050        {
1051            status.send_keepalive = false;
1052        }
1053
1054        match Self::convert_query(&req.msg) {
1055            Ok(msg) => {
1056                *reqmsg = Some(msg);
1057            }
1058            Err(err) => {
1059                // Take the sender out again and return the error.
1060                if let Some((mut req, _)) = query_vec.try_remove(index) {
1061                    _ = req.sender.send(Err(err));
1062                }
1063            }
1064        }
1065    }
1066
1067    /// Handle a received edns-tcp-keepalive option.
1068    fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) {
1069        if let Some(value) = opt_value.timeout() {
1070            let value_dur = Duration::from(value);
1071            status.idle_timeout = value_dur;
1072        }
1073    }
1074
1075    /// Convert the query message to a vector.
1076    fn convert_query(
1077        msg: &ReqSingleMulti<Req, ReqMulti>,
1078    ) -> Result<Vec<u8>, Error> {
1079        match msg {
1080            ReqSingleMulti::Single(msg) => {
1081                let mut target = StreamTarget::new_vec();
1082                msg.append_message(&mut target)
1083                    .map_err(|_| Error::StreamLongMessage)?;
1084                Ok(target.into_target())
1085            }
1086            ReqSingleMulti::Multi(msg) => {
1087                let target = StreamTarget::new_vec();
1088                let target = msg
1089                    .append_message(target)
1090                    .map_err(|_| Error::StreamLongMessage)?;
1091                Ok(target.finish().into_target())
1092            }
1093        }
1094    }
1095}
1096
1097/// Upstate the response stream state based on a response message.
1098fn check_stream<CRM>(
1099    msg: &CRM,
1100    mut xfr_state: XFRState,
1101    answer: &Message<Bytes>,
1102) -> (bool, XFRState, bool)
1103where
1104    CRM: ComposeRequestMulti,
1105{
1106    // First check if the reply matches the request.
1107    // RFC 5936, Section 2.2.2:
1108    // "In the first response message, this section MUST be copied from the
1109    // query.  In subsequent messages, this section MAY be copied from the
1110    // query, or it MAY be empty.  However, in an error response message
1111    // (see Section 2.2), this section MUST be copied as well."
1112    match xfr_state {
1113        XFRState::AXFRInit | XFRState::IXFRInit => {
1114            if !msg.is_answer(answer.for_slice()) {
1115                xfr_state = XFRState::Error;
1116                // If we detect an error, then keep the stream open. We are
1117                // likely out of sync with respect to the sender.
1118                return (false, xfr_state, false);
1119            }
1120        }
1121        XFRState::AXFRFirstSoa(_)
1122        | XFRState::IXFRFirstSoa(_)
1123        | XFRState::IXFRFirstDiffSoa(_)
1124        | XFRState::IXFRSecondDiffSoa(_) =>
1125            // No need to check anything.
1126            {}
1127        XFRState::Done => {
1128            // We should not be here. Switch to error state.
1129            xfr_state = XFRState::Error;
1130            return (false, xfr_state, false);
1131        }
1132        XFRState::Error =>
1133        // Keep the stream open.
1134        {
1135            return (false, xfr_state, false)
1136        }
1137    }
1138
1139    // Then check if the reply status an error.
1140    if answer.header().rcode() != Rcode::NOERROR {
1141        // Also check if this answers the question.
1142        if !msg.is_answer(answer.for_slice()) {
1143            xfr_state = XFRState::Error;
1144            // If we detect an error, then keep the stream open. We are
1145            // likely out of sync with respect to the sender.
1146            return (false, xfr_state, false);
1147        }
1148        return (true, xfr_state, true);
1149    }
1150
1151    let ans_sec = match answer.answer() {
1152        Ok(ans) => ans,
1153        Err(_) => {
1154            // Bad message, switch to error state.
1155            xfr_state = XFRState::Error;
1156            // If we detect an error, then keep the stream open.
1157            return (true, xfr_state, false);
1158        }
1159    };
1160    for rr in
1161        ans_sec.into_records::<AllRecordData<Bytes, ParsedName<Bytes>>>()
1162    {
1163        let rr = match rr {
1164            Ok(rr) => rr,
1165            Err(_) => {
1166                // Bad message, switch to error state.
1167                xfr_state = XFRState::Error;
1168                return (true, xfr_state, false);
1169            }
1170        };
1171        match xfr_state {
1172            XFRState::AXFRInit => {
1173                // The first record has to be a SOA record.
1174                if let AllRecordData::Soa(soa) = rr.data() {
1175                    xfr_state = XFRState::AXFRFirstSoa(soa.serial());
1176                    continue;
1177                }
1178                // Bad data. Switch to error status.
1179                xfr_state = XFRState::Error;
1180                return (false, xfr_state, false);
1181            }
1182            XFRState::AXFRFirstSoa(serial) => {
1183                if let AllRecordData::Soa(soa) = rr.data() {
1184                    if serial == soa.serial() {
1185                        // We found a match.
1186                        xfr_state = XFRState::Done;
1187                        continue;
1188                    }
1189
1190                    // Serial does not match. Move to error state.
1191                    xfr_state = XFRState::Error;
1192                    return (false, xfr_state, false);
1193                }
1194
1195                // Any other record, just continue.
1196            }
1197            XFRState::IXFRInit => {
1198                // The first record has to be a SOA record.
1199                if let AllRecordData::Soa(soa) = rr.data() {
1200                    xfr_state = XFRState::IXFRFirstSoa(soa.serial());
1201                    continue;
1202                }
1203                // Bad data. Switch to error status.
1204                xfr_state = XFRState::Error;
1205                return (false, xfr_state, false);
1206            }
1207            XFRState::IXFRFirstSoa(serial) => {
1208                // We have three possibilities:
1209                // 1) The record is not a SOA. In that case the format is AXFR.
1210                // 2) The record is a SOA and the serial is not the current
1211                //    serial. That is expected for an IXFR format. Move to
1212                //    IXFRFirstDiffSoa.
1213                // 3) The record is a SOA and the serial is equal to the
1214                //    current serial. Treat this as a strange empty AXFR.
1215                if let AllRecordData::Soa(soa) = rr.data() {
1216                    if serial == soa.serial() {
1217                        // We found a match.
1218                        xfr_state = XFRState::Done;
1219                        continue;
1220                    }
1221
1222                    xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1223                    continue;
1224                }
1225
1226                // Any other record, move to AXFRFirstSoa.
1227                xfr_state = XFRState::AXFRFirstSoa(serial);
1228            }
1229            XFRState::IXFRFirstDiffSoa(serial) => {
1230                // Move to IXFRSecondDiffSoa if the record is a SOA record,
1231                // otherwise stay in the current state.
1232                if let AllRecordData::Soa(_) = rr.data() {
1233                    xfr_state = XFRState::IXFRSecondDiffSoa(serial);
1234                    continue;
1235                }
1236
1237                // Any other record, just continue.
1238            }
1239            XFRState::IXFRSecondDiffSoa(serial) => {
1240                // Move to Done if the record is a SOA record and the
1241                // serial is the one from the first SOA record, move to
1242                // IXFRFirstDiffSoa for any other SOA record and
1243                // otherwise stay in the current state.
1244                if let AllRecordData::Soa(soa) = rr.data() {
1245                    if serial == soa.serial() {
1246                        // We found a match.
1247                        xfr_state = XFRState::Done;
1248                        continue;
1249                    }
1250
1251                    xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1252                    continue;
1253                }
1254
1255                // Any other record, just continue.
1256            }
1257            XFRState::Done => {
1258                // We got a record after we are done. Switch to error state.
1259                xfr_state = XFRState::Error;
1260                return (false, xfr_state, false);
1261            }
1262            XFRState::Error => panic!("should not be here"),
1263        }
1264    }
1265
1266    // Check the final state.
1267    match xfr_state {
1268        XFRState::AXFRInit | XFRState::IXFRInit => {
1269            // Still in one of the init state. So the data section was empty.
1270            // Switch to error state.
1271            xfr_state = XFRState::Error;
1272            return (false, xfr_state, false);
1273        }
1274        XFRState::AXFRFirstSoa(_)
1275        | XFRState::IXFRFirstDiffSoa(_)
1276        | XFRState::IXFRSecondDiffSoa(_) =>
1277            // Just continue.
1278            {}
1279        XFRState::IXFRFirstSoa(_) => {
1280            // We are still in IXFRFirstSoa. Assume the other side doesn't
1281            // have anything more to say. We could check the SOA serial in
1282            // the request. Just assume that we are done.
1283            xfr_state = XFRState::Done;
1284            return (true, xfr_state, true);
1285        }
1286        XFRState::Done => return (true, xfr_state, true),
1287        XFRState::Error => unreachable!(),
1288    }
1289
1290    // (eof, xfr_data, is_answer)
1291    (false, xfr_state, true)
1292}
1293
1294//------------ Queries -------------------------------------------------------
1295
1296/// Mapping outstanding queries to their ID.
1297///
1298/// This is generic over anything rather than our concrete request type for
1299/// easier testing.
1300#[derive(Clone, Debug)]
1301struct Queries<T> {
1302    /// The number of elements in `vec` that are not None.
1303    count: usize,
1304
1305    /// Index in `vec` where to look for a space for a new query.
1306    curr: usize,
1307
1308    /// Vector of senders to forward a DNS reply message (or error) to.
1309    vec: Vec<Option<T>>,
1310}
1311
1312impl<T> Queries<T> {
1313    /// Creates a new empty value.
1314    fn new() -> Self {
1315        Self {
1316            count: 0,
1317            curr: 0,
1318            vec: Vec::new(),
1319        }
1320    }
1321
1322    /// Returns whether there are no more outstanding queries.
1323    fn is_empty(&self) -> bool {
1324        self.count == 0
1325    }
1326
1327    /// Inserts the given query.
1328    ///
1329    /// Upon success, returns the index and a mutable reference to the stored
1330    /// query.
1331    ///
1332    /// Upon error, which means the set is full, returns the query.
1333    fn insert(&mut self, req: T) -> Result<(u16, &mut T), T> {
1334        // Fail if there are to many entries already in this vector
1335        // We cannot have more than u16::MAX entries because the
1336        // index needs to fit in an u16. For efficiency we want to
1337        // keep the vector half empty. So we return a failure if
1338        // 2*count > u16::MAX
1339        if 2 * self.count > u16::MAX as usize {
1340            return Err(req);
1341        }
1342
1343        // If more than half the vec is empty, we try and find the index of
1344        // an empty slot.
1345        let idx = if self.vec.len() >= 2 * self.count {
1346            let mut found = None;
1347            for idx in self.curr..self.vec.len() {
1348                if self.vec[idx].is_none() {
1349                    found = Some(idx);
1350                    break;
1351                }
1352            }
1353            found
1354        } else {
1355            None
1356        };
1357
1358        // If we have an index, we can insert there, otherwise we need to
1359        // append.
1360        let idx = match idx {
1361            Some(idx) => {
1362                self.vec[idx] = Some(req);
1363                idx
1364            }
1365            None => {
1366                let idx = self.vec.len();
1367                self.vec.push(Some(req));
1368                idx
1369            }
1370        };
1371
1372        self.count += 1;
1373        if idx == self.curr {
1374            self.curr += 1;
1375        }
1376        let req = self.vec[idx].as_mut().expect("no inserted item?");
1377        let idx = u16::try_from(idx).expect("query vec too large");
1378        Ok((idx, req))
1379    }
1380
1381    /// Inserts the given query at a specified position. A pre-condition is
1382    /// is that the slot has to be empty.
1383    fn insert_at(&mut self, id: u16, req: T) {
1384        let id = id as usize;
1385        self.vec[id] = Some(req);
1386
1387        self.count += 1;
1388        if id == self.curr {
1389            self.curr += 1;
1390        }
1391    }
1392
1393    /// Tries to remove and return the query at the given index.
1394    ///
1395    /// Returns `None` if there was no query there.
1396    fn try_remove(&mut self, index: u16) -> Option<T> {
1397        let res = self.vec.get_mut(usize::from(index))?.take()?;
1398        self.count = self.count.saturating_sub(1);
1399        self.curr = cmp::min(self.curr, index.into());
1400        Some(res)
1401    }
1402
1403    /// Removes all queries and returns an iterator over them.
1404    fn drain(&mut self) -> impl Iterator<Item = T> + '_ {
1405        let res = self.vec.drain(..).flatten(); // Skips all the `None`s.
1406        self.count = 0;
1407        self.curr = 0;
1408        res
1409    }
1410}
1411
1412//============ Tests =========================================================
1413
1414#[cfg(test)]
1415mod test {
1416    use super::*;
1417
1418    #[test]
1419    #[allow(clippy::needless_range_loop)]
1420    fn queries_insert_remove() {
1421        // Insert items, remove a few, insert a few more. Check that
1422        // everything looks right.
1423        let mut idxs = [None; 20];
1424        let mut queries = Queries::new();
1425
1426        for i in 0..12 {
1427            let (idx, item) = queries.insert(i).expect("test failed");
1428            idxs[i] = Some(idx);
1429            assert_eq!(i, *item);
1430        }
1431        assert_eq!(queries.count, 12);
1432        assert_eq!(queries.vec.iter().flatten().count(), 12);
1433
1434        for i in [1, 2, 3, 4, 7, 9] {
1435            let item = queries
1436                .try_remove(idxs[i].expect("test failed"))
1437                .expect("test failed");
1438            assert_eq!(i, item);
1439            idxs[i] = None;
1440        }
1441        assert_eq!(queries.count, 6);
1442        assert_eq!(queries.vec.iter().flatten().count(), 6);
1443
1444        for i in 12..20 {
1445            let (idx, item) = queries.insert(i).expect("test failed");
1446            idxs[i] = Some(idx);
1447            assert_eq!(i, *item);
1448        }
1449        assert_eq!(queries.count, 14);
1450        assert_eq!(queries.vec.iter().flatten().count(), 14);
1451
1452        for i in 0..20 {
1453            if let Some(idx) = idxs[i] {
1454                let item = queries.try_remove(idx).expect("test failed");
1455                assert_eq!(i, item);
1456            }
1457        }
1458        assert_eq!(queries.count, 0);
1459        assert_eq!(queries.vec.iter().flatten().count(), 0);
1460    }
1461
1462    #[test]
1463    fn queries_overrun() {
1464        // This is just a quick check that inserting to much stuff doesn’t
1465        // break.
1466        let mut queries = Queries::new();
1467        for i in 0..usize::from(u16::MAX) * 2 {
1468            let _ = queries.insert(i);
1469        }
1470    }
1471}