domain/net/client/redundant.rs
1//! A transport that multiplexes requests over multiple redundant transports.
2
3use bytes::Bytes;
4
5use futures_util::stream::FuturesUnordered;
6use futures_util::StreamExt;
7
8use octseq::Octets;
9
10use rand::random;
11
12use std::boxed::Box;
13use std::cmp::Ordering;
14use std::fmt::{Debug, Formatter};
15use std::future::Future;
16use std::pin::Pin;
17use std::vec::Vec;
18
19use tokio::sync::{mpsc, oneshot};
20use tokio::time::{sleep_until, Duration, Instant};
21
22use crate::base::iana::OptRcode;
23use crate::base::Message;
24use crate::net::client::request::{Error, GetResponse, SendRequest};
25
26/*
27Basic algorithm:
28- keep track of expected response time for every upstream
29- start with the upstream with the lowest expected response time
30- set a timer to the expect response time.
31- if the timer expires before reply arrives, send the query to the next lowest
32 and set a timer
33- when a reply arrives update the expected response time for the relevant
34 upstream and for the ones that failed.
35
36Based on a random number generator:
37- pick a different upstream rather then the best but set the timer to the
38 expected response time of the best.
39*/
40
41/// Capacity of the channel that transports [ChanReq].
42const DEF_CHAN_CAP: usize = 8;
43
44/// Time in milliseconds for the initial response time estimate.
45const DEFAULT_RT_MS: u64 = 300;
46
47/// The initial response time estimate for unused connections.
48const DEFAULT_RT: Duration = Duration::from_millis(DEFAULT_RT_MS);
49
50/// Maintain a moving average for the measured response time and the
51/// square of that. The window is SMOOTH_N.
52const SMOOTH_N: f64 = 8.;
53
54/// Chance to probe a worse connection.
55const PROBE_P: f64 = 0.05;
56
57//------------ Config ---------------------------------------------------------
58
59/// User configuration variables.
60#[derive(Clone, Copy, Debug, Default)]
61pub struct Config {
62 /// Defer transport errors.
63 defer_transport_error: bool,
64
65 /// Defer replies that report Refused.
66 defer_refused: bool,
67
68 /// Defer replies that report ServFail.
69 defer_servfail: bool,
70}
71
72impl Config {
73 /// Return the value of the defer_transport_error configuration variable.
74 pub fn defer_transport_error(&self) -> bool {
75 self.defer_transport_error
76 }
77
78 /// Set the value of the defer_transport_error configuration variable.
79 pub fn set_defer_transport_error(&mut self, value: bool) {
80 self.defer_transport_error = value
81 }
82
83 /// Return the value of the defer_refused configuration variable.
84 pub fn defer_refused(&self) -> bool {
85 self.defer_refused
86 }
87
88 /// Set the value of the defer_refused configuration variable.
89 pub fn set_defer_refused(&mut self, value: bool) {
90 self.defer_refused = value
91 }
92
93 /// Return the value of the defer_servfail configuration variable.
94 pub fn defer_servfail(&self) -> bool {
95 self.defer_servfail
96 }
97
98 /// Set the value of the defer_servfail configuration variable.
99 pub fn set_defer_servfail(&mut self, value: bool) {
100 self.defer_servfail = value
101 }
102}
103
104//------------ Connection -----------------------------------------------------
105
106/// This type represents a transport connection.
107#[derive(Debug)]
108pub struct Connection<Req>
109where
110 Req: Send + Sync,
111{
112 /// User configuation.
113 config: Config,
114
115 /// To send a request to the runner.
116 sender: mpsc::Sender<ChanReq<Req>>,
117}
118
119impl<Req: Clone + Debug + Send + Sync + 'static> Connection<Req> {
120 /// Create a new connection.
121 pub fn new() -> (Self, Transport<Req>) {
122 Self::with_config(Default::default())
123 }
124
125 /// Create a new connection with a given config.
126 pub fn with_config(config: Config) -> (Self, Transport<Req>) {
127 let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
128 (Self { config, sender }, Transport::new(receiver))
129 }
130
131 /// Add a transport connection.
132 pub async fn add(
133 &self,
134 conn: Box<dyn SendRequest<Req> + Send + Sync>,
135 ) -> Result<(), Error> {
136 let (tx, rx) = oneshot::channel();
137 self.sender
138 .send(ChanReq::Add(AddReq { conn, tx }))
139 .await
140 .expect("send should not fail");
141 rx.await.expect("receive should not fail")
142 }
143
144 /// Implementation of the query method.
145 async fn request_impl(
146 self,
147 request_msg: Req,
148 ) -> Result<Message<Bytes>, Error> {
149 let (tx, rx) = oneshot::channel();
150 self.sender
151 .send(ChanReq::GetRT(RTReq { tx }))
152 .await
153 .expect("send should not fail");
154 let conn_rt = rx.await.expect("receive should not fail")?;
155 Query::new(self.config, request_msg, conn_rt, self.sender.clone())
156 .get_response()
157 .await
158 }
159}
160
161impl<Req> Clone for Connection<Req>
162where
163 Req: Send + Sync,
164{
165 fn clone(&self) -> Self {
166 Self {
167 config: self.config,
168 sender: self.sender.clone(),
169 }
170 }
171}
172
173impl<Req: Clone + Debug + Send + Sync + 'static> SendRequest<Req>
174 for Connection<Req>
175{
176 fn send_request(
177 &self,
178 request_msg: Req,
179 ) -> Box<dyn GetResponse + Send + Sync> {
180 Box::new(Request {
181 fut: Box::pin(self.clone().request_impl(request_msg)),
182 })
183 }
184}
185
186//------------ Request -------------------------------------------------------
187
188/// An active request.
189struct Request {
190 /// The underlying future.
191 fut: Pin<
192 Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
193 >,
194}
195
196impl Request {
197 /// Async function that waits for the future stored in Query to complete.
198 async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
199 (&mut self.fut).await
200 }
201}
202
203impl GetResponse for Request {
204 fn get_response(
205 &mut self,
206 ) -> Pin<
207 Box<
208 dyn Future<Output = Result<Message<Bytes>, Error>>
209 + Send
210 + Sync
211 + '_,
212 >,
213 > {
214 Box::pin(self.get_response_impl())
215 }
216}
217
218impl Debug for Request {
219 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
220 f.debug_struct("Request")
221 .field("fut", &format_args!("_"))
222 .finish()
223 }
224}
225
226//------------ Query --------------------------------------------------------
227
228/// This type represents an active query request.
229#[derive(Debug)]
230struct Query<Req>
231where
232 Req: Send + Sync,
233{
234 /// User configuration.
235 config: Config,
236
237 /// The state of the query
238 state: QueryState,
239
240 /// The request message
241 request_msg: Req,
242
243 /// List of connections identifiers and estimated response times.
244 conn_rt: Vec<ConnRT>,
245
246 /// Channel to send requests to the run function.
247 sender: mpsc::Sender<ChanReq<Req>>,
248
249 /// List of futures for outstanding requests.
250 fut_list: FuturesUnordered<
251 Pin<Box<dyn Future<Output = FutListOutput> + Send + Sync>>,
252 >,
253
254 /// Transport error that should be reported if nothing better shows
255 /// up.
256 deferred_transport_error: Option<Error>,
257
258 /// Reply that should be returned to the user if nothing better shows
259 /// up.
260 deferred_reply: Option<Message<Bytes>>,
261
262 /// The result from one of the connectons.
263 result: Option<Result<Message<Bytes>, Error>>,
264
265 /// Index of the connection that returned a result.
266 res_index: usize,
267}
268
269/// The various states a query can be in.
270#[derive(Debug)]
271enum QueryState {
272 /// The initial state
273 Init,
274
275 /// Start a request on a specific connection.
276 Probe(usize),
277
278 /// Report the response time for a specific index in the list.
279 Report(usize),
280
281 /// Wait for one of the requests to finish.
282 Wait,
283}
284
285/// The commands that can be sent to the run function.
286enum ChanReq<Req>
287where
288 Req: Send + Sync,
289{
290 /// Add a connection
291 Add(AddReq<Req>),
292
293 /// Get the list of estimated response times for all connections
294 GetRT(RTReq),
295
296 /// Start a query
297 Query(RequestReq<Req>),
298
299 /// Report how long it took to get a response
300 Report(TimeReport),
301
302 /// Report that a connection failed to provide a timely response
303 Failure(TimeReport),
304}
305
306impl<Req> Debug for ChanReq<Req>
307where
308 Req: Send + Sync,
309{
310 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
311 f.debug_struct("ChanReq").finish()
312 }
313}
314
315/// Request to add a new connection
316struct AddReq<Req> {
317 /// New connection to add
318 conn: Box<dyn SendRequest<Req> + Send + Sync>,
319
320 /// Channel to send the reply to
321 tx: oneshot::Sender<AddReply>,
322}
323
324/// Reply to an Add request
325type AddReply = Result<(), Error>;
326
327/// Request to give the estimated response times for all connections
328struct RTReq /*<Octs>*/ {
329 /// Channel to send the reply to
330 tx: oneshot::Sender<RTReply>,
331}
332
333/// Reply to a RT request
334type RTReply = Result<Vec<ConnRT>, Error>;
335
336/// Request to start a request
337struct RequestReq<Req>
338where
339 Req: Send + Sync,
340{
341 /// Identifier of connection
342 id: u64,
343
344 /// Request message
345 request_msg: Req,
346
347 /// Channel to send the reply to
348 tx: oneshot::Sender<RequestReply>,
349}
350
351impl<Req: Debug> Debug for RequestReq<Req>
352where
353 Req: Send + Sync,
354{
355 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
356 f.debug_struct("RequestReq")
357 .field("id", &self.id)
358 .field("request_msg", &self.request_msg)
359 .finish()
360 }
361}
362
363/// Reply to a request request.
364type RequestReply = Result<Box<dyn GetResponse + Send + Sync>, Error>;
365
366/// Report the amount of time until success or failure.
367#[derive(Debug)]
368struct TimeReport {
369 /// Identifier of the transport connection.
370 id: u64,
371
372 /// Time spend waiting for a reply.
373 elapsed: Duration,
374}
375
376/// Connection statistics to compute the estimated response time.
377struct ConnStats {
378 /// Aproximation of the windowed average of response times.
379 mean: f64,
380
381 /// Aproximation of the windowed average of the square of response times.
382 mean_sq: f64,
383}
384
385/// Data required to schedule requests and report timing results.
386#[derive(Clone, Debug)]
387struct ConnRT {
388 /// Estimated response time.
389 est_rt: Duration,
390
391 /// Identifier of the connection.
392 id: u64,
393
394 /// Start of a request using this connection.
395 start: Option<Instant>,
396}
397
398/// Result of the futures in fut_list.
399type FutListOutput = (usize, Result<Message<Bytes>, Error>);
400
401impl<Req: Clone + Send + Sync + 'static> Query<Req> {
402 /// Create a new query object.
403 fn new(
404 config: Config,
405 request_msg: Req,
406 mut conn_rt: Vec<ConnRT>,
407 sender: mpsc::Sender<ChanReq<Req>>,
408 ) -> Self {
409 let conn_rt_len = conn_rt.len();
410 conn_rt.sort_unstable_by(conn_rt_cmp);
411
412 // Do we want to probe a less performant upstream?
413 if conn_rt_len > 1 && random::<f64>() < PROBE_P {
414 let index: usize = 1 + random::<usize>() % (conn_rt_len - 1);
415
416 // Give the probe some head start. We may need a separate
417 // configuration parameter. A multiple of min_rt. Just use
418 // min_rt for now.
419 let min_rt = conn_rt.iter().map(|e| e.est_rt).min().unwrap();
420
421 let mut e = conn_rt.remove(index);
422 e.est_rt = min_rt;
423 conn_rt.insert(0, e);
424 }
425
426 Self {
427 config,
428 request_msg,
429 conn_rt,
430 sender,
431 state: QueryState::Init,
432 fut_list: FuturesUnordered::new(),
433 deferred_transport_error: None,
434 deferred_reply: None,
435 result: None,
436 res_index: 0,
437 }
438 }
439
440 /// Implementation of get_response.
441 async fn get_response(&mut self) -> Result<Message<Bytes>, Error> {
442 loop {
443 match self.state {
444 QueryState::Init => {
445 if self.conn_rt.is_empty() {
446 return Err(Error::NoTransportAvailable);
447 }
448 self.state = QueryState::Probe(0);
449 continue;
450 }
451 QueryState::Probe(ind) => {
452 self.conn_rt[ind].start = Some(Instant::now());
453 let fut = start_request(
454 ind,
455 self.conn_rt[ind].id,
456 self.sender.clone(),
457 self.request_msg.clone(),
458 );
459 self.fut_list.push(Box::pin(fut));
460 let timeout = Instant::now() + self.conn_rt[ind].est_rt;
461 loop {
462 tokio::select! {
463 res = self.fut_list.next() => {
464 let res = res.expect("res should not be empty");
465 match res.1 {
466 Err(ref err) => {
467 if self.config.defer_transport_error {
468 if self.deferred_transport_error.is_none() {
469 self.deferred_transport_error = Some(err.clone());
470 }
471 if res.0 == ind {
472 // The current upstream finished,
473 // try the next one, if any.
474 self.state =
475 if ind+1 < self.conn_rt.len() {
476 QueryState::Probe(ind+1)
477 }
478 else
479 {
480 QueryState::Wait
481 };
482 // Break out of receive loop
483 break;
484 }
485 // Just continue receiving
486 continue;
487 }
488 // Return error to the user.
489 }
490 Ok(ref msg) => {
491 if skip(msg, &self.config) {
492 if self.deferred_reply.is_none() {
493 self.deferred_reply = Some(msg.clone());
494 }
495 if res.0 == ind {
496 // The current upstream finished,
497 // try the next one, if any.
498 self.state =
499 if ind+1 < self.conn_rt.len() {
500 QueryState::Probe(ind+1)
501 }
502 else
503 {
504 QueryState::Wait
505 };
506 // Break out of receive loop
507 break;
508 }
509 // Just continue receiving
510 continue;
511 }
512 // Now we have a reply that can be
513 // returned to the user.
514 }
515 }
516 self.result = Some(res.1);
517 self.res_index= res.0;
518
519 self.state = QueryState::Report(0);
520 // Break out of receive loop
521 break;
522 }
523 _ = sleep_until(timeout) => {
524 // Move to the next Probe state if there
525 // are more upstreams to try, otherwise
526 // move to the Wait state.
527 self.state =
528 if ind+1 < self.conn_rt.len() {
529 QueryState::Probe(ind+1)
530 }
531 else {
532 QueryState::Wait
533 };
534 // Break out of receive loop
535 break;
536 }
537 }
538 }
539 // Continue with state machine loop
540 continue;
541 }
542 QueryState::Report(ind) => {
543 if ind >= self.conn_rt.len()
544 || self.conn_rt[ind].start.is_none()
545 {
546 // Nothing more to report. Return result.
547 let res = self
548 .result
549 .take()
550 .expect("result should not be empty");
551 return res;
552 }
553
554 let start = self.conn_rt[ind]
555 .start
556 .expect("start time should not be empty");
557 let elapsed = start.elapsed();
558 let time_report = TimeReport {
559 id: self.conn_rt[ind].id,
560 elapsed,
561 };
562 let report = if ind == self.res_index {
563 // Succesfull entry
564 ChanReq::Report(time_report)
565 } else {
566 // Failed entry
567 ChanReq::Failure(time_report)
568 };
569
570 // Send could fail but we don't care.
571 let _ = self.sender.send(report).await;
572
573 self.state = QueryState::Report(ind + 1);
574 continue;
575 }
576 QueryState::Wait => {
577 loop {
578 if self.fut_list.is_empty() {
579 // We have nothing left. There should be a reply or
580 // an error. Prefer a reply over an error.
581 if self.deferred_reply.is_some() {
582 let msg = self
583 .deferred_reply
584 .take()
585 .expect("just checked for Some");
586 return Ok(msg);
587 }
588 if self.deferred_transport_error.is_some() {
589 let err = self
590 .deferred_transport_error
591 .take()
592 .expect("just checked for Some");
593 return Err(err);
594 }
595 panic!("either deferred_reply or deferred_error should be present");
596 }
597 let res = self.fut_list.next().await;
598 let res = res.expect("res should not be empty");
599 match res.1 {
600 Err(ref err) => {
601 if self.config.defer_transport_error {
602 if self.deferred_transport_error.is_none()
603 {
604 self.deferred_transport_error =
605 Some(err.clone());
606 }
607 // Just continue with the next future, or
608 // finish if fut_list is empty.
609 continue;
610 }
611 // Return error to the user.
612 }
613 Ok(ref msg) => {
614 if skip(msg, &self.config) {
615 if self.deferred_reply.is_none() {
616 self.deferred_reply =
617 Some(msg.clone());
618 }
619 // Just continue with the next future, or
620 // finish if fut_list is empty.
621 continue;
622 }
623 // Return reply to user.
624 }
625 }
626 self.result = Some(res.1);
627 self.res_index = res.0;
628 self.state = QueryState::Report(0);
629 // Break out of loop to continue with the state machine
630 break;
631 }
632 continue;
633 }
634 }
635 }
636 }
637}
638
639//------------ Transport -----------------------------------------------------
640
641/// Type that actually implements the connection.
642#[derive(Debug)]
643pub struct Transport<Req>
644where
645 Req: Send + Sync,
646{
647 /// Receive side of the channel used by the runner.
648 receiver: mpsc::Receiver<ChanReq<Req>>,
649}
650
651impl<Req: Clone + Send + Sync + 'static> Transport<Req> {
652 /// Implementation of the new method.
653 fn new(receiver: mpsc::Receiver<ChanReq<Req>>) -> Self {
654 Self { receiver }
655 }
656
657 /// Run method.
658 pub async fn run(mut self) {
659 let mut next_id: u64 = 10;
660 let mut conn_stats: Vec<ConnStats> = Vec::new();
661 let mut conn_rt: Vec<ConnRT> = Vec::new();
662 let mut conns: Vec<Box<dyn SendRequest<Req> + Send + Sync>> =
663 Vec::new();
664
665 loop {
666 let req = match self.receiver.recv().await {
667 Some(req) => req,
668 None => break, // All references to connection objects are
669 // dropped. Shutdown.
670 };
671 match req {
672 ChanReq::Add(add_req) => {
673 let id = next_id;
674 next_id += 1;
675 conn_stats.push(ConnStats {
676 mean: (DEFAULT_RT_MS as f64) / 1000.,
677 mean_sq: 0.,
678 });
679 conn_rt.push(ConnRT {
680 id,
681 est_rt: DEFAULT_RT,
682 start: None,
683 });
684 conns.push(add_req.conn);
685
686 // Don't care if send fails
687 let _ = add_req.tx.send(Ok(()));
688 }
689 ChanReq::GetRT(rt_req) => {
690 // Don't care if send fails
691 let _ = rt_req.tx.send(Ok(conn_rt.clone()));
692 }
693 ChanReq::Query(request_req) => {
694 let opt_ind =
695 conn_rt.iter().position(|e| e.id == request_req.id);
696 match opt_ind {
697 Some(ind) => {
698 let query = conns[ind]
699 .send_request(request_req.request_msg);
700 // Don't care if send fails
701 let _ = request_req.tx.send(Ok(query));
702 }
703 None => {
704 // Don't care if send fails
705 let _ = request_req
706 .tx
707 .send(Err(Error::RedundantTransportNotFound));
708 }
709 }
710 }
711 ChanReq::Report(time_report) => {
712 let opt_ind =
713 conn_rt.iter().position(|e| e.id == time_report.id);
714 if let Some(ind) = opt_ind {
715 let elapsed = time_report.elapsed.as_secs_f64();
716 conn_stats[ind].mean +=
717 (elapsed - conn_stats[ind].mean) / SMOOTH_N;
718 let elapsed_sq = elapsed * elapsed;
719 conn_stats[ind].mean_sq +=
720 (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N;
721 let mean = conn_stats[ind].mean;
722 let var = conn_stats[ind].mean_sq - mean * mean;
723 let std_dev =
724 if var < 0. { 0. } else { f64::sqrt(var) };
725 let est_rt = mean + 3. * std_dev;
726 conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
727 }
728 }
729 ChanReq::Failure(time_report) => {
730 let opt_ind =
731 conn_rt.iter().position(|e| e.id == time_report.id);
732 if let Some(ind) = opt_ind {
733 let elapsed = time_report.elapsed.as_secs_f64();
734 if elapsed < conn_stats[ind].mean {
735 // Do not update the mean if a
736 // failure took less time than the
737 // current mean.
738 continue;
739 }
740 conn_stats[ind].mean +=
741 (elapsed - conn_stats[ind].mean) / SMOOTH_N;
742 let elapsed_sq = elapsed * elapsed;
743 conn_stats[ind].mean_sq +=
744 (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N;
745 let mean = conn_stats[ind].mean;
746 let var = conn_stats[ind].mean_sq - mean * mean;
747 let std_dev =
748 if var < 0. { 0. } else { f64::sqrt(var) };
749 let est_rt = mean + 3. * std_dev;
750 conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt);
751 }
752 }
753 }
754 }
755 }
756}
757
758//------------ Utility --------------------------------------------------------
759
760/// Async function to send a request and wait for the reply.
761///
762/// This gives a single future that we can put in a list.
763async fn start_request<Req>(
764 index: usize,
765 id: u64,
766 sender: mpsc::Sender<ChanReq<Req>>,
767 request_msg: Req,
768) -> (usize, Result<Message<Bytes>, Error>)
769where
770 Req: Send + Sync,
771{
772 let (tx, rx) = oneshot::channel();
773 sender
774 .send(ChanReq::Query(RequestReq {
775 id,
776 request_msg,
777 tx,
778 }))
779 .await
780 .expect("send is expected to work");
781 let mut request = match rx.await.expect("receive is expected to work") {
782 Err(err) => return (index, Err(err)),
783 Ok(request) => request,
784 };
785 let reply = request.get_response().await;
786
787 (index, reply)
788}
789
790/// Compare ConnRT elements based on estimated response time.
791fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering {
792 e1.est_rt.cmp(&e2.est_rt)
793}
794
795/// Return if this reply should be skipped or not.
796fn skip<Octs: Octets>(msg: &Message<Octs>, config: &Config) -> bool {
797 // Check if we actually need to check.
798 if !config.defer_refused && !config.defer_servfail {
799 return false;
800 }
801
802 let opt_rcode = msg.opt_rcode();
803 // OptRcode needs PartialEq
804 if let OptRcode::REFUSED = opt_rcode {
805 if config.defer_refused {
806 return true;
807 }
808 }
809 if let OptRcode::SERVFAIL = opt_rcode {
810 if config.defer_servfail {
811 return true;
812 }
813 }
814
815 false
816}