domain/net/client/multi_stream.rs
1//! A DNS over multiple octet streams transport
2
3// To do:
4// - too many connection errors
5
6use crate::base::Message;
7use crate::net::client::protocol::AsyncConnect;
8use crate::net::client::request::{
9 ComposeRequest, Error, GetResponse, RequestMessageMulti, SendRequest,
10};
11use crate::net::client::stream;
12use crate::utils::config::DefMinMax;
13use bytes::Bytes;
14use futures_util::stream::FuturesUnordered;
15use futures_util::StreamExt;
16use rand::random;
17use std::boxed::Box;
18use std::fmt::Debug;
19use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::time::Duration;
23use std::vec::Vec;
24use tokio::io;
25use tokio::io::{AsyncRead, AsyncWrite};
26use tokio::sync::{mpsc, oneshot};
27use tokio::time::timeout;
28use tokio::time::{sleep_until, Instant};
29
30//------------ Constants -----------------------------------------------------
31
32/// Capacity of the channel that transports [`ChanReq`].
33const DEF_CHAN_CAP: usize = 8;
34
35/// Error messafe when the connection is closed.
36const ERR_CONN_CLOSED: &str = "connection closed";
37
38//------------ Configuration Constants ----------------------------------------
39
40/// Default response timeout.
41const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
42 Duration::from_secs(30),
43 Duration::from_millis(1),
44 Duration::from_secs(600),
45);
46
47//------------ Config ---------------------------------------------------------
48
49/// Configuration for an multi-stream transport.
50#[derive(Clone, Debug)]
51pub struct Config {
52 /// Response timeout currently in effect.
53 response_timeout: Duration,
54
55 /// Configuration of the underlying stream transport.
56 stream: stream::Config,
57}
58
59impl Config {
60 /// Returns the response timeout.
61 ///
62 /// This is the amount of time to wait for a request to complete.
63 pub fn response_timeout(&self) -> Duration {
64 self.response_timeout
65 }
66
67 /// Sets the response timeout.
68 ///
69 /// Excessive values are quietly trimmed.
70 pub fn set_response_timeout(&mut self, timeout: Duration) {
71 self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
72 }
73
74 /// Returns the underlying stream config.
75 pub fn stream(&self) -> &stream::Config {
76 &self.stream
77 }
78
79 /// Returns a mutable reference to the underlying stream config.
80 pub fn stream_mut(&mut self) -> &mut stream::Config {
81 &mut self.stream
82 }
83}
84
85impl From<stream::Config> for Config {
86 fn from(stream: stream::Config) -> Self {
87 Self {
88 stream,
89 response_timeout: RESPONSE_TIMEOUT.default(),
90 }
91 }
92}
93
94impl Default for Config {
95 fn default() -> Self {
96 Self {
97 stream: Default::default(),
98 response_timeout: RESPONSE_TIMEOUT.default(),
99 }
100 }
101}
102
103//------------ Connection -----------------------------------------------------
104
105/// A connection to a multi-stream transport.
106#[derive(Debug)]
107pub struct Connection<Req> {
108 /// The sender half of the connection request channel.
109 sender: mpsc::Sender<ChanReq<Req>>,
110
111 /// Maximum amount of time to wait for a response.
112 response_timeout: Duration,
113}
114
115impl<Req> Connection<Req> {
116 /// Creates a new multi-stream transport with default configuration.
117 pub fn new<Remote>(remote: Remote) -> (Self, Transport<Remote, Req>) {
118 Self::with_config(remote, Default::default())
119 }
120
121 /// Creates a new multi-stream transport.
122 pub fn with_config<Remote>(
123 remote: Remote,
124 config: Config,
125 ) -> (Self, Transport<Remote, Req>) {
126 let response_timeout = config.response_timeout;
127 let (sender, transport) = Transport::new(remote, config);
128 (
129 Self {
130 sender,
131 response_timeout,
132 },
133 transport,
134 )
135 }
136}
137
138impl<Req: ComposeRequest + Clone + 'static> Connection<Req> {
139 /// Sends a request and receives a response.
140 pub async fn request(
141 &self,
142 request: Req,
143 ) -> Result<Message<Bytes>, Error> {
144 Request::new(self.clone(), request).get_response().await
145 }
146
147 /// Starts a request.
148 ///
149 /// This is the future that is returned by the `SendRequest` impl.
150 async fn _send_request(
151 &self,
152 request: &Req,
153 ) -> Result<Box<dyn GetResponse + Send>, Error>
154 where
155 Req: 'static,
156 {
157 let gr = Request::new(self.clone(), request.clone());
158 Ok(Box::new(gr))
159 }
160
161 /// Request a new connection.
162 async fn new_conn(
163 &self,
164 opt_id: Option<u64>,
165 ) -> Result<oneshot::Receiver<ChanResp<Req>>, Error> {
166 let (sender, receiver) = oneshot::channel();
167 let req = ChanReq {
168 cmd: ReqCmd::NewConn(opt_id, sender),
169 };
170 self.sender
171 .send(req)
172 .await
173 .map_err(|_| Error::ConnectionClosed)?;
174 Ok(receiver)
175 }
176
177 /// Request a shutdown.
178 pub async fn shutdown(&self) -> Result<(), &'static str> {
179 let req = ChanReq {
180 cmd: ReqCmd::Shutdown,
181 };
182 match self.sender.send(req).await {
183 Err(_) =>
184 // Send error. The receiver is gone, this means that the
185 // connection is closed.
186 {
187 Err(ERR_CONN_CLOSED)
188 }
189 Ok(_) => Ok(()),
190 }
191 }
192}
193
194//--- Clone
195
196impl<Req> Clone for Connection<Req> {
197 fn clone(&self) -> Self {
198 Self {
199 sender: self.sender.clone(),
200 response_timeout: self.response_timeout,
201 }
202 }
203}
204
205//--- SendRequest
206
207impl<Req> SendRequest<Req> for Connection<Req>
208where
209 Req: ComposeRequest + Clone + 'static,
210{
211 fn send_request(
212 &self,
213 request: Req,
214 ) -> Box<dyn GetResponse + Send + Sync> {
215 Box::new(Request::new(self.clone(), request))
216 }
217}
218
219//------------ Request --------------------------------------------------------
220
221/// The connection side of an active request.
222#[derive(Debug)]
223struct Request<Req> {
224 /// The request message.
225 ///
226 /// It is kept so we can compare a response with it.
227 request_msg: Req,
228
229 /// Start time of the request.
230 start: Instant,
231
232 /// Current state of the query.
233 state: QueryState<Req>,
234
235 /// The underlying transport.
236 conn: Connection<Req>,
237
238 /// The id of the most recent connection, if any.
239 conn_id: Option<u64>,
240
241 /// Number of retries with delay.
242 delayed_retry_count: u64,
243}
244
245/// The states of the query state machine.
246#[derive(Debug)]
247enum QueryState<Req> {
248 /// Request a new connection.
249 RequestConn,
250
251 /// Receive a new connection from the receiver.
252 ReceiveConn(oneshot::Receiver<ChanResp<Req>>),
253
254 /// Start a query using the given stream transport.
255 StartQuery(Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>),
256
257 /// Get the result of the query.
258 GetResult(stream::Request),
259
260 /// Wait until trying again.
261 ///
262 /// The instant represents when the error occurred, the duration how
263 /// long to wait.
264 Delay(Instant, Duration),
265
266 /// A response has been received and the query is done.
267 Done,
268}
269
270/// The response to a connection request.
271type ChanResp<Req> = Result<ChanRespOk<Req>, Arc<std::io::Error>>;
272
273/// The successful response to a connection request.
274#[derive(Debug)]
275struct ChanRespOk<Req> {
276 /// The id of this connection.
277 id: u64,
278
279 /// The new stream transport to use for sending a request.
280 conn: Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>,
281}
282
283impl<Req> Request<Req> {
284 /// Creates a new query.
285 fn new(conn: Connection<Req>, request_msg: Req) -> Self {
286 Self {
287 conn,
288 request_msg,
289 start: Instant::now(),
290 state: QueryState::RequestConn,
291 conn_id: None,
292 delayed_retry_count: 0,
293 }
294 }
295}
296
297impl<Req: ComposeRequest + Clone + 'static> Request<Req> {
298 /// Get the result of a DNS request.
299 ///
300 /// This function is cancellation safe. If its future is dropped before
301 /// it is resolved, you can call it again to get a new future.
302 pub async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
303 loop {
304 let elapsed = self.start.elapsed();
305 if elapsed >= self.conn.response_timeout {
306 return Err(Error::StreamReadTimeout);
307 }
308 let remaining = self.conn.response_timeout - elapsed;
309
310 match self.state {
311 QueryState::RequestConn => {
312 let to =
313 timeout(remaining, self.conn.new_conn(self.conn_id))
314 .await
315 .map_err(|_| Error::StreamReadTimeout)?;
316
317 let rx = match to {
318 Ok(rx) => rx,
319 Err(err) => {
320 self.state = QueryState::Done;
321 return Err(err);
322 }
323 };
324 self.state = QueryState::ReceiveConn(rx);
325 }
326 QueryState::ReceiveConn(ref mut receiver) => {
327 let to = timeout(remaining, receiver)
328 .await
329 .map_err(|_| Error::StreamReadTimeout)?;
330 let res = match to {
331 Ok(res) => res,
332 Err(_) => {
333 // Assume receive error
334 self.state = QueryState::Done;
335 return Err(Error::StreamReceiveError);
336 }
337 };
338
339 // Another Result. This time from executing the request
340 match res {
341 Err(_) => {
342 self.delayed_retry_count += 1;
343 let retry_time =
344 retry_time(self.delayed_retry_count);
345 self.state =
346 QueryState::Delay(Instant::now(), retry_time);
347 continue;
348 }
349 Ok(ok_res) => {
350 let id = ok_res.id;
351 let conn = ok_res.conn;
352
353 self.conn_id = Some(id);
354 self.state = QueryState::StartQuery(conn);
355 continue;
356 }
357 }
358 }
359 QueryState::StartQuery(ref mut conn) => {
360 self.state = QueryState::GetResult(
361 conn.get_request(self.request_msg.clone()),
362 );
363 continue;
364 }
365 QueryState::GetResult(ref mut query) => {
366 let to = timeout(remaining, query.get_response())
367 .await
368 .map_err(|_| Error::StreamReadTimeout)?;
369 match to {
370 Ok(reply) => {
371 return Ok(reply);
372 }
373 // XXX This replicates the previous behavior. But
374 // maybe we should have a whole category of
375 // fatal errors where retrying doesn’t make any
376 // sense?
377 Err(Error::WrongReplyForQuery) => {
378 return Err(Error::WrongReplyForQuery)
379 }
380 Err(Error::ConnectionClosed) => {
381 // The stream may immedately return that the
382 // connection was already closed. Do not delay
383 // the first time.
384 self.delayed_retry_count += 1;
385 if self.delayed_retry_count == 1 {
386 self.state = QueryState::RequestConn;
387 } else {
388 let retry_time =
389 retry_time(self.delayed_retry_count);
390 self.state = QueryState::Delay(
391 Instant::now(),
392 retry_time,
393 );
394 }
395 }
396 Err(_) => {
397 self.delayed_retry_count += 1;
398 let retry_time =
399 retry_time(self.delayed_retry_count);
400 self.state =
401 QueryState::Delay(Instant::now(), retry_time);
402 }
403 }
404 }
405 QueryState::Delay(instant, duration) => {
406 if timeout(remaining, sleep_until(instant + duration))
407 .await
408 .is_err()
409 {
410 return Err(Error::StreamReadTimeout);
411 };
412 self.state = QueryState::RequestConn;
413 }
414 QueryState::Done => {
415 panic!("Already done");
416 }
417 }
418 }
419 }
420}
421
422impl<Req: ComposeRequest + Clone + 'static> GetResponse for Request<Req> {
423 fn get_response(
424 &mut self,
425 ) -> Pin<
426 Box<
427 dyn Future<Output = Result<Message<Bytes>, Error>>
428 + Send
429 + Sync
430 + '_,
431 >,
432 > {
433 Box::pin(Self::get_response(self))
434 }
435}
436
437//------------ Transport ------------------------------------------------
438
439/// The actual implementation of [Connection].
440#[derive(Debug)]
441pub struct Transport<Remote, Req> {
442 /// User configuration values.
443 config: Config,
444
445 /// The remote destination.
446 stream: Remote,
447
448 /// Underlying stream connection.
449 conn_state: SingleConnState3<Req>,
450
451 /// Current connection id.
452 conn_id: u64,
453
454 /// Receiver part of the channel.
455 receiver: mpsc::Receiver<ChanReq<Req>>,
456}
457
458#[derive(Debug)]
459/// A request to [Connection::run] either for a new stream or to
460/// shutdown.
461struct ChanReq<Req> {
462 /// A requests consists of a command.
463 cmd: ReqCmd<Req>,
464}
465
466#[derive(Debug)]
467/// Commands that can be requested.
468enum ReqCmd<Req> {
469 /// Request for a (new) connection.
470 ///
471 /// The id of the previous connection (if any) is passed as well as a
472 /// channel to send the reply.
473 NewConn(Option<u64>, ReplySender<Req>),
474
475 /// Shutdown command.
476 Shutdown,
477}
478
479/// This is the type of sender in [ReqCmd].
480type ReplySender<Req> = oneshot::Sender<ChanResp<Req>>;
481
482/// State of the current underlying stream transport.
483#[derive(Debug)]
484enum SingleConnState3<Req> {
485 /// No current stream transport.
486 None,
487
488 /// Current stream transport.
489 Some(Arc<stream::Connection<Req, RequestMessageMulti<Vec<u8>>>>),
490
491 /// State that deals with an error getting a new octet stream from
492 /// a connection stream.
493 Err(ErrorState),
494}
495
496/// State associated with a failed attempt to create a new stream
497/// transport.
498#[derive(Clone, Debug)]
499struct ErrorState {
500 /// The error we got from the most recent attempt.
501 error: Arc<std::io::Error>,
502
503 /// How many times we tried so far.
504 retries: u64,
505
506 /// When we got an error.
507 timer: Instant,
508
509 /// Time to wait before trying to create a new connection.
510 timeout: Duration,
511}
512
513impl<Remote, Req> Transport<Remote, Req> {
514 /// Creates a new transport.
515 fn new(
516 stream: Remote,
517 config: Config,
518 ) -> (mpsc::Sender<ChanReq<Req>>, Self) {
519 let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
520 (
521 sender,
522 Self {
523 config,
524 stream,
525 conn_state: SingleConnState3::None,
526 conn_id: 0,
527 receiver,
528 },
529 )
530 }
531}
532
533impl<Remote, Req: ComposeRequest> Transport<Remote, Req>
534where
535 Remote: AsyncConnect,
536 Remote::Connection: AsyncRead + AsyncWrite,
537 Req: ComposeRequest,
538{
539 /// Run the transport machinery.
540 pub async fn run(mut self) {
541 let mut curr_cmd: Option<ReqCmd<Req>> = None;
542 let mut do_stream = false;
543 let mut runners = FuturesUnordered::new();
544 let mut stream_fut: Pin<
545 Box<
546 dyn Future<
547 Output = Result<Remote::Connection, std::io::Error>,
548 > + Send,
549 >,
550 > = Box::pin(stream_nop());
551 let mut opt_chan = None;
552
553 loop {
554 if let Some(req) = curr_cmd {
555 assert!(!do_stream);
556 curr_cmd = None;
557 match req {
558 ReqCmd::NewConn(opt_id, chan) => {
559 if let SingleConnState3::Err(error_state) =
560 &self.conn_state
561 {
562 if error_state.timer.elapsed()
563 < error_state.timeout
564 {
565 let resp =
566 ChanResp::Err(error_state.error.clone());
567
568 // Ignore errors. We don't care if the receiver
569 // is gone
570 _ = chan.send(resp);
571 continue;
572 }
573
574 // Try to set up a new connection
575 }
576
577 // Check if the command has an id greather than the
578 // current id.
579 if let Some(id) = opt_id {
580 if id >= self.conn_id {
581 // We need a new connection. Remove the
582 // current one. This is the best place to
583 // increment conn_id.
584 self.conn_id += 1;
585 self.conn_state = SingleConnState3::None;
586 }
587 }
588 // If we still have a connection then we can reply
589 // immediately.
590 if let SingleConnState3::Some(conn) = &self.conn_state
591 {
592 let resp = ChanResp::Ok(ChanRespOk {
593 id: self.conn_id,
594 conn: conn.clone(),
595 });
596 // Ignore errors. We don't care if the receiver
597 // is gone
598 _ = chan.send(resp);
599 } else {
600 opt_chan = Some(chan);
601 stream_fut = Box::pin(self.stream.connect());
602 do_stream = true;
603 }
604 }
605 ReqCmd::Shutdown => break,
606 }
607 }
608
609 if do_stream {
610 let runners_empty = runners.is_empty();
611
612 loop {
613 tokio::select! {
614 res_conn = stream_fut.as_mut() => {
615 do_stream = false;
616 stream_fut = Box::pin(stream_nop());
617
618 let stream = match res_conn {
619 Ok(stream) => stream,
620 Err(error) => {
621 let error = Arc::new(error);
622 match self.conn_state {
623 SingleConnState3::None =>
624 self.conn_state =
625 SingleConnState3::Err(ErrorState {
626 error: error.clone(),
627 retries: 0,
628 timer: Instant::now(),
629 timeout: retry_time(0),
630 }),
631 SingleConnState3::Some(_) =>
632 panic!("Illegal Some state"),
633 SingleConnState3::Err(error_state) => {
634 self.conn_state =
635 SingleConnState3::Err(ErrorState {
636 error:
637 error_state.error.clone(),
638 retries: error_state.retries+1,
639 timer: Instant::now(),
640 timeout: retry_time(
641 error_state.retries+1),
642 });
643 }
644 }
645
646 let resp = ChanResp::Err(error);
647 let loc_opt_chan = opt_chan.take();
648
649 // Ignore errors. We don't care if the receiver
650 // is gone
651 _ = loc_opt_chan.expect("weird, no channel?")
652 .send(resp);
653 break;
654 }
655 };
656 let (conn, tran) = stream::Connection::with_config(
657 stream, self.config.stream.clone()
658 );
659 let conn = Arc::new(conn);
660 runners.push(Box::pin(tran.run()));
661
662 let resp = ChanResp::Ok(ChanRespOk {
663 id: self.conn_id,
664 conn: conn.clone(),
665 });
666 self.conn_state = SingleConnState3::Some(conn);
667
668 let loc_opt_chan = opt_chan.take();
669
670 // Ignore errors. We don't care if the receiver
671 // is gone
672 _ = loc_opt_chan.expect("weird, no channel?")
673 .send(resp);
674 break;
675 }
676 _ = runners.next(), if !runners_empty => {
677 }
678 }
679 }
680 continue;
681 }
682
683 assert!(curr_cmd.is_none());
684 let recv_fut = self.receiver.recv();
685 let runners_empty = runners.is_empty();
686 tokio::select! {
687 msg = recv_fut => {
688 if msg.is_none() {
689 // All references to the connection object have been
690 // dropped. Shutdown.
691 break;
692 }
693 curr_cmd = Some(msg.expect("None is checked before").cmd);
694 }
695 _ = runners.next(), if !runners_empty => {
696 }
697 }
698 }
699
700 // Avoid new queries
701 drop(self.receiver);
702
703 // Wait for existing stream runners to terminate
704 while !runners.is_empty() {
705 runners.next().await;
706 }
707 }
708}
709
710//------------ Utility --------------------------------------------------------
711
712/// Compute the retry timeout based on the number of retries so far.
713///
714/// The computation is a random value (in microseconds) between zero and
715/// two to the power of the number of retries.
716fn retry_time(retries: u64) -> Duration {
717 let to_secs = if retries > 6 { 60 } else { 1 << retries };
718 let to_usecs = to_secs * 1000000;
719 let rnd: f64 = random();
720 let to_usecs = to_usecs as f64 * rnd;
721 Duration::from_micros(to_usecs as u64)
722}
723
724/// Helper function to create an empty future that is compatible with the
725/// future returned by a connection stream.
726async fn stream_nop<IO>() -> Result<IO, std::io::Error> {
727 Err(io::Error::other("nop"))
728}