domain/net/client/
load_balancer.rs

1//! A transport that tries to distribute requests over multiple upstreams.
2//!
3//! It is assumed that the upstreams have similar performance. use the
4//! [super::redundant] transport to forward requests to the best upstream out of
5//! upstreams that may have quite different performance.
6//!
7//! Basic mode of operation
8//!
9//! Associated with every upstream configured is optionally a burst length
10//! and burst interval. Burst length deviced by burst interval gives a
11//! queries per second (QPS) value. This be use to limit the rate and
12//! especially the bursts that reach upstream servers. Once the burst
13//! length has been reach, the upstream receives no new requests until
14//! the burst interval has completed.
15//!
16//! For each upstream the object maintains an estimated response time.
17//! with the configuration value slow_rt_factor, the group of upstream
18//! that have not exceeded their burst length are divided into a 'fast'
19//! and a 'slow' group. The slow group are those upstream that have an
20//! estimated response time that is higher than slow_rt_factor times the
21//! lowest estimated response time. Slow upstream are considered only when
22//! all fast upstream failed to provide a suitable response.
23//!
24//! Within the group of fast upstreams, the ones with the lower queue
25//! length are preferred. This tries to give each of the fast upstreams
26//! an equal number of outstanding requests.
27//!
28//! Within a group of fast upstreams with the same queue length, the
29//! one with the lowest estimated response time is preferred.
30//!
31//! Probing
32//!
33//! Upstream with high estimated response times may be get any traffic and
34//! therefore the estimated response time may remain high. Probing is
35//! intended to solve that problem. Using a random number generator,
36//! occasionally an upstream is selected for probing. If the selected
37//! upstream currently has a non-zero queue then probing is not needed and
38//! no probe will happen.
39//! Otherwise, the upstream to be probed is selected first with an
40//! estimated response time equal to the lowest one. If the probed upstream
41//! does not provide a response within that time, the otherwise best upstream
42//! also gets the request. If the probes upstream provides a suitable response
43//! before the next upstream then its estimated will be updated.
44
45use crate::base::iana::OptRcode;
46use crate::base::iana::Rcode;
47use crate::base::opt::AllOptData;
48use crate::base::Message;
49use crate::base::MessageBuilder;
50use crate::base::StaticCompressor;
51use crate::dep::octseq::OctetsInto;
52use crate::net::client::request::ComposeRequest;
53use crate::net::client::request::{Error, GetResponse, SendRequest};
54use crate::utils::config::DefMinMax;
55use bytes::Bytes;
56use futures_util::stream::FuturesUnordered;
57use futures_util::StreamExt;
58use octseq::Octets;
59use rand::random;
60use std::boxed::Box;
61use std::cmp::Ordering;
62use std::fmt::{Debug, Formatter};
63use std::future::Future;
64use std::pin::Pin;
65use std::string::String;
66use std::string::ToString;
67use std::sync::Arc;
68use std::vec::Vec;
69use tokio::sync::{mpsc, oneshot};
70use tokio::time::{sleep_until, Duration, Instant};
71
72/*
73Basic algorithm:
74- try to distribute requests over all upstreams subject to some limitations.
75- limit bursts
76  - record the start of a burst interval when a request goes out over an
77    upstream
78  - record the number of requests since the start of the burst interval
79  - in the burst is larger than the maximum configured by the user then the
80    upstream is no longer available.
81  - start a new burst interval when enough time has passed.
82- prefer fast upstreams over slow upstreams
83  - maintain a response time estimate for each upstream
84  - upstreams with an estimate response time larger than slow_rt_factor
85    times the lowest estimated response time are consider slow.
86  - 'fast' upstreams are preferred over slow upstream. However slow upstreams
87    are considered if during a single request all fast upstreams fail.
88- prefer fast upstream with a low queue length
89  - maintain a counter with the number of current outstanding requests on an
90    upstream.
91  - prefer the upstream with the lowest count.
92  - preset the upstream with the lowest estimated response time in case
93    two or more upstreams have the same count.
94
95Execution:
96- set a timer to the expect response time.
97- if the timer expires before reply arrives, send the query to the next lowest
98  and set a timer
99- when a reply arrives update the expected response time for the relevant
100  upstream and for the ones that failed.
101
102Probing:
103- upstream that currently have outstanding requests do not need to be
104  probed.
105- for idle upstream, based on a random number generator:
106  - pick a different upstream rather then the best
107  - but set the timer to the expected response time of the best.
108  - maybe we need a configuration parameter for the amound of head start
109    given to the probed upstream.
110*/
111
112/// Capacity of the channel that transports [ChanReq].
113const DEF_CHAN_CAP: usize = 8;
114
115/// Time in milliseconds for the initial response time estimate.
116const DEFAULT_RT_MS: u64 = 300;
117
118/// The initial response time estimate for unused connections.
119const DEFAULT_RT: Duration = Duration::from_millis(DEFAULT_RT_MS);
120
121/// Maintain a moving average for the measured response time and the
122/// square of that. The window is SMOOTH_N.
123const SMOOTH_N: f64 = 8.;
124
125/// Chance to probe a worse connection.
126const PROBE_P: f64 = 0.05;
127
128//------------ Configuration Constants ----------------------------------------
129
130/// Cut off for slow upstreams.
131const DEF_SLOW_RT_FACTOR: f64 = 5.0;
132
133/// Minimum value for the cut off factor.
134const MIN_SLOW_RT_FACTOR: f64 = 1.0;
135
136/// Interval for limiting upstream query bursts.
137const BURST_INTERVAL: DefMinMax<Duration> = DefMinMax::new(
138    Duration::from_secs(1),
139    Duration::from_millis(1),
140    Duration::from_secs(3600),
141);
142
143//------------ Config ---------------------------------------------------------
144
145/// User configuration variables.
146#[derive(Clone, Copy, Debug)]
147pub struct Config {
148    /// Defer transport errors.
149    defer_transport_error: bool,
150
151    /// Defer replies that report Refused.
152    defer_refused: bool,
153
154    /// Defer replies that report ServFail.
155    defer_servfail: bool,
156
157    /// Cut-off for slow upstreams as a factor of the fastest upstream.
158    slow_rt_factor: f64,
159}
160
161impl Config {
162    /// Return the value of the defer_transport_error configuration variable.
163    pub fn defer_transport_error(&self) -> bool {
164        self.defer_transport_error
165    }
166
167    /// Set the value of the defer_transport_error configuration variable.
168    pub fn set_defer_transport_error(&mut self, value: bool) {
169        self.defer_transport_error = value
170    }
171
172    /// Return the value of the defer_refused configuration variable.
173    pub fn defer_refused(&self) -> bool {
174        self.defer_refused
175    }
176
177    /// Set the value of the defer_refused configuration variable.
178    pub fn set_defer_refused(&mut self, value: bool) {
179        self.defer_refused = value
180    }
181
182    /// Return the value of the defer_servfail configuration variable.
183    pub fn defer_servfail(&self) -> bool {
184        self.defer_servfail
185    }
186
187    /// Set the value of the defer_servfail configuration variable.
188    pub fn set_defer_servfail(&mut self, value: bool) {
189        self.defer_servfail = value
190    }
191
192    /// Set the value of the slow_rt_factor configuration variable.
193    pub fn slow_rt_factor(&self) -> f64 {
194        self.slow_rt_factor
195    }
196
197    /// Set the value of the slow_rt_factor configuration variable.
198    pub fn set_slow_rt_factor(&mut self, mut value: f64) {
199        if value < MIN_SLOW_RT_FACTOR {
200            value = MIN_SLOW_RT_FACTOR
201        };
202        self.slow_rt_factor = value;
203    }
204}
205
206impl Default for Config {
207    fn default() -> Self {
208        Self {
209            defer_transport_error: Default::default(),
210            defer_refused: Default::default(),
211            defer_servfail: Default::default(),
212            slow_rt_factor: DEF_SLOW_RT_FACTOR,
213        }
214    }
215}
216
217//------------ ConnConfig -----------------------------------------------------
218
219/// Configuration variables for each upstream.
220#[derive(Clone, Copy, Debug, Default)]
221pub struct ConnConfig {
222    /// Maximum burst of upstream queries.
223    max_burst: Option<u64>,
224
225    /// Interval over which the burst is counted.
226    burst_interval: Duration,
227}
228
229impl ConnConfig {
230    /// Create a new ConnConfig object.
231    pub fn new() -> Self {
232        Self {
233            max_burst: None,
234            burst_interval: BURST_INTERVAL.default(),
235        }
236    }
237
238    /// Return the current configuration value for the maximum burst.
239    /// None means that there is no limit.
240    pub fn max_burst(&mut self) -> Option<u64> {
241        self.max_burst
242    }
243
244    /// Set the configuration value for the maximum burst.
245    /// The value None means no limit.
246    pub fn set_max_burst(&mut self, max_burst: Option<u64>) {
247        self.max_burst = max_burst;
248    }
249
250    /// Return the current burst interval.
251    pub fn burst_interval(&mut self) -> Duration {
252        self.burst_interval
253    }
254
255    /// Set a new burst interval.
256    ///
257    /// The interval is silently limited to at least 1 millesecond and
258    /// at most 1 hour.
259    pub fn set_burst_interval(&mut self, burst_interval: Duration) {
260        self.burst_interval = BURST_INTERVAL.limit(burst_interval);
261    }
262}
263
264//------------ Connection -----------------------------------------------------
265
266/// This type represents a transport connection.
267#[derive(Debug)]
268pub struct Connection<Req>
269where
270    Req: Send + Sync,
271{
272    /// User configuation.
273    config: Config,
274
275    /// To send a request to the runner.
276    sender: mpsc::Sender<ChanReq<Req>>,
277}
278
279impl<Req: Clone + Debug + Send + Sync + 'static> Connection<Req> {
280    /// Create a new connection.
281    pub fn new() -> (Self, Transport<Req>) {
282        Self::with_config(Default::default())
283    }
284
285    /// Create a new connection with a given config.
286    pub fn with_config(config: Config) -> (Self, Transport<Req>) {
287        let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
288        (Self { config, sender }, Transport::new(receiver))
289    }
290
291    /// Add a transport connection.
292    pub async fn add(
293        &self,
294        label: &str,
295        config: &ConnConfig,
296        conn: Box<dyn SendRequest<Req> + Send + Sync>,
297    ) -> Result<(), Error> {
298        let (tx, rx) = oneshot::channel();
299        self.sender
300            .send(ChanReq::Add(AddReq {
301                label: label.to_string(),
302                max_burst: config.max_burst,
303                burst_interval: config.burst_interval,
304                conn,
305                tx,
306            }))
307            .await
308            .expect("send should not fail");
309        rx.await.expect("receive should not fail")
310    }
311
312    /// Implementation of the query method.
313    async fn request_impl(
314        self,
315        request_msg: Req,
316    ) -> Result<Message<Bytes>, Error>
317    where
318        Req: ComposeRequest,
319    {
320        let (tx, rx) = oneshot::channel();
321        self.sender
322            .send(ChanReq::GetRT(RTReq { tx }))
323            .await
324            .expect("send should not fail");
325        let conn_rt = rx.await.expect("receive should not fail")?;
326        if conn_rt.is_empty() {
327            return serve_fail(&request_msg.to_message().unwrap());
328        }
329        Query::new(self.config, request_msg, conn_rt, self.sender.clone())
330            .get_response()
331            .await
332    }
333}
334
335impl<Req> Clone for Connection<Req>
336where
337    Req: Send + Sync,
338{
339    fn clone(&self) -> Self {
340        Self {
341            config: self.config,
342            sender: self.sender.clone(),
343        }
344    }
345}
346
347impl<Req: Clone + ComposeRequest + Debug + Send + Sync + 'static>
348    SendRequest<Req> for Connection<Req>
349{
350    fn send_request(
351        &self,
352        request_msg: Req,
353    ) -> Box<dyn GetResponse + Send + Sync> {
354        Box::new(Request {
355            fut: Box::pin(self.clone().request_impl(request_msg)),
356        })
357    }
358}
359
360//------------ Request -------------------------------------------------------
361
362/// An active request.
363struct Request {
364    /// The underlying future.
365    fut: Pin<
366        Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
367    >,
368}
369
370impl Request {
371    /// Async function that waits for the future stored in Query to complete.
372    async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
373        (&mut self.fut).await
374    }
375}
376
377impl GetResponse for Request {
378    fn get_response(
379        &mut self,
380    ) -> Pin<
381        Box<
382            dyn Future<Output = Result<Message<Bytes>, Error>>
383                + Send
384                + Sync
385                + '_,
386        >,
387    > {
388        Box::pin(self.get_response_impl())
389    }
390}
391
392impl Debug for Request {
393    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
394        f.debug_struct("Request")
395            .field("fut", &format_args!("_"))
396            .finish()
397    }
398}
399
400//------------ Query --------------------------------------------------------
401
402/// This type represents an active query request.
403#[derive(Debug)]
404struct Query<Req>
405where
406    Req: Send + Sync,
407{
408    /// User configuration.
409    config: Config,
410
411    /// The state of the query
412    state: QueryState,
413
414    /// The request message
415    request_msg: Req,
416
417    /// List of connections identifiers and estimated response times.
418    conn_rt: Vec<ConnRT>,
419
420    /// Channel to send requests to the run function.
421    sender: mpsc::Sender<ChanReq<Req>>,
422
423    /// List of futures for outstanding requests.
424    fut_list: FuturesUnordered<
425        Pin<Box<dyn Future<Output = FutListOutput> + Send + Sync>>,
426    >,
427
428    /// Transport error that should be reported if nothing better shows
429    /// up.
430    deferred_transport_error: Option<Error>,
431
432    /// Reply that should be returned to the user if nothing better shows
433    /// up.
434    deferred_reply: Option<Message<Bytes>>,
435
436    /// The result from one of the connectons.
437    result: Option<Result<Message<Bytes>, Error>>,
438
439    /// Index of the connection that returned a result.
440    res_index: usize,
441}
442
443/// The various states a query can be in.
444#[derive(Debug)]
445enum QueryState {
446    /// The initial state
447    Init,
448
449    /// Start a request on a specific connection.
450    Probe(usize),
451
452    /// Report the response time for a specific index in the list.
453    Report(usize),
454
455    /// Wait for one of the requests to finish.
456    Wait,
457}
458
459/// The commands that can be sent to the run function.
460enum ChanReq<Req>
461where
462    Req: Send + Sync,
463{
464    /// Add a connection
465    Add(AddReq<Req>),
466
467    /// Get the list of estimated response times for all connections
468    GetRT(RTReq),
469
470    /// Start a query
471    Query(RequestReq<Req>),
472
473    /// Report how long it took to get a response
474    Report(TimeReport),
475
476    /// Report that a connection failed to provide a timely response
477    Failure(TimeReport),
478}
479
480impl<Req> Debug for ChanReq<Req>
481where
482    Req: Send + Sync,
483{
484    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
485        f.debug_struct("ChanReq").finish()
486    }
487}
488
489/// Request to add a new connection
490struct AddReq<Req> {
491    /// Name of new connection
492    label: String,
493
494    /// Maximum length of a burst.
495    max_burst: Option<u64>,
496
497    /// Interval over which bursts are counted.
498    burst_interval: Duration,
499
500    /// New connection to add
501    conn: Box<dyn SendRequest<Req> + Send + Sync>,
502
503    /// Channel to send the reply to
504    tx: oneshot::Sender<AddReply>,
505}
506
507/// Reply to an Add request
508type AddReply = Result<(), Error>;
509
510/// Request to give the estimated response times for all connections
511struct RTReq /*<Octs>*/ {
512    /// Channel to send the reply to
513    tx: oneshot::Sender<RTReply>,
514}
515
516/// Reply to a RT request
517type RTReply = Result<Vec<ConnRT>, Error>;
518
519/// Request to start a request
520struct RequestReq<Req>
521where
522    Req: Send + Sync,
523{
524    /// Identifier of connection
525    id: u64,
526
527    /// Request message
528    request_msg: Req,
529
530    /// Channel to send the reply to
531    tx: oneshot::Sender<RequestReply>,
532}
533
534impl<Req: Debug> Debug for RequestReq<Req>
535where
536    Req: Send + Sync,
537{
538    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
539        f.debug_struct("RequestReq")
540            .field("id", &self.id)
541            .field("request_msg", &self.request_msg)
542            .finish()
543    }
544}
545
546/// Reply to a request request.
547type RequestReply =
548    Result<(Box<dyn GetResponse + Send + Sync>, Arc<()>), Error>;
549
550/// Report the amount of time until success or failure.
551#[derive(Debug)]
552struct TimeReport {
553    /// Identifier of the transport connection.
554    id: u64,
555
556    /// Time spend waiting for a reply.
557    elapsed: Duration,
558}
559
560/// Connection statistics to compute the estimated response time.
561struct ConnStats {
562    /// Name of the connection.
563    _label: String,
564
565    /// Aproximation of the windowed average of response times.
566    mean: f64,
567
568    /// Aproximation of the windowed average of the square of response times.
569    mean_sq: f64,
570
571    /// Maximum upstream query burst.
572    max_burst: Option<u64>,
573
574    /// burst length,
575    burst_interval: Duration,
576
577    /// Start of the current burst
578    burst_start: Instant,
579
580    /// Number of queries since the start of the burst.
581    burst: u64,
582
583    /// Use the number of references to an Arc as queue length. The number
584    /// of references is one higher than then actual queue length.
585    queue_length_plus_one: Arc<()>,
586}
587
588impl ConnStats {
589    /// Update response time statistics.
590    fn update(&mut self, elapsed: Duration) {
591        let elapsed = elapsed.as_secs_f64();
592        self.mean += (elapsed - self.mean) / SMOOTH_N;
593        let elapsed_sq = elapsed * elapsed;
594        self.mean_sq += (elapsed_sq - self.mean_sq) / SMOOTH_N;
595    }
596
597    /// Get an estimated response time.
598    fn est_rt(&self) -> f64 {
599        let mean = self.mean;
600        let var = self.mean_sq - mean * mean;
601        let std_dev = f64::sqrt(var.max(0.));
602        mean + 3. * std_dev
603    }
604}
605
606/// Data required to schedule requests and report timing results.
607#[derive(Clone, Debug)]
608struct ConnRT {
609    /// Estimated response time.
610    est_rt: Duration,
611
612    /// Identifier of the connection.
613    id: u64,
614
615    /// Start of a request using this connection.
616    start: Option<Instant>,
617
618    /// Use the number of references to an Arc as queue length. The number
619    /// of references is one higher than then actual queue length.
620    queue_length: usize,
621}
622
623/// Result of the futures in fut_list.
624type FutListOutput = (usize, Result<Message<Bytes>, Error>);
625
626impl<Req: Clone + Send + Sync + 'static> Query<Req> {
627    /// Create a new query object.
628    fn new(
629        config: Config,
630        request_msg: Req,
631        mut conn_rt: Vec<ConnRT>,
632        sender: mpsc::Sender<ChanReq<Req>>,
633    ) -> Self {
634        let conn_rt_len = conn_rt.len();
635        let min_rt = conn_rt.iter().map(|e| e.est_rt).min().unwrap();
636        let slow_rt = min_rt.as_secs_f64() * config.slow_rt_factor;
637        conn_rt.sort_unstable_by(|e1, e2| conn_rt_cmp(e1, e2, slow_rt));
638
639        // Do we want to probe a less performant upstream? We only need to
640        // probe upstreams with a queue length of zero. If the queue length
641        // is non-zero then the upstream recently got work and does not need
642        // to be probed.
643        if conn_rt_len > 1 && random::<f64>() < PROBE_P {
644            let index: usize = 1 + random::<usize>() % (conn_rt_len - 1);
645
646            if conn_rt[index].queue_length == 0 {
647                // Give the probe some head start. We may need a separate
648                // configuration parameter. A multiple of min_rt. Just use
649                // min_rt for now.
650                let mut e = conn_rt.remove(index);
651                e.est_rt = min_rt;
652                conn_rt.insert(0, e);
653            }
654        }
655
656        Self {
657            config,
658            request_msg,
659            conn_rt,
660            sender,
661            state: QueryState::Init,
662            fut_list: FuturesUnordered::new(),
663            deferred_transport_error: None,
664            deferred_reply: None,
665            result: None,
666            res_index: 0,
667        }
668    }
669
670    /// Implementation of get_response.
671    async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
672        loop {
673            match self.state {
674                QueryState::Init => {
675                    if self.conn_rt.is_empty() {
676                        return Err(Error::NoTransportAvailable);
677                    }
678                    self.state = QueryState::Probe(0);
679                    continue;
680                }
681                QueryState::Probe(ind) => {
682                    self.conn_rt[ind].start = Some(Instant::now());
683                    let fut = start_request(
684                        ind,
685                        self.conn_rt[ind].id,
686                        self.sender.clone(),
687                        self.request_msg.clone(),
688                    );
689                    self.fut_list.push(Box::pin(fut));
690                    let timeout = Instant::now() + self.conn_rt[ind].est_rt;
691                    loop {
692                        tokio::select! {
693                            res = self.fut_list.next() => {
694                                let res = res.expect("res should not be empty");
695                                match res.1 {
696                                    Err(ref err) => {
697                                        if self.config.defer_transport_error {
698                                            if self.deferred_transport_error.is_none() {
699                                                self.deferred_transport_error = Some(err.clone());
700                                            }
701                                            if res.0 == ind {
702                                                // The current upstream finished,
703                                                // try the next one, if any.
704                                                self.state =
705                                                if ind+1 < self.conn_rt.len() {
706                                                    QueryState::Probe(ind+1)
707                                                }
708                                                else
709                                                {
710                                                    QueryState::Wait
711                                                };
712                                                // Break out of receive loop
713                                                break;
714                                            }
715                                            // Just continue receiving
716                                            continue;
717                                        }
718                                        // Return error to the user.
719                                    }
720                                    Ok(ref msg) => {
721                                        if skip(msg, &self.config) {
722                                            if self.deferred_reply.is_none() {
723                                                self.deferred_reply = Some(msg.clone());
724                                            }
725                                            if res.0 == ind {
726                                                // The current upstream finished,
727                                                // try the next one, if any.
728                                                self.state =
729                                                    if ind+1 < self.conn_rt.len() {
730                                                        QueryState::Probe(ind+1)
731                                                    }
732                                                    else
733                                                    {
734                                                        QueryState::Wait
735                                                    };
736                                                // Break out of receive loop
737                                                break;
738                                            }
739                                            // Just continue receiving
740                                            continue;
741                                        }
742                                        // Now we have a reply that can be
743                                        // returned to the user.
744                                    }
745                                }
746                                self.result = Some(res.1);
747                                self.res_index = res.0;
748
749                                self.state = QueryState::Report(0);
750                                // Break out of receive loop
751                                break;
752                            }
753                            _ = sleep_until(timeout) => {
754                                // Move to the next Probe state if there
755                                // are more upstreams to try, otherwise
756                                // move to the Wait state.
757                                self.state =
758                                if ind+1 < self.conn_rt.len() {
759                                    QueryState::Probe(ind+1)
760                                }
761                                else {
762                                    QueryState::Wait
763                                };
764                                // Break out of receive loop
765                                break;
766                            }
767                        }
768                    }
769                    // Continue with state machine loop
770                    continue;
771                }
772                QueryState::Report(ind) => {
773                    if ind >= self.conn_rt.len()
774                        || self.conn_rt[ind].start.is_none()
775                    {
776                        // Nothing more to report. Return result.
777                        let res = self
778                            .result
779                            .take()
780                            .expect("result should not be empty");
781                        return res;
782                    }
783
784                    let start = self.conn_rt[ind]
785                        .start
786                        .expect("start time should not be empty");
787                    let elapsed = start.elapsed();
788                    let time_report = TimeReport {
789                        id: self.conn_rt[ind].id,
790                        elapsed,
791                    };
792                    let report = if ind == self.res_index {
793                        // Succesfull entry
794                        ChanReq::Report(time_report)
795                    } else {
796                        // Failed entry
797                        ChanReq::Failure(time_report)
798                    };
799
800                    // Send could fail but we don't care.
801                    let _ = self.sender.send(report).await;
802
803                    self.state = QueryState::Report(ind + 1);
804                    continue;
805                }
806                QueryState::Wait => {
807                    loop {
808                        if self.fut_list.is_empty() {
809                            // We have nothing left. There should be a reply or
810                            // an error. Prefer a reply over an error.
811                            if self.deferred_reply.is_some() {
812                                let msg = self
813                                    .deferred_reply
814                                    .take()
815                                    .expect("just checked for Some");
816                                return Ok(msg);
817                            }
818                            if self.deferred_transport_error.is_some() {
819                                let err = self
820                                    .deferred_transport_error
821                                    .take()
822                                    .expect("just checked for Some");
823                                return Err(err);
824                            }
825                            panic!("either deferred_reply or deferred_error should be present");
826                        }
827                        let res = self.fut_list.next().await;
828                        let res = res.expect("res should not be empty");
829                        match res.1 {
830                            Err(ref err) => {
831                                if self.config.defer_transport_error {
832                                    if self.deferred_transport_error.is_none()
833                                    {
834                                        self.deferred_transport_error =
835                                            Some(err.clone());
836                                    }
837                                    // Just continue with the next future, or
838                                    // finish if fut_list is empty.
839                                    continue;
840                                }
841                                // Return error to the user.
842                            }
843                            Ok(ref msg) => {
844                                if skip(msg, &self.config) {
845                                    if self.deferred_reply.is_none() {
846                                        self.deferred_reply =
847                                            Some(msg.clone());
848                                    }
849                                    // Just continue with the next future, or
850                                    // finish if fut_list is empty.
851                                    continue;
852                                }
853                                // Return reply to user.
854                            }
855                        }
856                        self.result = Some(res.1);
857                        self.res_index = res.0;
858                        self.state = QueryState::Report(0);
859                        // Break out of loop to continue with the state machine
860                        break;
861                    }
862                    continue;
863                }
864            }
865        }
866    }
867}
868
869//------------ Transport -----------------------------------------------------
870
871/// Type that actually implements the connection.
872#[derive(Debug)]
873pub struct Transport<Req>
874where
875    Req: Send + Sync,
876{
877    /// Receive side of the channel used by the runner.
878    receiver: mpsc::Receiver<ChanReq<Req>>,
879}
880
881impl<Req: Clone + Send + Sync + 'static> Transport<Req> {
882    /// Implementation of the new method.
883    fn new(receiver: mpsc::Receiver<ChanReq<Req>>) -> Self {
884        Self { receiver }
885    }
886
887    /// Run method.
888    pub async fn run(mut self) {
889        let mut next_id: u64 = 10;
890        let mut conn_stats: Vec<ConnStats> = Vec::new();
891        let mut conn_rt: Vec<ConnRT> = Vec::new();
892        let mut conns: Vec<Box<dyn SendRequest<Req> + Send + Sync>> =
893            Vec::new();
894
895        loop {
896            let req = match self.receiver.recv().await {
897                Some(req) => req,
898                None => break, // All references to connection objects are
899                               // dropped. Shutdown.
900            };
901            match req {
902                ChanReq::Add(add_req) => {
903                    let id = next_id;
904                    next_id += 1;
905                    conn_stats.push(ConnStats {
906                        _label: add_req.label,
907                        mean: (DEFAULT_RT_MS as f64) / 1000.,
908                        mean_sq: 0.,
909                        max_burst: add_req.max_burst,
910                        burst_interval: add_req.burst_interval,
911                        burst_start: Instant::now(),
912                        burst: 0,
913                        queue_length_plus_one: Arc::new(()),
914                    });
915                    conn_rt.push(ConnRT {
916                        id,
917                        est_rt: DEFAULT_RT,
918                        start: None,
919                        queue_length: 42, // To spot errors.
920                    });
921                    conns.push(add_req.conn);
922
923                    // Don't care if send fails
924                    let _ = add_req.tx.send(Ok(()));
925                }
926                ChanReq::GetRT(rt_req) => {
927                    let mut tmp_conn_rt = conn_rt.clone();
928
929                    // Remove entries that exceed the QPS limit. Loop
930                    // backward to efficiently remove them.
931                    for i in (0..tmp_conn_rt.len()).rev() {
932                        // Fill-in current queue length.
933                        tmp_conn_rt[i].queue_length = Arc::strong_count(
934                            &conn_stats[i].queue_length_plus_one,
935                        ) - 1;
936                        if let Some(max_burst) = conn_stats[i].max_burst {
937                            if conn_stats[i].burst_start.elapsed()
938                                > conn_stats[i].burst_interval
939                            {
940                                conn_stats[i].burst_start = Instant::now();
941                                conn_stats[i].burst = 0;
942                            }
943                            if conn_stats[i].burst > max_burst {
944                                tmp_conn_rt.swap_remove(i);
945                            }
946                        } else {
947                            // No limit.
948                        }
949                    }
950                    // Don't care if send fails
951                    let _ = rt_req.tx.send(Ok(tmp_conn_rt));
952                }
953                ChanReq::Query(request_req) => {
954                    let opt_ind =
955                        conn_rt.iter().position(|e| e.id == request_req.id);
956                    match opt_ind {
957                        Some(ind) => {
958                            // Leave resetting qps_num to GetRT.
959                            conn_stats[ind].burst += 1;
960                            let query = conns[ind]
961                                .send_request(request_req.request_msg);
962                            // Don't care if send fails
963                            let _ = request_req.tx.send(Ok((
964                                query,
965                                conn_stats[ind].queue_length_plus_one.clone(),
966                            )));
967                        }
968                        None => {
969                            // Don't care if send fails
970                            let _ = request_req
971                                .tx
972                                .send(Err(Error::RedundantTransportNotFound));
973                        }
974                    }
975                }
976                ChanReq::Report(time_report) => {
977                    let opt_ind =
978                        conn_rt.iter().position(|e| e.id == time_report.id);
979                    if let Some(ind) = opt_ind {
980                        conn_stats[ind].update(time_report.elapsed);
981
982                        let est_rt = conn_stats[ind].est_rt();
983                        conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
984                    }
985                }
986                ChanReq::Failure(time_report) => {
987                    let opt_ind =
988                        conn_rt.iter().position(|e| e.id == time_report.id);
989                    if let Some(ind) = opt_ind {
990                        let elapsed = time_report.elapsed.as_secs_f64();
991                        if elapsed < conn_stats[ind].mean {
992                            // Do not update the mean if a
993                            // failure took less time than the
994                            // current mean.
995                            continue;
996                        }
997                        conn_stats[ind].update(time_report.elapsed);
998                        let est_rt = conn_stats[ind].est_rt();
999                        conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
1000                    }
1001                }
1002            }
1003        }
1004    }
1005}
1006
1007//------------ Utility --------------------------------------------------------
1008
1009/// Async function to send a request and wait for the reply.
1010///
1011/// This gives a single future that we can put in a list.
1012async fn start_request<Req>(
1013    index: usize,
1014    id: u64,
1015    sender: mpsc::Sender<ChanReq<Req>>,
1016    request_msg: Req,
1017) -> (usize, Result<Message<Bytes>, Error>)
1018where
1019    Req: Send + Sync,
1020{
1021    let (tx, rx) = oneshot::channel();
1022    sender
1023        .send(ChanReq::Query(RequestReq {
1024            id,
1025            request_msg,
1026            tx,
1027        }))
1028        .await
1029        .expect("receiver still exists");
1030    let (mut request, qlp1) =
1031        match rx.await.expect("receive is expected to work") {
1032            Err(err) => return (index, Err(err)),
1033            Ok((request, qlp1)) => (request, qlp1),
1034        };
1035    let reply = request.get_response().await;
1036
1037    drop(qlp1);
1038    (index, reply)
1039}
1040
1041/// Compare ConnRT elements based on estimated response time.
1042fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT, slow_rt: f64) -> Ordering {
1043    let e1_slow = e1.est_rt.as_secs_f64() > slow_rt;
1044    let e2_slow = e2.est_rt.as_secs_f64() > slow_rt;
1045
1046    match (e1_slow, e2_slow) {
1047        (true, true) => {
1048            // Normal case. First check queue lengths. Then check est_rt.
1049            e1.queue_length
1050                .cmp(&e2.queue_length)
1051                .then(e1.est_rt.cmp(&e2.est_rt))
1052        }
1053        (true, false) => Ordering::Greater,
1054        (false, true) => Ordering::Less,
1055        (false, false) => e1.est_rt.cmp(&e2.est_rt),
1056    }
1057}
1058
1059/// Return if this reply should be skipped or not.
1060fn skip<Octs: Octets>(msg: &Message<Octs>, config: &Config) -> bool {
1061    // Check if we actually need to check.
1062    if !config.defer_refused && !config.defer_servfail {
1063        return false;
1064    }
1065
1066    let opt_rcode = msg.opt_rcode();
1067    // OptRcode needs PartialEq
1068    if let OptRcode::REFUSED = opt_rcode {
1069        if config.defer_refused {
1070            return true;
1071        }
1072    }
1073    if let OptRcode::SERVFAIL = opt_rcode {
1074        if config.defer_servfail {
1075            return true;
1076        }
1077    }
1078
1079    false
1080}
1081
1082/// Generate a SERVFAIL reply message.
1083// This needs to be consolodated with the one in validator and the one in
1084// MessageBuilder.
1085fn serve_fail<Octs>(msg: &Message<Octs>) -> Result<Message<Bytes>, Error>
1086where
1087    Octs: AsRef<[u8]> + Octets,
1088{
1089    let mut target =
1090        MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
1091            .expect("Vec is expected to have enough space");
1092
1093    let source = msg;
1094
1095    *target.header_mut() = msg.header();
1096    target.header_mut().set_rcode(Rcode::SERVFAIL);
1097    target.header_mut().set_ad(false);
1098
1099    let source = source.question();
1100    let mut target = target.question();
1101    for rr in source {
1102        target.push(rr?).expect("should not fail");
1103    }
1104    let mut target = target.additional();
1105
1106    if let Some(opt) = msg.opt() {
1107        target
1108            .opt(|ob| {
1109                ob.set_dnssec_ok(opt.dnssec_ok());
1110                // XXX something is missing ob.set_rcode(opt.rcode());
1111                ob.set_udp_payload_size(opt.udp_payload_size());
1112                ob.set_version(opt.version());
1113                for o in opt.opt().iter() {
1114                    let x: AllOptData<_, _> = o.expect("should not fail");
1115                    ob.push(&x).expect("should not fail");
1116                }
1117                Ok(())
1118            })
1119            .expect("should not fail");
1120    }
1121
1122    let result = target.as_builder().clone();
1123    let msg = Message::<Bytes>::from_octets(
1124        result.finish().into_target().octets_into(),
1125    )
1126    .expect("Message should be able to parse output from MessageBuilder");
1127    Ok(msg)
1128}