domain/net/client/
redundant.rs

1//! A transport that multiplexes requests over multiple redundant transports.
2
3use bytes::Bytes;
4
5use futures_util::stream::FuturesUnordered;
6use futures_util::StreamExt;
7
8use octseq::Octets;
9
10use rand::random;
11
12use std::boxed::Box;
13use std::cmp::Ordering;
14use std::fmt::{Debug, Formatter};
15use std::future::Future;
16use std::pin::Pin;
17use std::vec::Vec;
18
19use tokio::sync::{mpsc, oneshot};
20use tokio::time::{sleep_until, Duration, Instant};
21
22use crate::base::iana::OptRcode;
23use crate::base::Message;
24use crate::net::client::request::{Error, GetResponse, SendRequest};
25
26/*
27Basic algorithm:
28- keep track of expected response time for every upstream
29- start with the upstream with the lowest expected response time
30- set a timer to the expect response time.
31- if the timer expires before reply arrives, send the query to the next lowest
32  and set a timer
33- when a reply arrives update the expected response time for the relevant
34  upstream and for the ones that failed.
35
36Based on a random number generator:
37- pick a different upstream rather then the best but set the timer to the
38  expected response time of the best.
39*/
40
41/// Capacity of the channel that transports [ChanReq].
42const DEF_CHAN_CAP: usize = 8;
43
44/// Time in milliseconds for the initial response time estimate.
45const DEFAULT_RT_MS: u64 = 300;
46
47/// The initial response time estimate for unused connections.
48const DEFAULT_RT: Duration = Duration::from_millis(DEFAULT_RT_MS);
49
50/// Maintain a moving average for the measured response time and the
51/// square of that. The window is SMOOTH_N.
52const SMOOTH_N: f64 = 8.;
53
54/// Chance to probe a worse connection.
55const PROBE_P: f64 = 0.05;
56
57//------------ Config ---------------------------------------------------------
58
59/// User configuration variables.
60#[derive(Clone, Copy, Debug, Default)]
61pub struct Config {
62    /// Defer transport errors.
63    defer_transport_error: bool,
64
65    /// Defer replies that report Refused.
66    defer_refused: bool,
67
68    /// Defer replies that report ServFail.
69    defer_servfail: bool,
70}
71
72impl Config {
73    /// Return the value of the defer_transport_error configuration variable.
74    pub fn defer_transport_error(&self) -> bool {
75        self.defer_transport_error
76    }
77
78    /// Set the value of the defer_transport_error configuration variable.
79    pub fn set_defer_transport_error(&mut self, value: bool) {
80        self.defer_transport_error = value
81    }
82
83    /// Return the value of the defer_refused configuration variable.
84    pub fn defer_refused(&self) -> bool {
85        self.defer_refused
86    }
87
88    /// Set the value of the defer_refused configuration variable.
89    pub fn set_defer_refused(&mut self, value: bool) {
90        self.defer_refused = value
91    }
92
93    /// Return the value of the defer_servfail configuration variable.
94    pub fn defer_servfail(&self) -> bool {
95        self.defer_servfail
96    }
97
98    /// Set the value of the defer_servfail configuration variable.
99    pub fn set_defer_servfail(&mut self, value: bool) {
100        self.defer_servfail = value
101    }
102}
103
104//------------ Connection -----------------------------------------------------
105
106/// This type represents a transport connection.
107#[derive(Debug)]
108pub struct Connection<Req>
109where
110    Req: Send + Sync,
111{
112    /// User configuation.
113    config: Config,
114
115    /// To send a request to the runner.
116    sender: mpsc::Sender<ChanReq<Req>>,
117}
118
119impl<Req: Clone + Debug + Send + Sync + 'static> Connection<Req> {
120    /// Create a new connection.
121    pub fn new() -> (Self, Transport<Req>) {
122        Self::with_config(Default::default())
123    }
124
125    /// Create a new connection with a given config.
126    pub fn with_config(config: Config) -> (Self, Transport<Req>) {
127        let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
128        (Self { config, sender }, Transport::new(receiver))
129    }
130
131    /// Add a transport connection.
132    pub async fn add(
133        &self,
134        conn: Box<dyn SendRequest<Req> + Send + Sync>,
135    ) -> Result<(), Error> {
136        let (tx, rx) = oneshot::channel();
137        self.sender
138            .send(ChanReq::Add(AddReq { conn, tx }))
139            .await
140            .expect("send should not fail");
141        rx.await.expect("receive should not fail")
142    }
143
144    /// Implementation of the query method.
145    async fn request_impl(
146        self,
147        request_msg: Req,
148    ) -> Result<Message<Bytes>, Error> {
149        let (tx, rx) = oneshot::channel();
150        self.sender
151            .send(ChanReq::GetRT(RTReq { tx }))
152            .await
153            .expect("send should not fail");
154        let conn_rt = rx.await.expect("receive should not fail")?;
155        Query::new(self.config, request_msg, conn_rt, self.sender.clone())
156            .get_response()
157            .await
158    }
159}
160
161impl<Req> Clone for Connection<Req>
162where
163    Req: Send + Sync,
164{
165    fn clone(&self) -> Self {
166        Self {
167            config: self.config,
168            sender: self.sender.clone(),
169        }
170    }
171}
172
173impl<Req: Clone + Debug + Send + Sync + 'static> SendRequest<Req>
174    for Connection<Req>
175{
176    fn send_request(
177        &self,
178        request_msg: Req,
179    ) -> Box<dyn GetResponse + Send + Sync> {
180        Box::new(Request {
181            fut: Box::pin(self.clone().request_impl(request_msg)),
182        })
183    }
184}
185
186//------------ Request -------------------------------------------------------
187
188/// An active request.
189struct Request {
190    /// The underlying future.
191    fut: Pin<
192        Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
193    >,
194}
195
196impl Request {
197    /// Async function that waits for the future stored in Query to complete.
198    async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
199        (&mut self.fut).await
200    }
201}
202
203impl GetResponse for Request {
204    fn get_response(
205        &mut self,
206    ) -> Pin<
207        Box<
208            dyn Future<Output = Result<Message<Bytes>, Error>>
209                + Send
210                + Sync
211                + '_,
212        >,
213    > {
214        Box::pin(self.get_response_impl())
215    }
216}
217
218impl Debug for Request {
219    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
220        f.debug_struct("Request")
221            .field("fut", &format_args!("_"))
222            .finish()
223    }
224}
225
226//------------ Query --------------------------------------------------------
227
228/// This type represents an active query request.
229#[derive(Debug)]
230struct Query<Req>
231where
232    Req: Send + Sync,
233{
234    /// User configuration.
235    config: Config,
236
237    /// The state of the query
238    state: QueryState,
239
240    /// The request message
241    request_msg: Req,
242
243    /// List of connections identifiers and estimated response times.
244    conn_rt: Vec<ConnRT>,
245
246    /// Channel to send requests to the run function.
247    sender: mpsc::Sender<ChanReq<Req>>,
248
249    /// List of futures for outstanding requests.
250    fut_list: FuturesUnordered<
251        Pin<Box<dyn Future<Output = FutListOutput> + Send + Sync>>,
252    >,
253
254    /// Transport error that should be reported if nothing better shows
255    /// up.
256    deferred_transport_error: Option<Error>,
257
258    /// Reply that should be returned to the user if nothing better shows
259    /// up.
260    deferred_reply: Option<Message<Bytes>>,
261
262    /// The result from one of the connectons.
263    result: Option<Result<Message<Bytes>, Error>>,
264
265    /// Index of the connection that returned a result.
266    res_index: usize,
267}
268
269/// The various states a query can be in.
270#[derive(Debug)]
271enum QueryState {
272    /// The initial state
273    Init,
274
275    /// Start a request on a specific connection.
276    Probe(usize),
277
278    /// Report the response time for a specific index in the list.
279    Report(usize),
280
281    /// Wait for one of the requests to finish.
282    Wait,
283}
284
285/// The commands that can be sent to the run function.
286enum ChanReq<Req>
287where
288    Req: Send + Sync,
289{
290    /// Add a connection
291    Add(AddReq<Req>),
292
293    /// Get the list of estimated response times for all connections
294    GetRT(RTReq),
295
296    /// Start a query
297    Query(RequestReq<Req>),
298
299    /// Report how long it took to get a response
300    Report(TimeReport),
301
302    /// Report that a connection failed to provide a timely response
303    Failure(TimeReport),
304}
305
306impl<Req> Debug for ChanReq<Req>
307where
308    Req: Send + Sync,
309{
310    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
311        f.debug_struct("ChanReq").finish()
312    }
313}
314
315/// Request to add a new connection
316struct AddReq<Req> {
317    /// New connection to add
318    conn: Box<dyn SendRequest<Req> + Send + Sync>,
319
320    /// Channel to send the reply to
321    tx: oneshot::Sender<AddReply>,
322}
323
324/// Reply to an Add request
325type AddReply = Result<(), Error>;
326
327/// Request to give the estimated response times for all connections
328struct RTReq /*<Octs>*/ {
329    /// Channel to send the reply to
330    tx: oneshot::Sender<RTReply>,
331}
332
333/// Reply to a RT request
334type RTReply = Result<Vec<ConnRT>, Error>;
335
336/// Request to start a request
337struct RequestReq<Req>
338where
339    Req: Send + Sync,
340{
341    /// Identifier of connection
342    id: u64,
343
344    /// Request message
345    request_msg: Req,
346
347    /// Channel to send the reply to
348    tx: oneshot::Sender<RequestReply>,
349}
350
351impl<Req: Debug> Debug for RequestReq<Req>
352where
353    Req: Send + Sync,
354{
355    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
356        f.debug_struct("RequestReq")
357            .field("id", &self.id)
358            .field("request_msg", &self.request_msg)
359            .finish()
360    }
361}
362
363/// Reply to a request request.
364type RequestReply = Result<Box<dyn GetResponse + Send + Sync>, Error>;
365
366/// Report the amount of time until success or failure.
367#[derive(Debug)]
368struct TimeReport {
369    /// Identifier of the transport connection.
370    id: u64,
371
372    /// Time spend waiting for a reply.
373    elapsed: Duration,
374}
375
376/// Connection statistics to compute the estimated response time.
377struct ConnStats {
378    /// Aproximation of the windowed average of response times.
379    mean: f64,
380
381    /// Aproximation of the windowed average of the square of response times.
382    mean_sq: f64,
383}
384
385/// Data required to schedule requests and report timing results.
386#[derive(Clone, Debug)]
387struct ConnRT {
388    /// Estimated response time.
389    est_rt: Duration,
390
391    /// Identifier of the connection.
392    id: u64,
393
394    /// Start of a request using this connection.
395    start: Option<Instant>,
396}
397
398/// Result of the futures in fut_list.
399type FutListOutput = (usize, Result<Message<Bytes>, Error>);
400
401impl<Req: Clone + Send + Sync + 'static> Query<Req> {
402    /// Create a new query object.
403    fn new(
404        config: Config,
405        request_msg: Req,
406        mut conn_rt: Vec<ConnRT>,
407        sender: mpsc::Sender<ChanReq<Req>>,
408    ) -> Self {
409        let conn_rt_len = conn_rt.len();
410        conn_rt.sort_unstable_by(conn_rt_cmp);
411
412        // Do we want to probe a less performant upstream?
413        if conn_rt_len > 1 && random::<f64>() < PROBE_P {
414            let index: usize = 1 + random::<usize>() % (conn_rt_len - 1);
415
416            // Give the probe some head start. We may need a separate
417            // configuration parameter. A multiple of min_rt. Just use
418            // min_rt for now.
419            let min_rt = conn_rt.iter().map(|e| e.est_rt).min().unwrap();
420
421            let mut e = conn_rt.remove(index);
422            e.est_rt = min_rt;
423            conn_rt.insert(0, e);
424        }
425
426        Self {
427            config,
428            request_msg,
429            conn_rt,
430            sender,
431            state: QueryState::Init,
432            fut_list: FuturesUnordered::new(),
433            deferred_transport_error: None,
434            deferred_reply: None,
435            result: None,
436            res_index: 0,
437        }
438    }
439
440    /// Implementation of get_response.
441    async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
442        loop {
443            match self.state {
444                QueryState::Init => {
445                    if self.conn_rt.is_empty() {
446                        return Err(Error::NoTransportAvailable);
447                    }
448                    self.state = QueryState::Probe(0);
449                    continue;
450                }
451                QueryState::Probe(ind) => {
452                    self.conn_rt[ind].start = Some(Instant::now());
453                    let fut = start_request(
454                        ind,
455                        self.conn_rt[ind].id,
456                        self.sender.clone(),
457                        self.request_msg.clone(),
458                    );
459                    self.fut_list.push(Box::pin(fut));
460                    let timeout = Instant::now() + self.conn_rt[ind].est_rt;
461                    loop {
462                        tokio::select! {
463                            res = self.fut_list.next() => {
464                                let res = res.expect("res should not be empty");
465                                match res.1 {
466                                    Err(ref err) => {
467                                        if self.config.defer_transport_error {
468                                            if self.deferred_transport_error.is_none() {
469                                                self.deferred_transport_error = Some(err.clone());
470                                            }
471                                            if res.0 == ind {
472                                                // The current upstream finished,
473                                                // try the next one, if any.
474                                                self.state =
475                                                if ind+1 < self.conn_rt.len() {
476                                                    QueryState::Probe(ind+1)
477                                                }
478                                                else
479                                                {
480                                                    QueryState::Wait
481                                                };
482                                                // Break out of receive loop
483                                                break;
484                                            }
485                                            // Just continue receiving
486                                            continue;
487                                        }
488                                        // Return error to the user.
489                                    }
490                                    Ok(ref msg) => {
491                                        if skip(msg, &self.config) {
492                                            if self.deferred_reply.is_none() {
493                                                self.deferred_reply = Some(msg.clone());
494                                            }
495                                            if res.0 == ind {
496                                                // The current upstream finished,
497                                                // try the next one, if any.
498                                                self.state =
499                                                    if ind+1 < self.conn_rt.len() {
500                                                        QueryState::Probe(ind+1)
501                                                    }
502                                                    else
503                                                    {
504                                                        QueryState::Wait
505                                                    };
506                                                // Break out of receive loop
507                                                break;
508                                            }
509                                            // Just continue receiving
510                                            continue;
511                                        }
512                                        // Now we have a reply that can be
513                                        // returned to the user.
514                                    }
515                                }
516                                self.result = Some(res.1);
517                                self.res_index= res.0;
518
519                                self.state = QueryState::Report(0);
520                                // Break out of receive loop
521                                break;
522                            }
523                            _ = sleep_until(timeout) => {
524                                // Move to the next Probe state if there
525                                // are more upstreams to try, otherwise
526                                // move to the Wait state.
527                                self.state =
528                                if ind+1 < self.conn_rt.len() {
529                                    QueryState::Probe(ind+1)
530                                }
531                                else {
532                                    QueryState::Wait
533                                };
534                                // Break out of receive loop
535                                break;
536                            }
537                        }
538                    }
539                    // Continue with state machine loop
540                    continue;
541                }
542                QueryState::Report(ind) => {
543                    if ind >= self.conn_rt.len()
544                        || self.conn_rt[ind].start.is_none()
545                    {
546                        // Nothing more to report. Return result.
547                        let res = self
548                            .result
549                            .take()
550                            .expect("result should not be empty");
551                        return res;
552                    }
553
554                    let start = self.conn_rt[ind]
555                        .start
556                        .expect("start time should not be empty");
557                    let elapsed = start.elapsed();
558                    let time_report = TimeReport {
559                        id: self.conn_rt[ind].id,
560                        elapsed,
561                    };
562                    let report = if ind == self.res_index {
563                        // Succesfull entry
564                        ChanReq::Report(time_report)
565                    } else {
566                        // Failed entry
567                        ChanReq::Failure(time_report)
568                    };
569
570                    // Send could fail but we don't care.
571                    let _ = self.sender.send(report).await;
572
573                    self.state = QueryState::Report(ind + 1);
574                    continue;
575                }
576                QueryState::Wait => {
577                    loop {
578                        if self.fut_list.is_empty() {
579                            // We have nothing left. There should be a reply or
580                            // an error. Prefer a reply over an error.
581                            if self.deferred_reply.is_some() {
582                                let msg = self
583                                    .deferred_reply
584                                    .take()
585                                    .expect("just checked for Some");
586                                return Ok(msg);
587                            }
588                            if self.deferred_transport_error.is_some() {
589                                let err = self
590                                    .deferred_transport_error
591                                    .take()
592                                    .expect("just checked for Some");
593                                return Err(err);
594                            }
595                            panic!("either deferred_reply or deferred_error should be present");
596                        }
597                        let res = self.fut_list.next().await;
598                        let res = res.expect("res should not be empty");
599                        match res.1 {
600                            Err(ref err) => {
601                                if self.config.defer_transport_error {
602                                    if self.deferred_transport_error.is_none()
603                                    {
604                                        self.deferred_transport_error =
605                                            Some(err.clone());
606                                    }
607                                    // Just continue with the next future, or
608                                    // finish if fut_list is empty.
609                                    continue;
610                                }
611                                // Return error to the user.
612                            }
613                            Ok(ref msg) => {
614                                if skip(msg, &self.config) {
615                                    if self.deferred_reply.is_none() {
616                                        self.deferred_reply =
617                                            Some(msg.clone());
618                                    }
619                                    // Just continue with the next future, or
620                                    // finish if fut_list is empty.
621                                    continue;
622                                }
623                                // Return reply to user.
624                            }
625                        }
626                        self.result = Some(res.1);
627                        self.res_index = res.0;
628                        self.state = QueryState::Report(0);
629                        // Break out of loop to continue with the state machine
630                        break;
631                    }
632                    continue;
633                }
634            }
635        }
636    }
637}
638
639//------------ Transport -----------------------------------------------------
640
641/// Type that actually implements the connection.
642#[derive(Debug)]
643pub struct Transport<Req>
644where
645    Req: Send + Sync,
646{
647    /// Receive side of the channel used by the runner.
648    receiver: mpsc::Receiver<ChanReq<Req>>,
649}
650
651impl<Req: Clone + Send + Sync + 'static> Transport<Req> {
652    /// Implementation of the new method.
653    fn new(receiver: mpsc::Receiver<ChanReq<Req>>) -> Self {
654        Self { receiver }
655    }
656
657    /// Run method.
658    pub async fn run(mut self) {
659        let mut next_id: u64 = 10;
660        let mut conn_stats: Vec<ConnStats> = Vec::new();
661        let mut conn_rt: Vec<ConnRT> = Vec::new();
662        let mut conns: Vec<Box<dyn SendRequest<Req> + Send + Sync>> =
663            Vec::new();
664
665        loop {
666            let req = match self.receiver.recv().await {
667                Some(req) => req,
668                None => break, // All references to connection objects are
669                               // dropped. Shutdown.
670            };
671            match req {
672                ChanReq::Add(add_req) => {
673                    let id = next_id;
674                    next_id += 1;
675                    conn_stats.push(ConnStats {
676                        mean: (DEFAULT_RT_MS as f64) / 1000.,
677                        mean_sq: 0.,
678                    });
679                    conn_rt.push(ConnRT {
680                        id,
681                        est_rt: DEFAULT_RT,
682                        start: None,
683                    });
684                    conns.push(add_req.conn);
685
686                    // Don't care if send fails
687                    let _ = add_req.tx.send(Ok(()));
688                }
689                ChanReq::GetRT(rt_req) => {
690                    // Don't care if send fails
691                    let _ = rt_req.tx.send(Ok(conn_rt.clone()));
692                }
693                ChanReq::Query(request_req) => {
694                    let opt_ind =
695                        conn_rt.iter().position(|e| e.id == request_req.id);
696                    match opt_ind {
697                        Some(ind) => {
698                            let query = conns[ind]
699                                .send_request(request_req.request_msg);
700                            // Don't care if send fails
701                            let _ = request_req.tx.send(Ok(query));
702                        }
703                        None => {
704                            // Don't care if send fails
705                            let _ = request_req
706                                .tx
707                                .send(Err(Error::RedundantTransportNotFound));
708                        }
709                    }
710                }
711                ChanReq::Report(time_report) => {
712                    let opt_ind =
713                        conn_rt.iter().position(|e| e.id == time_report.id);
714                    if let Some(ind) = opt_ind {
715                        let elapsed = time_report.elapsed.as_secs_f64();
716                        conn_stats[ind].mean +=
717                            (elapsed - conn_stats[ind].mean) / SMOOTH_N;
718                        let elapsed_sq = elapsed * elapsed;
719                        conn_stats[ind].mean_sq +=
720                            (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N;
721                        let mean = conn_stats[ind].mean;
722                        let var = conn_stats[ind].mean_sq - mean * mean;
723                        let std_dev =
724                            if var < 0. { 0. } else { f64::sqrt(var) };
725                        let est_rt = mean + 3. * std_dev;
726                        conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
727                    }
728                }
729                ChanReq::Failure(time_report) => {
730                    let opt_ind =
731                        conn_rt.iter().position(|e| e.id == time_report.id);
732                    if let Some(ind) = opt_ind {
733                        let elapsed = time_report.elapsed.as_secs_f64();
734                        if elapsed < conn_stats[ind].mean {
735                            // Do not update the mean if a
736                            // failure took less time than the
737                            // current mean.
738                            continue;
739                        }
740                        conn_stats[ind].mean +=
741                            (elapsed - conn_stats[ind].mean) / SMOOTH_N;
742                        let elapsed_sq = elapsed * elapsed;
743                        conn_stats[ind].mean_sq +=
744                            (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N;
745                        let mean = conn_stats[ind].mean;
746                        let var = conn_stats[ind].mean_sq - mean * mean;
747                        let std_dev =
748                            if var < 0. { 0. } else { f64::sqrt(var) };
749                        let est_rt = mean + 3. * std_dev;
750                        conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
751                    }
752                }
753            }
754        }
755    }
756}
757
758//------------ Utility --------------------------------------------------------
759
760/// Async function to send a request and wait for the reply.
761///
762/// This gives a single future that we can put in a list.
763async fn start_request<Req>(
764    index: usize,
765    id: u64,
766    sender: mpsc::Sender<ChanReq<Req>>,
767    request_msg: Req,
768) -> (usize, Result<Message<Bytes>, Error>)
769where
770    Req: Send + Sync,
771{
772    let (tx, rx) = oneshot::channel();
773    sender
774        .send(ChanReq::Query(RequestReq {
775            id,
776            request_msg,
777            tx,
778        }))
779        .await
780        .expect("send is expected to work");
781    let mut request = match rx.await.expect("receive is expected to work") {
782        Err(err) => return (index, Err(err)),
783        Ok(request) => request,
784    };
785    let reply = request.get_response().await;
786
787    (index, reply)
788}
789
790/// Compare ConnRT elements based on estimated response time.
791fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering {
792    e1.est_rt.cmp(&e2.est_rt)
793}
794
795/// Return if this reply should be skipped or not.
796fn skip<Octs: Octets>(msg: &Message<Octs>, config: &Config) -> bool {
797    // Check if we actually need to check.
798    if !config.defer_refused && !config.defer_servfail {
799        return false;
800    }
801
802    let opt_rcode = msg.opt_rcode();
803    // OptRcode needs PartialEq
804    if let OptRcode::REFUSED = opt_rcode {
805        if config.defer_refused {
806            return true;
807        }
808    }
809    if let OptRcode::SERVFAIL = opt_rcode {
810        if config.defer_servfail {
811            return true;
812        }
813    }
814
815    false
816}