domain/net/client/
multi_stream.rs

1//! A DNS over multiple octet streams transport
2
3// To do:
4// - too many connection errors
5
6use crate::base::Message;
7use crate::net::client::protocol::AsyncConnect;
8use crate::net::client::request::{
9    ComposeRequest, Error, GetResponse, RequestMessageMulti, SendRequest,
10};
11use crate::net::client::stream;
12use crate::utils::config::DefMinMax;
13use bytes::Bytes;
14use futures_util::stream::FuturesUnordered;
15use futures_util::StreamExt;
16use rand::random;
17use std::boxed::Box;
18use std::fmt::Debug;
19use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::time::Duration;
23use std::vec::Vec;
24use tokio::io;
25use tokio::io::{AsyncRead, AsyncWrite};
26use tokio::sync::{mpsc, oneshot};
27use tokio::time::timeout;
28use tokio::time::{sleep_until, Instant};
29
30//------------ Constants -----------------------------------------------------
31
32/// Capacity of the channel that transports [`ChanReq`].
33const DEF_CHAN_CAP: usize = 8;
34
35/// Error messafe when the connection is closed.
36const ERR_CONN_CLOSED: &str = "connection closed";
37
38//------------ Configuration Constants ----------------------------------------
39
40/// Default response timeout.
41const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
42    Duration::from_secs(30),
43    Duration::from_millis(1),
44    Duration::from_secs(600),
45);
46
47//------------ Config ---------------------------------------------------------
48
49/// Configuration for an multi-stream transport.
50#[derive(Clone, Debug)]
51pub struct Config {
52    /// Response timeout currently in effect.
53    response_timeout: Duration,
54
55    /// Configuration of the underlying stream transport.
56    stream: stream::Config,
57}
58
59impl Config {
60    /// Returns the response timeout.
61    ///
62    /// This is the amount of time to wait for a request to complete.
63    pub fn response_timeout(&self) -> Duration {
64        self.response_timeout
65    }
66
67    /// Sets the response timeout.
68    ///
69    /// Excessive values are quietly trimmed.
70    pub fn set_response_timeout(&mut self, timeout: Duration) {
71        self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
72    }
73
74    /// Returns the underlying stream config.
75    pub fn stream(&self) -> &stream::Config {
76        &self.stream
77    }
78
79    /// Returns a mutable reference to the underlying stream config.
80    pub fn stream_mut(&mut self) -> &mut stream::Config {
81        &mut self.stream
82    }
83}
84
85impl From<stream::Config> for Config {
86    fn from(stream: stream::Config) -> Self {
87        Self {
88            stream,
89            response_timeout: RESPONSE_TIMEOUT.default(),
90        }
91    }
92}
93
94impl Default for Config {
95    fn default() -> Self {
96        Self {
97            stream: Default::default(),
98            response_timeout: RESPONSE_TIMEOUT.default(),
99        }
100    }
101}
102
103//------------ Connection -----------------------------------------------------
104
105/// A connection to a multi-stream transport.
106#[derive(Debug)]
107pub struct Connection<Req> {
108    /// The sender half of the connection request channel.
109    sender: mpsc::Sender<ChanReq<Req>>,
110
111    /// Maximum amount of time to wait for a response.
112    response_timeout: Duration,
113}
114
115impl<Req> Connection<Req> {
116    /// Creates a new multi-stream transport with default configuration.
117    pub fn new<Remote>(remote: Remote) -> (Self, Transport<Remote, Req>) {
118        Self::with_config(remote, Default::default())
119    }
120
121    /// Creates a new multi-stream transport.
122    pub fn with_config<Remote>(
123        remote: Remote,
124        config: Config,
125    ) -> (Self, Transport<Remote, Req>) {
126        let response_timeout = config.response_timeout;
127        let (sender, transport) = Transport::new(remote, config);
128        (
129            Self {
130                sender,
131                response_timeout,
132            },
133            transport,
134        )
135    }
136}
137
138impl<Req: ComposeRequest + Clone + 'static> Connection<Req> {
139    /// Sends a request and receives a response.
140    pub async fn request(
141        &self,
142        request: Req,
143    ) -> Result<Message<Bytes>, Error> {
144        Request::new(self.clone(), request).get_response().await
145    }
146
147    /// Starts a request.
148    ///
149    /// This is the future that is returned by the `SendRequest` impl.
150    async fn _send_request(
151        &self,
152        request: &Req,
153    ) -> Result<Box<dyn GetResponse + Send>, Error>
154    where
155        Req: 'static,
156    {
157        let gr = Request::new(self.clone(), request.clone());
158        Ok(Box::new(gr))
159    }
160
161    /// Request a new connection.
162    async fn new_conn(
163        &self,
164        opt_id: Option<u64>,
165    ) -> Result<oneshot::Receiver<ChanResp<Req>>, Error> {
166        let (sender, receiver) = oneshot::channel();
167        let req = ChanReq {
168            cmd: ReqCmd::NewConn(opt_id, sender),
169        };
170        self.sender
171            .send(req)
172            .await
173            .map_err(|_| Error::ConnectionClosed)?;
174        Ok(receiver)
175    }
176
177    /// Request a shutdown.
178    pub async fn shutdown(&self) -> Result<(), &'static str> {
179        let req = ChanReq {
180            cmd: ReqCmd::Shutdown,
181        };
182        match self.sender.send(req).await {
183            Err(_) =>
184            // Send error. The receiver is gone, this means that the
185            // connection is closed.
186            {
187                Err(ERR_CONN_CLOSED)
188            }
189            Ok(_) => Ok(()),
190        }
191    }
192}
193
194//--- Clone
195
196impl<Req> Clone for Connection<Req> {
197    fn clone(&self) -> Self {
198        Self {
199            sender: self.sender.clone(),
200            response_timeout: self.response_timeout,
201        }
202    }
203}
204
205//--- SendRequest
206
207impl<Req> SendRequest<Req> for Connection<Req>
208where
209    Req: ComposeRequest + Clone + 'static,
210{
211    fn send_request(
212        &self,
213        request: Req,
214    ) -> Box<dyn GetResponse + Send + Sync> {
215        Box::new(Request::new(self.clone(), request))
216    }
217}
218
219//------------ Request --------------------------------------------------------
220
221/// The connection side of an active request.
222#[derive(Debug)]
223struct Request<Req> {
224    /// The request message.
225    ///
226    /// It is kept so we can compare a response with it.
227    request_msg: Req,
228
229    /// Start time of the request.
230    start: Instant,
231
232    /// Current state of the query.
233    state: QueryState<Req>,
234
235    /// The underlying transport.
236    conn: Connection<Req>,
237
238    /// The id of the most recent connection, if any.
239    conn_id: Option<u64>,
240
241    /// Number of retries with delay.
242    delayed_retry_count: u64,
243}
244
245/// The states of the query state machine.
246#[derive(Debug)]
247enum QueryState<Req> {
248    /// Request a new connection.
249    RequestConn,
250
251    /// Receive a new connection from the receiver.
252    ReceiveConn(oneshot::Receiver<ChanResp<Req>>),
253
254    /// Start a query using the given stream transport.
255    StartQuery(Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>),
256
257    /// Get the result of the query.
258    GetResult(stream::Request),
259
260    /// Wait until trying again.
261    ///
262    /// The instant represents when the error occurred, the duration how
263    /// long to wait.
264    Delay(Instant, Duration),
265
266    /// A response has been received and the query is done.
267    Done,
268}
269
270/// The response to a connection request.
271type ChanResp<Req> = Result<ChanRespOk<Req>, Arc<std::io::Error>>;
272
273/// The successful response to a connection request.
274#[derive(Debug)]
275struct ChanRespOk<Req> {
276    /// The id of this connection.
277    id: u64,
278
279    /// The new stream transport to use for sending a request.
280    conn: Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>,
281}
282
283impl<Req> Request<Req> {
284    /// Creates a new query.
285    fn new(conn: Connection<Req>, request_msg: Req) -> Self {
286        Self {
287            conn,
288            request_msg,
289            start: Instant::now(),
290            state: QueryState::RequestConn,
291            conn_id: None,
292            delayed_retry_count: 0,
293        }
294    }
295}
296
297impl<Req: ComposeRequest + Clone + 'static> Request<Req> {
298    /// Get the result of a DNS request.
299    ///
300    /// This function is cancellation safe. If its future is dropped before
301    /// it is resolved, you can call it again to get a new future.
302    pub async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
303        loop {
304            let elapsed = self.start.elapsed();
305            if elapsed >= self.conn.response_timeout {
306                return Err(Error::StreamReadTimeout);
307            }
308            let remaining = self.conn.response_timeout - elapsed;
309
310            match self.state {
311                QueryState::RequestConn => {
312                    let to =
313                        timeout(remaining, self.conn.new_conn(self.conn_id))
314                            .await
315                            .map_err(|_| Error::StreamReadTimeout)?;
316
317                    let rx = match to {
318                        Ok(rx) => rx,
319                        Err(err) => {
320                            self.state = QueryState::Done;
321                            return Err(err);
322                        }
323                    };
324                    self.state = QueryState::ReceiveConn(rx);
325                }
326                QueryState::ReceiveConn(ref mut receiver) => {
327                    let to = timeout(remaining, receiver)
328                        .await
329                        .map_err(|_| Error::StreamReadTimeout)?;
330                    let res = match to {
331                        Ok(res) => res,
332                        Err(_) => {
333                            // Assume receive error
334                            self.state = QueryState::Done;
335                            return Err(Error::StreamReceiveError);
336                        }
337                    };
338
339                    // Another Result. This time from executing the request
340                    match res {
341                        Err(_) => {
342                            self.delayed_retry_count += 1;
343                            let retry_time =
344                                retry_time(self.delayed_retry_count);
345                            self.state =
346                                QueryState::Delay(Instant::now(), retry_time);
347                            continue;
348                        }
349                        Ok(ok_res) => {
350                            let id = ok_res.id;
351                            let conn = ok_res.conn;
352
353                            self.conn_id = Some(id);
354                            self.state = QueryState::StartQuery(conn);
355                            continue;
356                        }
357                    }
358                }
359                QueryState::StartQuery(ref mut conn) => {
360                    self.state = QueryState::GetResult(
361                        conn.get_request(self.request_msg.clone()),
362                    );
363                    continue;
364                }
365                QueryState::GetResult(ref mut query) => {
366                    let to = timeout(remaining, query.get_response())
367                        .await
368                        .map_err(|_| Error::StreamReadTimeout)?;
369                    match to {
370                        Ok(reply) => {
371                            return Ok(reply);
372                        }
373                        // XXX This replicates the previous behavior. But
374                        //     maybe we should have a whole category of
375                        //     fatal errors where retrying doesn’t make any
376                        //     sense?
377                        Err(Error::WrongReplyForQuery) => {
378                            return Err(Error::WrongReplyForQuery)
379                        }
380                        Err(Error::ConnectionClosed) => {
381                            // The stream may immedately return that the
382                            // connection was already closed. Do not delay
383                            // the first time.
384                            self.delayed_retry_count += 1;
385                            if self.delayed_retry_count == 1 {
386                                self.state = QueryState::RequestConn;
387                            } else {
388                                let retry_time =
389                                    retry_time(self.delayed_retry_count);
390                                self.state = QueryState::Delay(
391                                    Instant::now(),
392                                    retry_time,
393                                );
394                            }
395                        }
396                        Err(_) => {
397                            self.delayed_retry_count += 1;
398                            let retry_time =
399                                retry_time(self.delayed_retry_count);
400                            self.state =
401                                QueryState::Delay(Instant::now(), retry_time);
402                        }
403                    }
404                }
405                QueryState::Delay(instant, duration) => {
406                    if timeout(remaining, sleep_until(instant + duration))
407                        .await
408                        .is_err()
409                    {
410                        return Err(Error::StreamReadTimeout);
411                    };
412                    self.state = QueryState::RequestConn;
413                }
414                QueryState::Done => {
415                    panic!("Already done");
416                }
417            }
418        }
419    }
420}
421
422impl<Req: ComposeRequest + Clone + 'static> GetResponse for Request<Req> {
423    fn get_response(
424        &mut self,
425    ) -> Pin<
426        Box<
427            dyn Future<Output = Result<Message<Bytes>, Error>>
428                + Send
429                + Sync
430                + '_,
431        >,
432    > {
433        Box::pin(Self::get_response(self))
434    }
435}
436
437//------------ Transport ------------------------------------------------
438
439/// The actual implementation of [Connection].
440#[derive(Debug)]
441pub struct Transport<Remote, Req> {
442    /// User configuration values.
443    config: Config,
444
445    /// The remote destination.
446    stream: Remote,
447
448    /// Underlying stream connection.
449    conn_state: SingleConnState3<Req>,
450
451    /// Current connection id.
452    conn_id: u64,
453
454    /// Receiver part of the channel.
455    receiver: mpsc::Receiver<ChanReq<Req>>,
456}
457
458#[derive(Debug)]
459/// A request to [Connection::run] either for a new stream or to
460/// shutdown.
461struct ChanReq<Req> {
462    /// A requests consists of a command.
463    cmd: ReqCmd<Req>,
464}
465
466#[derive(Debug)]
467/// Commands that can be requested.
468enum ReqCmd<Req> {
469    /// Request for a (new) connection.
470    ///
471    /// The id of the previous connection (if any) is passed as well as a
472    /// channel to send the reply.
473    NewConn(Option<u64>, ReplySender<Req>),
474
475    /// Shutdown command.
476    Shutdown,
477}
478
479/// This is the type of sender in [ReqCmd].
480type ReplySender<Req> = oneshot::Sender<ChanResp<Req>>;
481
482/// State of the current underlying stream transport.
483#[derive(Debug)]
484enum SingleConnState3<Req> {
485    /// No current stream transport.
486    None,
487
488    /// Current stream transport.
489    Some(Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>),
490
491    /// State that deals with an error getting a new octet stream from
492    /// a connection stream.
493    Err(ErrorState),
494}
495
496/// State associated with a failed attempt to create a new stream
497/// transport.
498#[derive(Clone, Debug)]
499struct ErrorState {
500    /// The error we got from the most recent attempt.
501    error: Arc<std::io::Error>,
502
503    /// How many times we tried so far.
504    retries: u64,
505
506    /// When we got an error.
507    timer: Instant,
508
509    /// Time to wait before trying to create a new connection.
510    timeout: Duration,
511}
512
513impl<Remote, Req> Transport<Remote, Req> {
514    /// Creates a new transport.
515    fn new(
516        stream: Remote,
517        config: Config,
518    ) -> (mpsc::Sender<ChanReq<Req>>, Self) {
519        let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
520        (
521            sender,
522            Self {
523                config,
524                stream,
525                conn_state: SingleConnState3::None,
526                conn_id: 0,
527                receiver,
528            },
529        )
530    }
531}
532
533impl<Remote, Req: ComposeRequest> Transport<Remote, Req>
534where
535    Remote: AsyncConnect,
536    Remote::Connection: AsyncRead + AsyncWrite,
537    Req: ComposeRequest,
538{
539    /// Run the transport machinery.
540    pub async fn run(mut self) {
541        let mut curr_cmd: Option<ReqCmd<Req>> = None;
542        let mut do_stream = false;
543        let mut runners = FuturesUnordered::new();
544        let mut stream_fut: Pin<
545            Box<
546                dyn Future<
547                        Output = Result<Remote::Connection, std::io::Error>,
548                    > + Send,
549            >,
550        > = Box::pin(stream_nop());
551        let mut opt_chan = None;
552
553        loop {
554            if let Some(req) = curr_cmd {
555                assert!(!do_stream);
556                curr_cmd = None;
557                match req {
558                    ReqCmd::NewConn(opt_id, chan) => {
559                        if let SingleConnState3::Err(error_state) =
560                            &self.conn_state
561                        {
562                            if error_state.timer.elapsed()
563                                < error_state.timeout
564                            {
565                                let resp =
566                                    ChanResp::Err(error_state.error.clone());
567
568                                // Ignore errors. We don't care if the receiver
569                                // is gone
570                                _ = chan.send(resp);
571                                continue;
572                            }
573
574                            // Try to set up a new connection
575                        }
576
577                        // Check if the command has an id greather than the
578                        // current id.
579                        if let Some(id) = opt_id {
580                            if id >= self.conn_id {
581                                // We need a new connection. Remove the
582                                // current one. This is the best place to
583                                // increment conn_id.
584                                self.conn_id += 1;
585                                self.conn_state = SingleConnState3::None;
586                            }
587                        }
588                        // If we still have a connection then we can reply
589                        // immediately.
590                        if let SingleConnState3::Some(conn) = &self.conn_state
591                        {
592                            let resp = ChanResp::Ok(ChanRespOk {
593                                id: self.conn_id,
594                                conn: conn.clone(),
595                            });
596                            // Ignore errors. We don't care if the receiver
597                            // is gone
598                            _ = chan.send(resp);
599                        } else {
600                            opt_chan = Some(chan);
601                            stream_fut = Box::pin(self.stream.connect());
602                            do_stream = true;
603                        }
604                    }
605                    ReqCmd::Shutdown => break,
606                }
607            }
608
609            if do_stream {
610                let runners_empty = runners.is_empty();
611
612                loop {
613                    tokio::select! {
614                        res_conn = stream_fut.as_mut() => {
615                            do_stream = false;
616                            stream_fut = Box::pin(stream_nop());
617
618                            let stream = match res_conn {
619                                Ok(stream) => stream,
620                                Err(error) => {
621                                    let error = Arc::new(error);
622                                    match self.conn_state {
623                                        SingleConnState3::None =>
624                                            self.conn_state =
625                                            SingleConnState3::Err(ErrorState {
626                                                error: error.clone(),
627                                                retries: 0,
628                                                timer: Instant::now(),
629                                                timeout: retry_time(0),
630                                            }),
631                                        SingleConnState3::Some(_) =>
632                                            panic!("Illegal Some state"),
633                                        SingleConnState3::Err(error_state) => {
634                                            self.conn_state =
635                                            SingleConnState3::Err(ErrorState {
636                                                error:
637                                                    error_state.error.clone(),
638                                                retries: error_state.retries+1,
639                                                timer: Instant::now(),
640                                                timeout: retry_time(
641                                                error_state.retries+1),
642                                            });
643                                        }
644                                    }
645
646                                    let resp = ChanResp::Err(error);
647                                    let loc_opt_chan = opt_chan.take();
648
649                                    // Ignore errors. We don't care if the receiver
650                                    // is gone
651                                    _ = loc_opt_chan.expect("weird, no channel?")
652                                        .send(resp);
653                                    break;
654                                }
655                            };
656                            let (conn, tran) = stream::Connection::with_config(
657                                stream, self.config.stream.clone()
658                            );
659                            let conn = Arc::new(conn);
660                            runners.push(Box::pin(tran.run()));
661
662                            let resp = ChanResp::Ok(ChanRespOk {
663                                id: self.conn_id,
664                                conn: conn.clone(),
665                            });
666                            self.conn_state = SingleConnState3::Some(conn);
667
668                            let loc_opt_chan = opt_chan.take();
669
670                            // Ignore errors. We don't care if the receiver
671                            // is gone
672                            _ = loc_opt_chan.expect("weird, no channel?")
673                                .send(resp);
674                            break;
675                        }
676                        _ = runners.next(), if !runners_empty => {
677                            }
678                    }
679                }
680                continue;
681            }
682
683            assert!(curr_cmd.is_none());
684            let recv_fut = self.receiver.recv();
685            let runners_empty = runners.is_empty();
686            tokio::select! {
687                msg = recv_fut => {
688                    if msg.is_none() {
689            // All references to the connection object have been
690            // dropped. Shutdown.
691                        break;
692                    }
693                    curr_cmd = Some(msg.expect("None is checked before").cmd);
694                }
695                _ = runners.next(), if !runners_empty => {
696                    }
697            }
698        }
699
700        // Avoid new queries
701        drop(self.receiver);
702
703        // Wait for existing stream runners to terminate
704        while !runners.is_empty() {
705            runners.next().await;
706        }
707    }
708}
709
710//------------ Utility --------------------------------------------------------
711
712/// Compute the retry timeout based on the number of retries so far.
713///
714/// The computation is a random value (in microseconds) between zero and
715/// two to the power of the number of retries.
716fn retry_time(retries: u64) -> Duration {
717    let to_secs = if retries > 6 { 60 } else { 1 << retries };
718    let to_usecs = to_secs * 1000000;
719    let rnd: f64 = random();
720    let to_usecs = to_usecs as f64 * rnd;
721    Duration::from_micros(to_usecs as u64)
722}
723
724/// Helper function to create an empty future that is compatible with the
725/// future returned by a connection stream.
726async fn stream_nop<IO>() -> Result<IO, std::io::Error> {
727    Err(io::Error::other("nop"))
728}