domain/net/client/load_balancer.rs
1//! A transport that tries to distribute requests over multiple upstreams.
2//!
3//! It is assumed that the upstreams have similar performance. use the
4//! [super::redundant] transport to forward requests to the best upstream out of
5//! upstreams that may have quite different performance.
6//!
7//! Basic mode of operation
8//!
9//! Associated with every upstream configured is optionally a burst length
10//! and burst interval. Burst length deviced by burst interval gives a
11//! queries per second (QPS) value. This be use to limit the rate and
12//! especially the bursts that reach upstream servers. Once the burst
13//! length has been reach, the upstream receives no new requests until
14//! the burst interval has completed.
15//!
16//! For each upstream the object maintains an estimated response time.
17//! with the configuration value slow_rt_factor, the group of upstream
18//! that have not exceeded their burst length are divided into a 'fast'
19//! and a 'slow' group. The slow group are those upstream that have an
20//! estimated response time that is higher than slow_rt_factor times the
21//! lowest estimated response time. Slow upstream are considered only when
22//! all fast upstream failed to provide a suitable response.
23//!
24//! Within the group of fast upstreams, the ones with the lower queue
25//! length are preferred. This tries to give each of the fast upstreams
26//! an equal number of outstanding requests.
27//!
28//! Within a group of fast upstreams with the same queue length, the
29//! one with the lowest estimated response time is preferred.
30//!
31//! Probing
32//!
33//! Upstream with high estimated response times may be get any traffic and
34//! therefore the estimated response time may remain high. Probing is
35//! intended to solve that problem. Using a random number generator,
36//! occasionally an upstream is selected for probing. If the selected
37//! upstream currently has a non-zero queue then probing is not needed and
38//! no probe will happen.
39//! Otherwise, the upstream to be probed is selected first with an
40//! estimated response time equal to the lowest one. If the probed upstream
41//! does not provide a response within that time, the otherwise best upstream
42//! also gets the request. If the probes upstream provides a suitable response
43//! before the next upstream then its estimated will be updated.
44
45use crate::base::iana::OptRcode;
46use crate::base::iana::Rcode;
47use crate::base::opt::AllOptData;
48use crate::base::Message;
49use crate::base::MessageBuilder;
50use crate::base::StaticCompressor;
51use crate::dep::octseq::OctetsInto;
52use crate::net::client::request::ComposeRequest;
53use crate::net::client::request::{Error, GetResponse, SendRequest};
54use crate::utils::config::DefMinMax;
55use bytes::Bytes;
56use futures_util::stream::FuturesUnordered;
57use futures_util::StreamExt;
58use octseq::Octets;
59use rand::random;
60use std::boxed::Box;
61use std::cmp::Ordering;
62use std::fmt::{Debug, Formatter};
63use std::future::Future;
64use std::pin::Pin;
65use std::string::String;
66use std::string::ToString;
67use std::sync::Arc;
68use std::vec::Vec;
69use tokio::sync::{mpsc, oneshot};
70use tokio::time::{sleep_until, Duration, Instant};
71
72/*
73Basic algorithm:
74- try to distribute requests over all upstreams subject to some limitations.
75- limit bursts
76 - record the start of a burst interval when a request goes out over an
77 upstream
78 - record the number of requests since the start of the burst interval
79 - in the burst is larger than the maximum configured by the user then the
80 upstream is no longer available.
81 - start a new burst interval when enough time has passed.
82- prefer fast upstreams over slow upstreams
83 - maintain a response time estimate for each upstream
84 - upstreams with an estimate response time larger than slow_rt_factor
85 times the lowest estimated response time are consider slow.
86 - 'fast' upstreams are preferred over slow upstream. However slow upstreams
87 are considered if during a single request all fast upstreams fail.
88- prefer fast upstream with a low queue length
89 - maintain a counter with the number of current outstanding requests on an
90 upstream.
91 - prefer the upstream with the lowest count.
92 - preset the upstream with the lowest estimated response time in case
93 two or more upstreams have the same count.
94
95Execution:
96- set a timer to the expect response time.
97- if the timer expires before reply arrives, send the query to the next lowest
98 and set a timer
99- when a reply arrives update the expected response time for the relevant
100 upstream and for the ones that failed.
101
102Probing:
103- upstream that currently have outstanding requests do not need to be
104 probed.
105- for idle upstream, based on a random number generator:
106 - pick a different upstream rather then the best
107 - but set the timer to the expected response time of the best.
108 - maybe we need a configuration parameter for the amound of head start
109 given to the probed upstream.
110*/
111
112/// Capacity of the channel that transports [ChanReq].
113const DEF_CHAN_CAP: usize = 8;
114
115/// Time in milliseconds for the initial response time estimate.
116const DEFAULT_RT_MS: u64 = 300;
117
118/// The initial response time estimate for unused connections.
119const DEFAULT_RT: Duration = Duration::from_millis(DEFAULT_RT_MS);
120
121/// Maintain a moving average for the measured response time and the
122/// square of that. The window is SMOOTH_N.
123const SMOOTH_N: f64 = 8.;
124
125/// Chance to probe a worse connection.
126const PROBE_P: f64 = 0.05;
127
128//------------ Configuration Constants ----------------------------------------
129
130/// Cut off for slow upstreams.
131const DEF_SLOW_RT_FACTOR: f64 = 5.0;
132
133/// Minimum value for the cut off factor.
134const MIN_SLOW_RT_FACTOR: f64 = 1.0;
135
136/// Interval for limiting upstream query bursts.
137const BURST_INTERVAL: DefMinMax<Duration> = DefMinMax::new(
138 Duration::from_secs(1),
139 Duration::from_millis(1),
140 Duration::from_secs(3600),
141);
142
143//------------ Config ---------------------------------------------------------
144
145/// User configuration variables.
146#[derive(Clone, Copy, Debug)]
147pub struct Config {
148 /// Defer transport errors.
149 defer_transport_error: bool,
150
151 /// Defer replies that report Refused.
152 defer_refused: bool,
153
154 /// Defer replies that report ServFail.
155 defer_servfail: bool,
156
157 /// Cut-off for slow upstreams as a factor of the fastest upstream.
158 slow_rt_factor: f64,
159}
160
161impl Config {
162 /// Return the value of the defer_transport_error configuration variable.
163 pub fn defer_transport_error(&self) -> bool {
164 self.defer_transport_error
165 }
166
167 /// Set the value of the defer_transport_error configuration variable.
168 pub fn set_defer_transport_error(&mut self, value: bool) {
169 self.defer_transport_error = value
170 }
171
172 /// Return the value of the defer_refused configuration variable.
173 pub fn defer_refused(&self) -> bool {
174 self.defer_refused
175 }
176
177 /// Set the value of the defer_refused configuration variable.
178 pub fn set_defer_refused(&mut self, value: bool) {
179 self.defer_refused = value
180 }
181
182 /// Return the value of the defer_servfail configuration variable.
183 pub fn defer_servfail(&self) -> bool {
184 self.defer_servfail
185 }
186
187 /// Set the value of the defer_servfail configuration variable.
188 pub fn set_defer_servfail(&mut self, value: bool) {
189 self.defer_servfail = value
190 }
191
192 /// Set the value of the slow_rt_factor configuration variable.
193 pub fn slow_rt_factor(&self) -> f64 {
194 self.slow_rt_factor
195 }
196
197 /// Set the value of the slow_rt_factor configuration variable.
198 pub fn set_slow_rt_factor(&mut self, mut value: f64) {
199 if value < MIN_SLOW_RT_FACTOR {
200 value = MIN_SLOW_RT_FACTOR
201 };
202 self.slow_rt_factor = value;
203 }
204}
205
206impl Default for Config {
207 fn default() -> Self {
208 Self {
209 defer_transport_error: Default::default(),
210 defer_refused: Default::default(),
211 defer_servfail: Default::default(),
212 slow_rt_factor: DEF_SLOW_RT_FACTOR,
213 }
214 }
215}
216
217//------------ ConnConfig -----------------------------------------------------
218
219/// Configuration variables for each upstream.
220#[derive(Clone, Copy, Debug, Default)]
221pub struct ConnConfig {
222 /// Maximum burst of upstream queries.
223 max_burst: Option<u64>,
224
225 /// Interval over which the burst is counted.
226 burst_interval: Duration,
227}
228
229impl ConnConfig {
230 /// Create a new ConnConfig object.
231 pub fn new() -> Self {
232 Self {
233 max_burst: None,
234 burst_interval: BURST_INTERVAL.default(),
235 }
236 }
237
238 /// Return the current configuration value for the maximum burst.
239 /// None means that there is no limit.
240 pub fn max_burst(&mut self) -> Option<u64> {
241 self.max_burst
242 }
243
244 /// Set the configuration value for the maximum burst.
245 /// The value None means no limit.
246 pub fn set_max_burst(&mut self, max_burst: Option<u64>) {
247 self.max_burst = max_burst;
248 }
249
250 /// Return the current burst interval.
251 pub fn burst_interval(&mut self) -> Duration {
252 self.burst_interval
253 }
254
255 /// Set a new burst interval.
256 ///
257 /// The interval is silently limited to at least 1 millesecond and
258 /// at most 1 hour.
259 pub fn set_burst_interval(&mut self, burst_interval: Duration) {
260 self.burst_interval = BURST_INTERVAL.limit(burst_interval);
261 }
262}
263
264//------------ Connection -----------------------------------------------------
265
266/// This type represents a transport connection.
267#[derive(Debug)]
268pub struct Connection<Req>
269where
270 Req: Send + Sync,
271{
272 /// User configuation.
273 config: Config,
274
275 /// To send a request to the runner.
276 sender: mpsc::Sender<ChanReq<Req>>,
277}
278
279impl<Req: Clone + Debug + Send + Sync + 'static> Connection<Req> {
280 /// Create a new connection.
281 pub fn new() -> (Self, Transport<Req>) {
282 Self::with_config(Default::default())
283 }
284
285 /// Create a new connection with a given config.
286 pub fn with_config(config: Config) -> (Self, Transport<Req>) {
287 let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
288 (Self { config, sender }, Transport::new(receiver))
289 }
290
291 /// Add a transport connection.
292 pub async fn add(
293 &self,
294 label: &str,
295 config: &ConnConfig,
296 conn: Box<dyn SendRequest<Req> + Send + Sync>,
297 ) -> Result<(), Error> {
298 let (tx, rx) = oneshot::channel();
299 self.sender
300 .send(ChanReq::Add(AddReq {
301 label: label.to_string(),
302 max_burst: config.max_burst,
303 burst_interval: config.burst_interval,
304 conn,
305 tx,
306 }))
307 .await
308 .expect("send should not fail");
309 rx.await.expect("receive should not fail")
310 }
311
312 /// Implementation of the query method.
313 async fn request_impl(
314 self,
315 request_msg: Req,
316 ) -> Result<Message<Bytes>, Error>
317 where
318 Req: ComposeRequest,
319 {
320 let (tx, rx) = oneshot::channel();
321 self.sender
322 .send(ChanReq::GetRT(RTReq { tx }))
323 .await
324 .expect("send should not fail");
325 let conn_rt = rx.await.expect("receive should not fail")?;
326 if conn_rt.is_empty() {
327 return serve_fail(&request_msg.to_message().unwrap());
328 }
329 Query::new(self.config, request_msg, conn_rt, self.sender.clone())
330 .get_response()
331 .await
332 }
333}
334
335impl<Req> Clone for Connection<Req>
336where
337 Req: Send + Sync,
338{
339 fn clone(&self) -> Self {
340 Self {
341 config: self.config,
342 sender: self.sender.clone(),
343 }
344 }
345}
346
347impl<Req: Clone + ComposeRequest + Debug + Send + Sync + 'static>
348 SendRequest<Req> for Connection<Req>
349{
350 fn send_request(
351 &self,
352 request_msg: Req,
353 ) -> Box<dyn GetResponse + Send + Sync> {
354 Box::new(Request {
355 fut: Box::pin(self.clone().request_impl(request_msg)),
356 })
357 }
358}
359
360//------------ Request -------------------------------------------------------
361
362/// An active request.
363struct Request {
364 /// The underlying future.
365 fut: Pin<
366 Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
367 >,
368}
369
370impl Request {
371 /// Async function that waits for the future stored in Query to complete.
372 async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
373 (&mut self.fut).await
374 }
375}
376
377impl GetResponse for Request {
378 fn get_response(
379 &mut self,
380 ) -> Pin<
381 Box<
382 dyn Future<Output = Result<Message<Bytes>, Error>>
383 + Send
384 + Sync
385 + '_,
386 >,
387 > {
388 Box::pin(self.get_response_impl())
389 }
390}
391
392impl Debug for Request {
393 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
394 f.debug_struct("Request")
395 .field("fut", &format_args!("_"))
396 .finish()
397 }
398}
399
400//------------ Query --------------------------------------------------------
401
402/// This type represents an active query request.
403#[derive(Debug)]
404struct Query<Req>
405where
406 Req: Send + Sync,
407{
408 /// User configuration.
409 config: Config,
410
411 /// The state of the query
412 state: QueryState,
413
414 /// The request message
415 request_msg: Req,
416
417 /// List of connections identifiers and estimated response times.
418 conn_rt: Vec<ConnRT>,
419
420 /// Channel to send requests to the run function.
421 sender: mpsc::Sender<ChanReq<Req>>,
422
423 /// List of futures for outstanding requests.
424 fut_list: FuturesUnordered<
425 Pin<Box<dyn Future<Output = FutListOutput> + Send + Sync>>,
426 >,
427
428 /// Transport error that should be reported if nothing better shows
429 /// up.
430 deferred_transport_error: Option<Error>,
431
432 /// Reply that should be returned to the user if nothing better shows
433 /// up.
434 deferred_reply: Option<Message<Bytes>>,
435
436 /// The result from one of the connectons.
437 result: Option<Result<Message<Bytes>, Error>>,
438
439 /// Index of the connection that returned a result.
440 res_index: usize,
441}
442
443/// The various states a query can be in.
444#[derive(Debug)]
445enum QueryState {
446 /// The initial state
447 Init,
448
449 /// Start a request on a specific connection.
450 Probe(usize),
451
452 /// Report the response time for a specific index in the list.
453 Report(usize),
454
455 /// Wait for one of the requests to finish.
456 Wait,
457}
458
459/// The commands that can be sent to the run function.
460enum ChanReq<Req>
461where
462 Req: Send + Sync,
463{
464 /// Add a connection
465 Add(AddReq<Req>),
466
467 /// Get the list of estimated response times for all connections
468 GetRT(RTReq),
469
470 /// Start a query
471 Query(RequestReq<Req>),
472
473 /// Report how long it took to get a response
474 Report(TimeReport),
475
476 /// Report that a connection failed to provide a timely response
477 Failure(TimeReport),
478}
479
480impl<Req> Debug for ChanReq<Req>
481where
482 Req: Send + Sync,
483{
484 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
485 f.debug_struct("ChanReq").finish()
486 }
487}
488
489/// Request to add a new connection
490struct AddReq<Req> {
491 /// Name of new connection
492 label: String,
493
494 /// Maximum length of a burst.
495 max_burst: Option<u64>,
496
497 /// Interval over which bursts are counted.
498 burst_interval: Duration,
499
500 /// New connection to add
501 conn: Box<dyn SendRequest<Req> + Send + Sync>,
502
503 /// Channel to send the reply to
504 tx: oneshot::Sender<AddReply>,
505}
506
507/// Reply to an Add request
508type AddReply = Result<(), Error>;
509
510/// Request to give the estimated response times for all connections
511struct RTReq /*<Octs>*/ {
512 /// Channel to send the reply to
513 tx: oneshot::Sender<RTReply>,
514}
515
516/// Reply to a RT request
517type RTReply = Result<Vec<ConnRT>, Error>;
518
519/// Request to start a request
520struct RequestReq<Req>
521where
522 Req: Send + Sync,
523{
524 /// Identifier of connection
525 id: u64,
526
527 /// Request message
528 request_msg: Req,
529
530 /// Channel to send the reply to
531 tx: oneshot::Sender<RequestReply>,
532}
533
534impl<Req: Debug> Debug for RequestReq<Req>
535where
536 Req: Send + Sync,
537{
538 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
539 f.debug_struct("RequestReq")
540 .field("id", &self.id)
541 .field("request_msg", &self.request_msg)
542 .finish()
543 }
544}
545
546/// Reply to a request request.
547type RequestReply =
548 Result<(Box<dyn GetResponse + Send + Sync>, Arc<()>), Error>;
549
550/// Report the amount of time until success or failure.
551#[derive(Debug)]
552struct TimeReport {
553 /// Identifier of the transport connection.
554 id: u64,
555
556 /// Time spend waiting for a reply.
557 elapsed: Duration,
558}
559
560/// Connection statistics to compute the estimated response time.
561struct ConnStats {
562 /// Name of the connection.
563 _label: String,
564
565 /// Aproximation of the windowed average of response times.
566 mean: f64,
567
568 /// Aproximation of the windowed average of the square of response times.
569 mean_sq: f64,
570
571 /// Maximum upstream query burst.
572 max_burst: Option<u64>,
573
574 /// burst length,
575 burst_interval: Duration,
576
577 /// Start of the current burst
578 burst_start: Instant,
579
580 /// Number of queries since the start of the burst.
581 burst: u64,
582
583 /// Use the number of references to an Arc as queue length. The number
584 /// of references is one higher than then actual queue length.
585 queue_length_plus_one: Arc<()>,
586}
587
588impl ConnStats {
589 /// Update response time statistics.
590 fn update(&mut self, elapsed: Duration) {
591 let elapsed = elapsed.as_secs_f64();
592 self.mean += (elapsed - self.mean) / SMOOTH_N;
593 let elapsed_sq = elapsed * elapsed;
594 self.mean_sq += (elapsed_sq - self.mean_sq) / SMOOTH_N;
595 }
596
597 /// Get an estimated response time.
598 fn est_rt(&self) -> f64 {
599 let mean = self.mean;
600 let var = self.mean_sq - mean * mean;
601 let std_dev = f64::sqrt(var.max(0.));
602 mean + 3. * std_dev
603 }
604}
605
606/// Data required to schedule requests and report timing results.
607#[derive(Clone, Debug)]
608struct ConnRT {
609 /// Estimated response time.
610 est_rt: Duration,
611
612 /// Identifier of the connection.
613 id: u64,
614
615 /// Start of a request using this connection.
616 start: Option<Instant>,
617
618 /// Use the number of references to an Arc as queue length. The number
619 /// of references is one higher than then actual queue length.
620 queue_length: usize,
621}
622
623/// Result of the futures in fut_list.
624type FutListOutput = (usize, Result<Message<Bytes>, Error>);
625
626impl<Req: Clone + Send + Sync + 'static> Query<Req> {
627 /// Create a new query object.
628 fn new(
629 config: Config,
630 request_msg: Req,
631 mut conn_rt: Vec<ConnRT>,
632 sender: mpsc::Sender<ChanReq<Req>>,
633 ) -> Self {
634 let conn_rt_len = conn_rt.len();
635 let min_rt = conn_rt.iter().map(|e| e.est_rt).min().unwrap();
636 let slow_rt = min_rt.as_secs_f64() * config.slow_rt_factor;
637 conn_rt.sort_unstable_by(|e1, e2| conn_rt_cmp(e1, e2, slow_rt));
638
639 // Do we want to probe a less performant upstream? We only need to
640 // probe upstreams with a queue length of zero. If the queue length
641 // is non-zero then the upstream recently got work and does not need
642 // to be probed.
643 if conn_rt_len > 1 && random::<f64>() < PROBE_P {
644 let index: usize = 1 + random::<usize>() % (conn_rt_len - 1);
645
646 if conn_rt[index].queue_length == 0 {
647 // Give the probe some head start. We may need a separate
648 // configuration parameter. A multiple of min_rt. Just use
649 // min_rt for now.
650 let mut e = conn_rt.remove(index);
651 e.est_rt = min_rt;
652 conn_rt.insert(0, e);
653 }
654 }
655
656 Self {
657 config,
658 request_msg,
659 conn_rt,
660 sender,
661 state: QueryState::Init,
662 fut_list: FuturesUnordered::new(),
663 deferred_transport_error: None,
664 deferred_reply: None,
665 result: None,
666 res_index: 0,
667 }
668 }
669
670 /// Implementation of get_response.
671 async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
672 loop {
673 match self.state {
674 QueryState::Init => {
675 if self.conn_rt.is_empty() {
676 return Err(Error::NoTransportAvailable);
677 }
678 self.state = QueryState::Probe(0);
679 continue;
680 }
681 QueryState::Probe(ind) => {
682 self.conn_rt[ind].start = Some(Instant::now());
683 let fut = start_request(
684 ind,
685 self.conn_rt[ind].id,
686 self.sender.clone(),
687 self.request_msg.clone(),
688 );
689 self.fut_list.push(Box::pin(fut));
690 let timeout = Instant::now() + self.conn_rt[ind].est_rt;
691 loop {
692 tokio::select! {
693 res = self.fut_list.next() => {
694 let res = res.expect("res should not be empty");
695 match res.1 {
696 Err(ref err) => {
697 if self.config.defer_transport_error {
698 if self.deferred_transport_error.is_none() {
699 self.deferred_transport_error = Some(err.clone());
700 }
701 if res.0 == ind {
702 // The current upstream finished,
703 // try the next one, if any.
704 self.state =
705 if ind+1 < self.conn_rt.len() {
706 QueryState::Probe(ind+1)
707 }
708 else
709 {
710 QueryState::Wait
711 };
712 // Break out of receive loop
713 break;
714 }
715 // Just continue receiving
716 continue;
717 }
718 // Return error to the user.
719 }
720 Ok(ref msg) => {
721 if skip(msg, &self.config) {
722 if self.deferred_reply.is_none() {
723 self.deferred_reply = Some(msg.clone());
724 }
725 if res.0 == ind {
726 // The current upstream finished,
727 // try the next one, if any.
728 self.state =
729 if ind+1 < self.conn_rt.len() {
730 QueryState::Probe(ind+1)
731 }
732 else
733 {
734 QueryState::Wait
735 };
736 // Break out of receive loop
737 break;
738 }
739 // Just continue receiving
740 continue;
741 }
742 // Now we have a reply that can be
743 // returned to the user.
744 }
745 }
746 self.result = Some(res.1);
747 self.res_index = res.0;
748
749 self.state = QueryState::Report(0);
750 // Break out of receive loop
751 break;
752 }
753 _ = sleep_until(timeout) => {
754 // Move to the next Probe state if there
755 // are more upstreams to try, otherwise
756 // move to the Wait state.
757 self.state =
758 if ind+1 < self.conn_rt.len() {
759 QueryState::Probe(ind+1)
760 }
761 else {
762 QueryState::Wait
763 };
764 // Break out of receive loop
765 break;
766 }
767 }
768 }
769 // Continue with state machine loop
770 continue;
771 }
772 QueryState::Report(ind) => {
773 if ind >= self.conn_rt.len()
774 || self.conn_rt[ind].start.is_none()
775 {
776 // Nothing more to report. Return result.
777 let res = self
778 .result
779 .take()
780 .expect("result should not be empty");
781 return res;
782 }
783
784 let start = self.conn_rt[ind]
785 .start
786 .expect("start time should not be empty");
787 let elapsed = start.elapsed();
788 let time_report = TimeReport {
789 id: self.conn_rt[ind].id,
790 elapsed,
791 };
792 let report = if ind == self.res_index {
793 // Succesfull entry
794 ChanReq::Report(time_report)
795 } else {
796 // Failed entry
797 ChanReq::Failure(time_report)
798 };
799
800 // Send could fail but we don't care.
801 let _ = self.sender.send(report).await;
802
803 self.state = QueryState::Report(ind + 1);
804 continue;
805 }
806 QueryState::Wait => {
807 loop {
808 if self.fut_list.is_empty() {
809 // We have nothing left. There should be a reply or
810 // an error. Prefer a reply over an error.
811 if self.deferred_reply.is_some() {
812 let msg = self
813 .deferred_reply
814 .take()
815 .expect("just checked for Some");
816 return Ok(msg);
817 }
818 if self.deferred_transport_error.is_some() {
819 let err = self
820 .deferred_transport_error
821 .take()
822 .expect("just checked for Some");
823 return Err(err);
824 }
825 panic!("either deferred_reply or deferred_error should be present");
826 }
827 let res = self.fut_list.next().await;
828 let res = res.expect("res should not be empty");
829 match res.1 {
830 Err(ref err) => {
831 if self.config.defer_transport_error {
832 if self.deferred_transport_error.is_none()
833 {
834 self.deferred_transport_error =
835 Some(err.clone());
836 }
837 // Just continue with the next future, or
838 // finish if fut_list is empty.
839 continue;
840 }
841 // Return error to the user.
842 }
843 Ok(ref msg) => {
844 if skip(msg, &self.config) {
845 if self.deferred_reply.is_none() {
846 self.deferred_reply =
847 Some(msg.clone());
848 }
849 // Just continue with the next future, or
850 // finish if fut_list is empty.
851 continue;
852 }
853 // Return reply to user.
854 }
855 }
856 self.result = Some(res.1);
857 self.res_index = res.0;
858 self.state = QueryState::Report(0);
859 // Break out of loop to continue with the state machine
860 break;
861 }
862 continue;
863 }
864 }
865 }
866 }
867}
868
869//------------ Transport -----------------------------------------------------
870
871/// Type that actually implements the connection.
872#[derive(Debug)]
873pub struct Transport<Req>
874where
875 Req: Send + Sync,
876{
877 /// Receive side of the channel used by the runner.
878 receiver: mpsc::Receiver<ChanReq<Req>>,
879}
880
881impl<Req: Clone + Send + Sync + 'static> Transport<Req> {
882 /// Implementation of the new method.
883 fn new(receiver: mpsc::Receiver<ChanReq<Req>>) -> Self {
884 Self { receiver }
885 }
886
887 /// Run method.
888 pub async fn run(mut self) {
889 let mut next_id: u64 = 10;
890 let mut conn_stats: Vec<ConnStats> = Vec::new();
891 let mut conn_rt: Vec<ConnRT> = Vec::new();
892 let mut conns: Vec<Box<dyn SendRequest<Req> + Send + Sync>> =
893 Vec::new();
894
895 loop {
896 let req = match self.receiver.recv().await {
897 Some(req) => req,
898 None => break, // All references to connection objects are
899 // dropped. Shutdown.
900 };
901 match req {
902 ChanReq::Add(add_req) => {
903 let id = next_id;
904 next_id += 1;
905 conn_stats.push(ConnStats {
906 _label: add_req.label,
907 mean: (DEFAULT_RT_MS as f64) / 1000.,
908 mean_sq: 0.,
909 max_burst: add_req.max_burst,
910 burst_interval: add_req.burst_interval,
911 burst_start: Instant::now(),
912 burst: 0,
913 queue_length_plus_one: Arc::new(()),
914 });
915 conn_rt.push(ConnRT {
916 id,
917 est_rt: DEFAULT_RT,
918 start: None,
919 queue_length: 42, // To spot errors.
920 });
921 conns.push(add_req.conn);
922
923 // Don't care if send fails
924 let _ = add_req.tx.send(Ok(()));
925 }
926 ChanReq::GetRT(rt_req) => {
927 let mut tmp_conn_rt = conn_rt.clone();
928
929 // Remove entries that exceed the QPS limit. Loop
930 // backward to efficiently remove them.
931 for i in (0..tmp_conn_rt.len()).rev() {
932 // Fill-in current queue length.
933 tmp_conn_rt[i].queue_length = Arc::strong_count(
934 &conn_stats[i].queue_length_plus_one,
935 ) - 1;
936 if let Some(max_burst) = conn_stats[i].max_burst {
937 if conn_stats[i].burst_start.elapsed()
938 > conn_stats[i].burst_interval
939 {
940 conn_stats[i].burst_start = Instant::now();
941 conn_stats[i].burst = 0;
942 }
943 if conn_stats[i].burst > max_burst {
944 tmp_conn_rt.swap_remove(i);
945 }
946 } else {
947 // No limit.
948 }
949 }
950 // Don't care if send fails
951 let _ = rt_req.tx.send(Ok(tmp_conn_rt));
952 }
953 ChanReq::Query(request_req) => {
954 let opt_ind =
955 conn_rt.iter().position(|e| e.id == request_req.id);
956 match opt_ind {
957 Some(ind) => {
958 // Leave resetting qps_num to GetRT.
959 conn_stats[ind].burst += 1;
960 let query = conns[ind]
961 .send_request(request_req.request_msg);
962 // Don't care if send fails
963 let _ = request_req.tx.send(Ok((
964 query,
965 conn_stats[ind].queue_length_plus_one.clone(),
966 )));
967 }
968 None => {
969 // Don't care if send fails
970 let _ = request_req
971 .tx
972 .send(Err(Error::RedundantTransportNotFound));
973 }
974 }
975 }
976 ChanReq::Report(time_report) => {
977 let opt_ind =
978 conn_rt.iter().position(|e| e.id == time_report.id);
979 if let Some(ind) = opt_ind {
980 conn_stats[ind].update(time_report.elapsed);
981
982 let est_rt = conn_stats[ind].est_rt();
983 conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
984 }
985 }
986 ChanReq::Failure(time_report) => {
987 let opt_ind =
988 conn_rt.iter().position(|e| e.id == time_report.id);
989 if let Some(ind) = opt_ind {
990 let elapsed = time_report.elapsed.as_secs_f64();
991 if elapsed < conn_stats[ind].mean {
992 // Do not update the mean if a
993 // failure took less time than the
994 // current mean.
995 continue;
996 }
997 conn_stats[ind].update(time_report.elapsed);
998 let est_rt = conn_stats[ind].est_rt();
999 conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
1000 }
1001 }
1002 }
1003 }
1004 }
1005}
1006
1007//------------ Utility --------------------------------------------------------
1008
1009/// Async function to send a request and wait for the reply.
1010///
1011/// This gives a single future that we can put in a list.
1012async fn start_request<Req>(
1013 index: usize,
1014 id: u64,
1015 sender: mpsc::Sender<ChanReq<Req>>,
1016 request_msg: Req,
1017) -> (usize, Result<Message<Bytes>, Error>)
1018where
1019 Req: Send + Sync,
1020{
1021 let (tx, rx) = oneshot::channel();
1022 sender
1023 .send(ChanReq::Query(RequestReq {
1024 id,
1025 request_msg,
1026 tx,
1027 }))
1028 .await
1029 .expect("receiver still exists");
1030 let (mut request, qlp1) =
1031 match rx.await.expect("receive is expected to work") {
1032 Err(err) => return (index, Err(err)),
1033 Ok((request, qlp1)) => (request, qlp1),
1034 };
1035 let reply = request.get_response().await;
1036
1037 drop(qlp1);
1038 (index, reply)
1039}
1040
1041/// Compare ConnRT elements based on estimated response time.
1042fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT, slow_rt: f64) -> Ordering {
1043 let e1_slow = e1.est_rt.as_secs_f64() > slow_rt;
1044 let e2_slow = e2.est_rt.as_secs_f64() > slow_rt;
1045
1046 match (e1_slow, e2_slow) {
1047 (true, true) => {
1048 // Normal case. First check queue lengths. Then check est_rt.
1049 e1.queue_length
1050 .cmp(&e2.queue_length)
1051 .then(e1.est_rt.cmp(&e2.est_rt))
1052 }
1053 (true, false) => Ordering::Greater,
1054 (false, true) => Ordering::Less,
1055 (false, false) => e1.est_rt.cmp(&e2.est_rt),
1056 }
1057}
1058
1059/// Return if this reply should be skipped or not.
1060fn skip<Octs: Octets>(msg: &Message<Octs>, config: &Config) -> bool {
1061 // Check if we actually need to check.
1062 if !config.defer_refused && !config.defer_servfail {
1063 return false;
1064 }
1065
1066 let opt_rcode = msg.opt_rcode();
1067 // OptRcode needs PartialEq
1068 if let OptRcode::REFUSED = opt_rcode {
1069 if config.defer_refused {
1070 return true;
1071 }
1072 }
1073 if let OptRcode::SERVFAIL = opt_rcode {
1074 if config.defer_servfail {
1075 return true;
1076 }
1077 }
1078
1079 false
1080}
1081
1082/// Generate a SERVFAIL reply message.
1083// This needs to be consolodated with the one in validator and the one in
1084// MessageBuilder.
1085fn serve_fail<Octs>(msg: &Message<Octs>) -> Result<Message<Bytes>, Error>
1086where
1087 Octs: AsRef<[u8]> + Octets,
1088{
1089 let mut target =
1090 MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
1091 .expect("Vec is expected to have enough space");
1092
1093 let source = msg;
1094
1095 *target.header_mut() = msg.header();
1096 target.header_mut().set_rcode(Rcode::SERVFAIL);
1097 target.header_mut().set_ad(false);
1098
1099 let source = source.question();
1100 let mut target = target.question();
1101 for rr in source {
1102 target.push(rr?).expect("should not fail");
1103 }
1104 let mut target = target.additional();
1105
1106 if let Some(opt) = msg.opt() {
1107 target
1108 .opt(|ob| {
1109 ob.set_dnssec_ok(opt.dnssec_ok());
1110 // XXX something is missing ob.set_rcode(opt.rcode());
1111 ob.set_udp_payload_size(opt.udp_payload_size());
1112 ob.set_version(opt.version());
1113 for o in opt.opt().iter() {
1114 let x: AllOptData<_, _> = o.expect("should not fail");
1115 ob.push(&x).expect("should not fail");
1116 }
1117 Ok(())
1118 })
1119 .expect("should not fail");
1120 }
1121
1122 let result = target.as_builder().clone();
1123 let msg = Message::<Bytes>::from_octets(
1124 result.finish().into_target().octets_into(),
1125 )
1126 .expect("Message should be able to parse output from MessageBuilder");
1127 Ok(msg)
1128}