1use super::request::{
7 ComposeRequest, ComposeRequestMulti, Error, GetResponse,
8 GetResponseMulti, SendRequest, SendRequestMulti,
9};
10use crate::base::iana::{Rcode, Rtype};
11use crate::base::message::Message;
12use crate::base::message_builder::StreamTarget;
13use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive};
14use crate::base::{ParsedName, Serial};
15use crate::rdata::AllRecordData;
16use crate::utils::config::DefMinMax;
17use bytes::{Bytes, BytesMut};
18use core::cmp;
19use octseq::Octets;
20use std::boxed::Box;
21use std::fmt::Debug;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use std::vec::Vec;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28use tokio::sync::{mpsc, oneshot};
29use tokio::time::sleep;
30use tracing::trace;
31
32const RESPONSE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
38 Duration::from_secs(19),
39 Duration::from_millis(1),
40 Duration::from_secs(600),
41);
42
43const IDLE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
58 Duration::from_secs(10),
59 Duration::ZERO,
60 Duration::from_secs(3600),
61);
62
63const DEF_CHAN_CAP: usize = 8;
65
66const READ_REPLY_CHAN_CAP: usize = 8;
68
69#[derive(Clone, Debug)]
73pub struct Config {
74 response_timeout: Duration,
76
77 single_response_timeout: Duration,
79
80 streaming_response_timeout: Duration,
82
83 idle_timeout: Duration,
88}
89
90impl Config {
91 pub fn new() -> Self {
93 Default::default()
94 }
95
96 pub fn response_timeout(&self) -> Duration {
101 self.response_timeout
102 }
103
104 pub fn set_response_timeout(&mut self, timeout: Duration) {
113 self.response_timeout = RESPONSE_TIMEOUT.limit(timeout);
114 self.streaming_response_timeout = self.response_timeout;
115 }
116
117 pub fn streaming_response_timeout(&self) -> Duration {
119 self.streaming_response_timeout
120 }
121
122 pub fn set_streaming_response_timeout(&mut self, timeout: Duration) {
131 self.streaming_response_timeout = RESPONSE_TIMEOUT.limit(timeout);
132 }
133
134 pub fn idle_timeout(&self) -> Duration {
136 self.idle_timeout
137 }
138
139 pub fn set_idle_timeout(&mut self, timeout: Duration) {
153 self.idle_timeout = IDLE_TIMEOUT.limit(timeout)
154 }
155}
156
157impl Default for Config {
158 fn default() -> Self {
159 Self {
160 response_timeout: RESPONSE_TIMEOUT.default(),
161 single_response_timeout: RESPONSE_TIMEOUT.default(),
162 streaming_response_timeout: RESPONSE_TIMEOUT.default(),
163 idle_timeout: IDLE_TIMEOUT.default(),
164 }
165 }
166}
167
168#[derive(Debug)]
172pub struct Connection<Req, ReqMulti> {
173 sender: mpsc::Sender<ChanReq<Req, ReqMulti>>,
175}
176
177impl<Req, ReqMulti> Connection<Req, ReqMulti> {
178 pub fn new<Stream>(
185 stream: Stream,
186 ) -> (Self, Transport<Stream, Req, ReqMulti>) {
187 Self::with_config(stream, Default::default())
188 }
189
190 pub fn with_config<Stream>(
197 stream: Stream,
198 config: Config,
199 ) -> (Self, Transport<Stream, Req, ReqMulti>) {
200 let (sender, transport) = Transport::new(stream, config);
201 (Self { sender }, transport)
202 }
203}
204
205impl<Req, ReqMulti> Connection<Req, ReqMulti>
206where
207 Req: ComposeRequest + 'static,
208 ReqMulti: ComposeRequestMulti + 'static,
209{
210 async fn handle_request_impl(
215 self,
216 msg: Req,
217 ) -> Result<Message<Bytes>, Error> {
218 let (sender, receiver) = oneshot::channel();
219 let sender = ReplySender::Single(Some(sender));
220 let msg = ReqSingleMulti::Single(msg);
221 let req = ChanReq { sender, msg };
222 self.sender.send(req).await.map_err(|_| {
223 Error::ConnectionClosed
226 })?;
227 receiver.await.map_err(|_| Error::StreamReceiveError)?
228 }
229
230 async fn handle_streaming_request_impl(
232 self,
233 msg: ReqMulti,
234 sender: mpsc::Sender<Result<Option<Message<Bytes>>, Error>>,
235 ) -> Result<(), Error> {
236 let reply_sender = ReplySender::Stream(sender);
237 let msg = ReqSingleMulti::Multi(msg);
238 let req = ChanReq {
239 sender: reply_sender,
240 msg,
241 };
242 self.sender.send(req).await.map_err(|_| {
243 Error::ConnectionClosed
246 })?;
247 Ok(())
248 }
249
250 pub fn get_request(&self, request_msg: Req) -> Request {
252 Request {
253 fut: Box::pin(self.clone().handle_request_impl(request_msg)),
254 }
255 }
256
257 fn get_streaming_request(&self, request_msg: ReqMulti) -> RequestMulti {
259 let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
260 RequestMulti {
261 stream: receiver,
262 fut: Some(Box::pin(
263 self.clone()
264 .handle_streaming_request_impl(request_msg, sender),
265 )),
266 }
267 }
268}
269
270impl<Req, ReqMulti> Clone for Connection<Req, ReqMulti> {
271 fn clone(&self) -> Self {
272 Self {
273 sender: self.sender.clone(),
274 }
275 }
276}
277
278impl<Req, ReqMulti> SendRequest<Req> for Connection<Req, ReqMulti>
279where
280 Req: ComposeRequest + 'static,
281 ReqMulti: ComposeRequestMulti + Debug + Send + Sync + 'static,
282{
283 fn send_request(
284 &self,
285 request_msg: Req,
286 ) -> Box<dyn GetResponse + Send + Sync> {
287 Box::new(self.get_request(request_msg))
288 }
289}
290
291impl<Req, ReqMulti> SendRequestMulti<ReqMulti> for Connection<Req, ReqMulti>
292where
293 Req: ComposeRequest + Debug + Send + Sync + 'static,
294 ReqMulti: ComposeRequestMulti + 'static,
295{
296 fn send_request(
297 &self,
298 request_msg: ReqMulti,
299 ) -> Box<dyn GetResponseMulti + Send + Sync> {
300 Box::new(self.get_streaming_request(request_msg))
301 }
302}
303
304pub struct Request {
308 fut: Pin<
310 Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
311 >,
312}
313
314impl Request {
315 async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
317 (&mut self.fut).await
318 }
319}
320
321impl GetResponse for Request {
322 fn get_response(
323 &mut self,
324 ) -> Pin<
325 Box<
326 dyn Future<Output = Result<Message<Bytes>, Error>>
327 + Send
328 + Sync
329 + '_,
330 >,
331 > {
332 Box::pin(self.get_response_impl())
333 }
334}
335
336impl Debug for Request {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 f.debug_struct("Request")
339 .field("fut", &format_args!("_"))
340 .finish()
341 }
342}
343
344pub struct RequestMulti {
348 stream: mpsc::Receiver<Result<Option<Message<Bytes>>, Error>>,
350
351 #[allow(clippy::type_complexity)]
353 fut: Option<
354 Pin<Box<dyn Future<Output = Result<(), Error>> + Send + Sync>>,
355 >,
356}
357
358impl RequestMulti {
359 async fn get_response_impl(
361 &mut self,
362 ) -> Result<Option<Message<Bytes>>, Error> {
363 if self.fut.is_some() {
364 let fut = self.fut.take().expect("Some expected");
365 fut.await?;
366 }
367
368 self.stream
370 .recv()
371 .await
372 .ok_or(Error::ConnectionClosed)
373 .map_err(|_| Error::ConnectionClosed)?
374 }
375}
376
377impl GetResponseMulti for RequestMulti {
378 fn get_response(
379 &mut self,
380 ) -> Pin<
381 Box<
382 dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
383 + Send
384 + Sync
385 + '_,
386 >,
387 > {
388 let fut = self.get_response_impl();
389 Box::pin(fut)
390 }
391}
392
393impl Debug for RequestMulti {
394 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395 f.debug_struct("Request")
396 .field("fut", &format_args!("_"))
397 .finish()
398 }
399}
400
401#[derive(Debug)]
405pub struct Transport<Stream, Req, ReqMulti> {
406 stream: Stream,
408
409 config: Config,
411
412 receiver: mpsc::Receiver<ChanReq<Req, ReqMulti>>,
414}
415
416#[derive(Debug)]
418enum ReplySender {
419 Single(Option<oneshot::Sender<ChanResp>>),
421
422 Stream(mpsc::Sender<Result<Option<Message<Bytes>>, Error>>),
424}
425
426impl ReplySender {
427 async fn send(&mut self, resp: ChanResp) -> Result<(), ()> {
429 match self {
430 ReplySender::Single(sender) => match sender.take() {
431 Some(sender) => sender.send(resp).map_err(|_| ()),
432 None => Err(()),
433 },
434 ReplySender::Stream(sender) => {
435 sender.send(resp.map(Some)).await.map_err(|_| ())
436 }
437 }
438 }
439
440 async fn send_eof(&mut self) -> Result<(), ()> {
442 match self {
443 ReplySender::Single(_) => {
444 panic!("cannot send EOF for Single");
445 }
446 ReplySender::Stream(sender) => {
447 sender.send(Ok(None)).await.map_err(|_| ())
448 }
449 }
450 }
451
452 fn is_stream(&self) -> bool {
454 matches!(self, Self::Stream(_))
455 }
456}
457
458#[derive(Debug)]
459enum ReqSingleMulti<Req, ReqMulti> {
462 Single(Req),
464 Multi(ReqMulti),
466}
467
468#[derive(Debug)]
470struct ChanReq<Req, ReqMulti> {
471 msg: ReqSingleMulti<Req, ReqMulti>,
473
474 sender: ReplySender,
476}
477
478type ChanResp = Result<Message<Bytes>, Error>;
480
481struct Status {
485 state: ConnState,
487
488 send_keepalive: bool,
494
495 idle_timeout: Duration,
500}
501
502enum ConnState {
504 Active(Option<Instant>),
510
511 Idle(Instant),
516
517 IdleTimeout,
520
521 ReadError(Error),
523
524 ReadTimeout,
526
527 WriteError(Error),
529}
530
531impl std::fmt::Display for ConnState {
533 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534 match self {
535 ConnState::Active(instant) => f.write_fmt(format_args!(
536 "Active (since {}s ago)",
537 instant
538 .map(|v| Instant::now().duration_since(v).as_secs())
539 .unwrap_or_default()
540 )),
541 ConnState::Idle(instant) => f.write_fmt(format_args!(
542 "Idle (since {}s ago)",
543 Instant::now().duration_since(*instant).as_secs()
544 )),
545 ConnState::IdleTimeout => f.write_str("IdleTimeout"),
546 ConnState::ReadError(err) => {
547 f.write_fmt(format_args!("ReadError: {err}"))
548 }
549 ConnState::ReadTimeout => f.write_str("ReadTimeout"),
550 ConnState::WriteError(err) => {
551 f.write_fmt(format_args!("WriteError: {err}"))
552 }
553 }
554 }
555}
556
557#[derive(Debug)]
558enum XFRState {
561 AXFRInit,
563 AXFRFirstSoa(Serial),
565 IXFRInit,
567 IXFRFirstSoa(Serial),
569 IXFRFirstDiffSoa(Serial),
571 IXFRSecondDiffSoa(Serial),
573 Done,
575 Error,
577}
578
579impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti> {
580 fn new(
582 stream: Stream,
583 config: Config,
584 ) -> (mpsc::Sender<ChanReq<Req, ReqMulti>>, Self) {
585 let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP);
586 (
587 sender,
588 Self {
589 config,
590 stream,
591 receiver,
592 },
593 )
594 }
595}
596
597impl<Stream, Req, ReqMulti> Transport<Stream, Req, ReqMulti>
598where
599 Stream: AsyncRead + AsyncWrite,
600 Req: ComposeRequest,
601 ReqMulti: ComposeRequestMulti,
602{
603 pub async fn run(mut self) {
605 let (reply_sender, mut reply_receiver) =
606 mpsc::channel::<Message<Bytes>>(READ_REPLY_CHAN_CAP);
607
608 let (read_stream, mut write_stream) = tokio::io::split(self.stream);
609
610 let reader_fut = Self::reader(read_stream, reply_sender);
611 tokio::pin!(reader_fut);
612
613 let mut status = Status {
614 state: ConnState::Active(None),
615 idle_timeout: self.config.idle_timeout,
616 send_keepalive: true,
617 };
618 let mut query_vec =
619 Queries::<(ChanReq<Req, ReqMulti>, Option<XFRState>)>::new();
620
621 let mut reqmsg: Option<Vec<u8>> = None;
622 let mut reqmsg_offset = 0;
623
624 loop {
625 let opt_timeout = match status.state {
626 ConnState::Active(opt_instant) => {
627 if let Some(instant) = opt_instant {
628 let elapsed = instant.elapsed();
629 if elapsed > self.config.response_timeout {
630 Self::error(
631 Error::StreamReadTimeout,
632 &mut query_vec,
633 )
634 .await;
635 status.state = ConnState::ReadTimeout;
636 break;
637 }
638 Some(self.config.response_timeout - elapsed)
639 } else {
640 None
641 }
642 }
643 ConnState::Idle(instant) => {
644 let elapsed = instant.elapsed();
645 if elapsed >= status.idle_timeout {
646 status.state = ConnState::IdleTimeout;
649 break;
650 }
651 Some(status.idle_timeout - elapsed)
652 }
653 ConnState::IdleTimeout
654 | ConnState::ReadError(_)
655 | ConnState::WriteError(_) => None, ConnState::ReadTimeout => {
657 panic!("should not be in loop with ReadTimeout");
658 }
659 };
660
661 let timeout = match opt_timeout {
663 Some(timeout) => timeout,
664 None =>
665 {
667 self.config.response_timeout
668 }
669 };
670
671 let sleep_fut = sleep(timeout);
672 let recv_fut = self.receiver.recv();
673
674 let (do_write, msg) = match &reqmsg {
675 None => {
676 let msg: &[u8] = &[];
677 (false, msg)
678 }
679 Some(msg) => {
680 let msg: &[u8] = msg;
681 (true, msg)
682 }
683 };
684
685 tokio::select! {
686 biased;
687 res = &mut reader_fut => {
688 while let Ok(answer) = reply_receiver.try_recv() {
690 Self::demux_reply(answer, &mut status, &mut query_vec).await;
691 }
692
693 match res {
694 Ok(_) =>
695 panic!("reader terminated"),
699 Err(error) => {
700 Self::error(error.clone(), &mut query_vec).await;
701 status.state = ConnState::ReadError(error);
702 break
706 }
707 }
708 }
709 opt_answer = reply_receiver.recv() => {
710 let answer = opt_answer.expect("reader died?");
711 Self::demux_reply(answer, &mut status, &mut query_vec).await;
712 }
713 res = write_stream.write(&msg[reqmsg_offset..]),
714 if do_write => {
715 match res {
716 Err(error) => {
717 let error =
718 Error::StreamWriteError(Arc::new(error));
719 Self::error(error.clone(), &mut query_vec).await;
720 status.state =
721 ConnState::WriteError(error);
722 break;
723 }
724 Ok(len) => {
725 reqmsg_offset += len;
726 if reqmsg_offset >= msg.len() {
727 reqmsg = None;
728 reqmsg_offset = 0;
729 }
730 }
731 }
732 }
733 res = recv_fut, if !do_write => {
734 match res {
735 Some(req) => {
736 if req.sender.is_stream() {
737 self.config.response_timeout =
738 self.config.streaming_response_timeout;
739 } else {
740 self.config.response_timeout =
741 self.config.single_response_timeout;
742 }
743 Self::insert_req(
744 req, &mut status, &mut reqmsg, &mut query_vec
745 );
746 }
747 None => {
748 break;
751 }
752 }
753 }
754 _ = sleep_fut => {
755 }
758
759 }
760
761 match status.state {
763 ConnState::Active(_) | ConnState::Idle(_) => {
764 }
766 ConnState::IdleTimeout => break,
767 ConnState::ReadError(_)
768 | ConnState::ReadTimeout
769 | ConnState::WriteError(_) => {
770 panic!("Should not be here");
771 }
772 }
773 }
774
775 trace!("Closing TCP connecting in state: {}", status.state);
776
777 _ = write_stream.shutdown().await;
779 }
780
781 async fn reader(
790 mut sock: tokio::io::ReadHalf<Stream>,
791 sender: mpsc::Sender<Message<Bytes>>,
792 ) -> Result<(), Error> {
793 loop {
794 let read_res = sock.read_u16().await;
795 let len = match read_res {
796 Ok(len) => len,
797 Err(error) => {
798 return Err(Error::StreamReadError(Arc::new(error)));
799 }
800 } as usize;
801
802 let mut buf = BytesMut::with_capacity(len);
803
804 loop {
805 let curlen = buf.len();
806 if curlen >= len {
807 if curlen > len {
808 panic!(
809 "reader: got too much data {curlen}, expetect {len}");
810 }
811
812 break;
814 }
815
816 let read_res = sock.read_buf(&mut buf).await;
817
818 match read_res {
819 Ok(readlen) => {
820 if readlen == 0 {
821 return Err(Error::StreamUnexpectedEndOfData);
822 }
823 }
824 Err(error) => {
825 return Err(Error::StreamReadError(Arc::new(error)));
826 }
827 };
828
829 }
831
832 let reply_message = Message::<Bytes>::from_octets(buf.into());
833 match reply_message {
834 Ok(answer) => {
835 sender
836 .send(answer)
837 .await
838 .expect("can't send reply to run");
839 }
840 Err(_) => {
841 return Err(Error::ShortMessage);
843 }
844 }
845 }
846 }
847
848 async fn error(
850 error: Error,
851 query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
852 ) {
853 for (mut req, _) in query_vec.drain() {
856 _ = req.sender.send(Err(error.clone())).await;
857 }
858 }
859
860 fn handle_opts<Octs: Octets + AsRef<[u8]>>(
864 opts: &OptRecord<Octs>,
865 status: &mut Status,
866 ) {
867 for option in opts.opt().iter().flatten() {
871 if let AllOptData::TcpKeepalive(tcpkeepalive) = option {
872 Self::handle_keepalive(tcpkeepalive, status);
873 }
874 }
875 }
876
877 async fn demux_reply(
882 answer: Message<Bytes>,
883 status: &mut Status,
884 query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
885 ) {
886 if let Some(opts) = answer.opt() {
888 Self::handle_opts(&opts, status);
889 };
890
891 status.state = ConnState::Active(Some(Instant::now()));
893
894 let id = answer.header().id();
895
896 let (mut req, mut opt_xfr_data) = match query_vec.try_remove(id) {
898 Some(req) => req,
899 None => {
900 return;
903 }
904 };
905 let mut send_eof = false;
906 let answer = if match &req.msg {
907 ReqSingleMulti::Single(msg) => msg.is_answer(answer.for_slice()),
908 ReqSingleMulti::Multi(msg) => {
909 let xfr_data =
910 opt_xfr_data.expect("xfr_data should be present");
911 let (eof, xfr_data, is_answer) =
912 check_stream(msg, xfr_data, &answer);
913 send_eof = eof;
914 opt_xfr_data = Some(xfr_data);
915 is_answer
916 }
917 } {
918 Ok(answer)
919 } else {
920 Err(Error::WrongReplyForQuery)
921 };
922 _ = req.sender.send(answer).await;
923
924 if req.sender.is_stream() {
925 if send_eof {
926 _ = req.sender.send_eof().await;
927 } else {
928 query_vec.insert_at(id, (req, opt_xfr_data));
929 }
930 }
931
932 if query_vec.is_empty() {
933 status.state = ConnState::Active(None);
938
939 status.state = if status.idle_timeout.is_zero() {
940 ConnState::IdleTimeout
943 } else {
944 ConnState::Idle(Instant::now())
945 }
946 }
947 }
948
949 fn insert_req(
956 mut req: ChanReq<Req, ReqMulti>,
957 status: &mut Status,
958 reqmsg: &mut Option<Vec<u8>>,
959 query_vec: &mut Queries<(ChanReq<Req, ReqMulti>, Option<XFRState>)>,
960 ) {
961 match &status.state {
962 ConnState::Active(timer) => {
963 if timer.is_none() {
965 status.state = ConnState::Active(Some(Instant::now()));
966 }
967 }
968 ConnState::Idle(_) => {
969 status.state = ConnState::Active(Some(Instant::now()));
971 }
972 ConnState::IdleTimeout => {
973 _ = req.sender.send(Err(Error::StreamIdleTimeout));
975 return;
976 }
977 ConnState::ReadError(error) => {
978 _ = req.sender.send(Err(error.clone()));
979 return;
980 }
981 ConnState::ReadTimeout => {
982 _ = req.sender.send(Err(Error::StreamReadTimeout));
983 return;
984 }
985 ConnState::WriteError(error) => {
986 _ = req.sender.send(Err(error.clone()));
987 return;
988 }
989 }
990
991 let xfr_data = match &req.msg {
992 ReqSingleMulti::Single(_) => None,
993 ReqSingleMulti::Multi(msg) => {
994 let qtype = match msg.to_message().and_then(|m| {
995 m.sole_question()
996 .map_err(|_| Error::MessageParseError)
997 .map(|q| q.qtype())
998 }) {
999 Ok(msg) => msg,
1000 Err(e) => {
1001 _ = req.sender.send(Err(e));
1002 return;
1003 }
1004 };
1005 if qtype == Rtype::AXFR {
1006 Some(XFRState::AXFRInit)
1007 } else if qtype == Rtype::IXFR {
1008 Some(XFRState::IXFRInit)
1009 } else {
1010 _ = req.sender.send(Err(Error::FormError));
1012 return;
1013 }
1014 }
1015 };
1016
1017 let (index, (req, _)) = match query_vec.insert((req, xfr_data)) {
1021 Ok(res) => res,
1022 Err((mut req, _)) => {
1023 _ = req
1025 .sender
1026 .send(Err(Error::StreamTooManyOutstandingQueries));
1027 return;
1028 }
1029 };
1030
1031 let hdr = match &mut req.msg {
1040 ReqSingleMulti::Single(msg) => msg.header_mut(),
1041 ReqSingleMulti::Multi(msg) => msg.header_mut(),
1042 };
1043 hdr.set_id(index);
1044
1045 if status.send_keepalive
1046 && match &mut req.msg {
1047 ReqSingleMulti::Single(msg) => {
1048 msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1049 }
1050 ReqSingleMulti::Multi(msg) => {
1051 msg.add_opt(&TcpKeepalive::new(None)).is_ok()
1052 }
1053 }
1054 {
1055 status.send_keepalive = false;
1056 }
1057
1058 match Self::convert_query(&req.msg) {
1059 Ok(msg) => {
1060 *reqmsg = Some(msg);
1061 }
1062 Err(err) => {
1063 if let Some((mut req, _)) = query_vec.try_remove(index) {
1065 _ = req.sender.send(Err(err));
1066 }
1067 }
1068 }
1069 }
1070
1071 fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) {
1073 if let Some(value) = opt_value.timeout() {
1074 let value_dur = Duration::from(value);
1075 status.idle_timeout = value_dur;
1076 }
1077 }
1078
1079 fn convert_query(
1081 msg: &ReqSingleMulti<Req, ReqMulti>,
1082 ) -> Result<Vec<u8>, Error> {
1083 match msg {
1084 ReqSingleMulti::Single(msg) => {
1085 let mut target = StreamTarget::new_vec();
1086 msg.append_message(&mut target)
1087 .map_err(|_| Error::StreamLongMessage)?;
1088 Ok(target.into_target())
1089 }
1090 ReqSingleMulti::Multi(msg) => {
1091 let target = StreamTarget::new_vec();
1092 let target = msg
1093 .append_message(target)
1094 .map_err(|_| Error::StreamLongMessage)?;
1095 Ok(target.finish().into_target())
1096 }
1097 }
1098 }
1099}
1100
1101fn check_stream<CRM>(
1103 msg: &CRM,
1104 mut xfr_state: XFRState,
1105 answer: &Message<Bytes>,
1106) -> (bool, XFRState, bool)
1107where
1108 CRM: ComposeRequestMulti,
1109{
1110 match xfr_state {
1117 XFRState::AXFRInit | XFRState::IXFRInit => {
1118 if !msg.is_answer(answer.for_slice()) {
1119 xfr_state = XFRState::Error;
1120 return (false, xfr_state, false);
1123 }
1124 }
1125 XFRState::AXFRFirstSoa(_)
1126 | XFRState::IXFRFirstSoa(_)
1127 | XFRState::IXFRFirstDiffSoa(_)
1128 | XFRState::IXFRSecondDiffSoa(_) =>
1129 {}
1131 XFRState::Done => {
1132 xfr_state = XFRState::Error;
1134 return (false, xfr_state, false);
1135 }
1136 XFRState::Error =>
1137 {
1139 return (false, xfr_state, false)
1140 }
1141 }
1142
1143 if answer.header().rcode() != Rcode::NOERROR {
1145 if !msg.is_answer(answer.for_slice()) {
1147 xfr_state = XFRState::Error;
1148 return (false, xfr_state, false);
1151 }
1152 return (true, xfr_state, true);
1153 }
1154
1155 let ans_sec = match answer.answer() {
1156 Ok(ans) => ans,
1157 Err(_) => {
1158 xfr_state = XFRState::Error;
1160 return (true, xfr_state, false);
1162 }
1163 };
1164 for rr in
1165 ans_sec.into_records::<AllRecordData<Bytes, ParsedName<Bytes>>>()
1166 {
1167 let rr = match rr {
1168 Ok(rr) => rr,
1169 Err(_) => {
1170 xfr_state = XFRState::Error;
1172 return (true, xfr_state, false);
1173 }
1174 };
1175 match xfr_state {
1176 XFRState::AXFRInit => {
1177 if let AllRecordData::Soa(soa) = rr.data() {
1179 xfr_state = XFRState::AXFRFirstSoa(soa.serial());
1180 continue;
1181 }
1182 xfr_state = XFRState::Error;
1184 return (false, xfr_state, false);
1185 }
1186 XFRState::AXFRFirstSoa(serial) => {
1187 if let AllRecordData::Soa(soa) = rr.data() {
1188 if serial == soa.serial() {
1189 xfr_state = XFRState::Done;
1191 continue;
1192 }
1193
1194 xfr_state = XFRState::Error;
1196 return (false, xfr_state, false);
1197 }
1198
1199 }
1201 XFRState::IXFRInit => {
1202 if let AllRecordData::Soa(soa) = rr.data() {
1204 xfr_state = XFRState::IXFRFirstSoa(soa.serial());
1205 continue;
1206 }
1207 xfr_state = XFRState::Error;
1209 return (false, xfr_state, false);
1210 }
1211 XFRState::IXFRFirstSoa(serial) => {
1212 if let AllRecordData::Soa(soa) = rr.data() {
1220 if serial == soa.serial() {
1221 xfr_state = XFRState::Done;
1223 continue;
1224 }
1225
1226 xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1227 continue;
1228 }
1229
1230 xfr_state = XFRState::AXFRFirstSoa(serial);
1232 }
1233 XFRState::IXFRFirstDiffSoa(serial) => {
1234 if let AllRecordData::Soa(_) = rr.data() {
1237 xfr_state = XFRState::IXFRSecondDiffSoa(serial);
1238 continue;
1239 }
1240
1241 }
1243 XFRState::IXFRSecondDiffSoa(serial) => {
1244 if let AllRecordData::Soa(soa) = rr.data() {
1249 if serial == soa.serial() {
1250 xfr_state = XFRState::Done;
1252 continue;
1253 }
1254
1255 xfr_state = XFRState::IXFRFirstDiffSoa(serial);
1256 continue;
1257 }
1258
1259 }
1261 XFRState::Done => {
1262 xfr_state = XFRState::Error;
1264 return (false, xfr_state, false);
1265 }
1266 XFRState::Error => panic!("should not be here"),
1267 }
1268 }
1269
1270 match xfr_state {
1272 XFRState::AXFRInit | XFRState::IXFRInit => {
1273 xfr_state = XFRState::Error;
1276 return (false, xfr_state, false);
1277 }
1278 XFRState::AXFRFirstSoa(_)
1279 | XFRState::IXFRFirstDiffSoa(_)
1280 | XFRState::IXFRSecondDiffSoa(_) =>
1281 {}
1283 XFRState::IXFRFirstSoa(_) => {
1284 xfr_state = XFRState::Done;
1288 return (true, xfr_state, true);
1289 }
1290 XFRState::Done => return (true, xfr_state, true),
1291 XFRState::Error => unreachable!(),
1292 }
1293
1294 (false, xfr_state, true)
1296}
1297
1298#[derive(Clone, Debug)]
1305struct Queries<T> {
1306 count: usize,
1308
1309 curr: usize,
1311
1312 vec: Vec<Option<T>>,
1314}
1315
1316impl<T> Queries<T> {
1317 fn new() -> Self {
1319 Self {
1320 count: 0,
1321 curr: 0,
1322 vec: Vec::new(),
1323 }
1324 }
1325
1326 fn is_empty(&self) -> bool {
1328 self.count == 0
1329 }
1330
1331 fn insert(&mut self, req: T) -> Result<(u16, &mut T), T> {
1338 if 2 * self.count > u16::MAX as usize {
1344 return Err(req);
1345 }
1346
1347 let idx = if self.vec.len() >= 2 * self.count {
1350 let mut found = None;
1351 for idx in self.curr..self.vec.len() {
1352 if self.vec[idx].is_none() {
1353 found = Some(idx);
1354 break;
1355 }
1356 }
1357 found
1358 } else {
1359 None
1360 };
1361
1362 let idx = match idx {
1365 Some(idx) => {
1366 self.vec[idx] = Some(req);
1367 idx
1368 }
1369 None => {
1370 let idx = self.vec.len();
1371 self.vec.push(Some(req));
1372 idx
1373 }
1374 };
1375
1376 self.count += 1;
1377 if idx == self.curr {
1378 self.curr += 1;
1379 }
1380 let req = self.vec[idx].as_mut().expect("no inserted item?");
1381 let idx = u16::try_from(idx).expect("query vec too large");
1382 Ok((idx, req))
1383 }
1384
1385 fn insert_at(&mut self, id: u16, req: T) {
1388 let id = id as usize;
1389 self.vec[id] = Some(req);
1390
1391 self.count += 1;
1392 if id == self.curr {
1393 self.curr += 1;
1394 }
1395 }
1396
1397 fn try_remove(&mut self, index: u16) -> Option<T> {
1401 let res = self.vec.get_mut(usize::from(index))?.take()?;
1402 self.count = self.count.saturating_sub(1);
1403 self.curr = cmp::min(self.curr, index.into());
1404 Some(res)
1405 }
1406
1407 fn drain(&mut self) -> impl Iterator<Item = T> + '_ {
1409 let res = self.vec.drain(..).flatten(); self.count = 0;
1411 self.curr = 0;
1412 res
1413 }
1414}
1415
1416#[cfg(test)]
1419mod test {
1420 use super::*;
1421
1422 #[test]
1423 #[allow(clippy::needless_range_loop)]
1424 fn queries_insert_remove() {
1425 let mut idxs = [None; 20];
1428 let mut queries = Queries::new();
1429
1430 for i in 0..12 {
1431 let (idx, item) = queries.insert(i).expect("test failed");
1432 idxs[i] = Some(idx);
1433 assert_eq!(i, *item);
1434 }
1435 assert_eq!(queries.count, 12);
1436 assert_eq!(queries.vec.iter().flatten().count(), 12);
1437
1438 for i in [1, 2, 3, 4, 7, 9] {
1439 let item = queries
1440 .try_remove(idxs[i].expect("test failed"))
1441 .expect("test failed");
1442 assert_eq!(i, item);
1443 idxs[i] = None;
1444 }
1445 assert_eq!(queries.count, 6);
1446 assert_eq!(queries.vec.iter().flatten().count(), 6);
1447
1448 for i in 12..20 {
1449 let (idx, item) = queries.insert(i).expect("test failed");
1450 idxs[i] = Some(idx);
1451 assert_eq!(i, *item);
1452 }
1453 assert_eq!(queries.count, 14);
1454 assert_eq!(queries.vec.iter().flatten().count(), 14);
1455
1456 for i in 0..20 {
1457 if let Some(idx) = idxs[i] {
1458 let item = queries.try_remove(idx).expect("test failed");
1459 assert_eq!(i, item);
1460 }
1461 }
1462 assert_eq!(queries.count, 0);
1463 assert_eq!(queries.vec.iter().flatten().count(), 0);
1464 }
1465
1466 #[test]
1467 fn queries_overrun() {
1468 let mut queries = Queries::new();
1471 for i in 0..usize::from(u16::MAX) * 2 {
1472 let _ = queries.insert(i);
1473 }
1474 }
1475}