1use self::conf::{
13 ResolvConf, ResolvOptions, SearchSuffix, ServerConf, Transport,
14};
15use crate::base::iana::Rcode;
16use crate::base::message::Message;
17use crate::base::message_builder::{
18 AdditionalBuilder, MessageBuilder, StreamTarget,
19};
20use crate::base::name::{ToDname, ToRelativeDname};
21use crate::base::question::Question;
22use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs};
23use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts};
24use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError};
25use crate::resolv::resolver::{Resolver, SearchNames};
26use bytes::Bytes;
27use octseq::array::Array;
28use std::boxed::Box;
29use std::future::Future;
30use std::net::{IpAddr, SocketAddr};
31use std::pin::Pin;
32use std::slice::SliceIndex;
33use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
34use std::sync::Arc;
35use std::vec::Vec;
36use std::{io, ops};
37use tokio::io::{AsyncReadExt, AsyncWriteExt};
38use tokio::net::{TcpStream, UdpSocket};
39#[cfg(feature = "resolv-sync")]
40use tokio::runtime;
41use tokio::time::timeout;
42
43pub mod conf;
46
47const RETRY_RANDOM_PORT: usize = 10;
51
52#[derive(Clone, Debug)]
74pub struct StubResolver {
75 preferred: ServerList,
77
78 stream: ServerList,
80
81 options: ResolvOptions,
83}
84
85impl StubResolver {
86 pub fn new() -> Self {
88 Self::from_conf(ResolvConf::default())
89 }
90
91 pub fn from_conf(conf: ResolvConf) -> Self {
93 StubResolver {
94 preferred: ServerList::from_conf(&conf, |s| {
95 s.transport.is_preferred()
96 }),
97 stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()),
98 options: conf.options,
99 }
100 }
101
102 pub fn options(&self) -> &ResolvOptions {
103 &self.options
104 }
105
106 pub async fn query<N: ToDname, Q: Into<Question<N>>>(
107 &self,
108 question: Q,
109 ) -> Result<Answer, io::Error> {
110 Query::new(self)?
111 .run(Query::create_message(question.into()))
112 .await
113 }
114
115 async fn query_message(
116 &self,
117 message: QueryMessage,
118 ) -> Result<Answer, io::Error> {
119 Query::new(self)?.run(message).await
120 }
121}
122
123impl StubResolver {
124 pub async fn lookup_addr(
125 &self,
126 addr: IpAddr,
127 ) -> Result<FoundAddrs<&Self>, io::Error> {
128 lookup_addr(&self, addr).await
129 }
130
131 pub async fn lookup_host(
132 &self,
133 qname: impl ToDname,
134 ) -> Result<FoundHosts<&Self>, io::Error> {
135 lookup_host(&self, qname).await
136 }
137
138 pub async fn search_host(
139 &self,
140 qname: impl ToRelativeDname,
141 ) -> Result<FoundHosts<&Self>, io::Error> {
142 search_host(&self, qname).await
143 }
144
145 pub async fn lookup_srv(
149 &self,
150 service: impl ToRelativeDname,
151 name: impl ToDname,
152 fallback_port: u16,
153 ) -> Result<Option<FoundSrvs>, SrvError> {
154 lookup_srv(&self, service, name, fallback_port).await
155 }
156}
157
158#[cfg(feature = "resolv-sync")]
159#[cfg_attr(docsrs, doc(cfg(feature = "resolv-sync")))]
160impl StubResolver {
161 pub fn run<R, F>(op: F) -> R::Output
172 where
173 R: Future + Send + 'static,
174 R::Output: Send + 'static,
175 F: FnOnce(StubResolver) -> R + Send + 'static,
176 {
177 Self::run_with_conf(ResolvConf::default(), op)
178 }
179
180 pub fn run_with_conf<R, F>(conf: ResolvConf, op: F) -> R::Output
187 where
188 R: Future + Send + 'static,
189 R::Output: Send + 'static,
190 F: FnOnce(StubResolver) -> R + Send + 'static,
191 {
192 let resolver = Self::from_conf(conf);
193 let runtime = runtime::Builder::new_current_thread()
194 .enable_all()
195 .build()
196 .unwrap();
197 runtime.block_on(op(resolver))
198 }
199}
200
201impl Default for StubResolver {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl<'a> Resolver for &'a StubResolver {
208 type Octets = Bytes;
209 type Answer = Answer;
210 type Query =
211 Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + Send + 'a>>;
212
213 fn query<N, Q>(&self, question: Q) -> Self::Query
214 where
215 N: ToDname,
216 Q: Into<Question<N>>,
217 {
218 let message = Query::create_message(question.into());
219 Box::pin(self.query_message(message))
220 }
221}
222
223impl<'a> SearchNames for &'a StubResolver {
224 type Name = SearchSuffix;
225 type Iter = SearchIter<'a>;
226
227 fn search_iter(&self) -> Self::Iter {
228 SearchIter {
229 resolver: self,
230 pos: 0,
231 }
232 }
233}
234
235pub struct Query<'a> {
238 resolver: &'a StubResolver,
240
241 preferred: bool,
243
244 attempt: usize,
246
247 counter: ServerListCounter,
249
250 error: Result<Answer, io::Error>,
258}
259
260impl<'a> Query<'a> {
261 pub fn new(resolver: &'a StubResolver) -> Result<Self, io::Error> {
262 let (preferred, counter) =
263 if resolver.options().use_vc || resolver.preferred.is_empty() {
264 if resolver.stream.is_empty() {
265 return Err(io::Error::new(
266 io::ErrorKind::NotFound,
267 "no servers available",
268 ));
269 }
270 (false, resolver.stream.counter(resolver.options().rotate))
271 } else {
272 (true, resolver.preferred.counter(resolver.options().rotate))
273 };
274 Ok(Query {
275 resolver,
276 preferred,
277 attempt: 0,
278 counter,
279 error: Err(io::Error::new(
280 io::ErrorKind::TimedOut,
281 "all timed out",
282 )),
283 })
284 }
285
286 pub async fn run(
287 mut self,
288 mut message: QueryMessage,
289 ) -> Result<Answer, io::Error> {
290 loop {
291 match self.run_query(&mut message).await {
292 Ok(answer) => {
293 if answer.header().rcode() == Rcode::FormErr
294 && self.current_server().does_edns()
295 {
296 self.current_server().disable_edns();
298 continue;
299 } else if answer.header().rcode() == Rcode::ServFail {
300 self.update_error_servfail(answer);
302 } else if answer.header().tc()
303 && self.preferred
304 && !self.resolver.options().ign_tc
305 {
306 if self.switch_to_stream() {
310 continue;
311 } else {
312 return Ok(answer);
313 }
314 } else {
315 return Ok(answer);
317 }
318 }
319 Err(err) => self.update_error(err),
320 }
321 if !self.next_server() {
322 return self.error;
323 }
324 }
325 }
326
327 fn create_message(question: Question<impl ToDname>) -> QueryMessage {
328 let mut message = MessageBuilder::from_target(
329 StreamTarget::new(Default::default()).unwrap(),
330 )
331 .unwrap();
332 message.header_mut().set_rd(true);
333 let mut message = message.question();
334 message.push(question).unwrap();
335 message.additional()
336 }
337
338 async fn run_query(
339 &mut self,
340 message: &mut QueryMessage,
341 ) -> Result<Answer, io::Error> {
342 let server = self.current_server();
343 server.prepare_message(message);
344 server.query(message).await
345 }
346
347 fn current_server(&self) -> &ServerInfo {
348 let list = if self.preferred {
349 &self.resolver.preferred
350 } else {
351 &self.resolver.stream
352 };
353 self.counter.info(list)
354 }
355
356 fn update_error(&mut self, err: io::Error) {
357 if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
361 self.error = Err(err)
362 }
363 }
364
365 fn update_error_servfail(&mut self, answer: Answer) {
366 self.error = Ok(answer)
367 }
368
369 fn switch_to_stream(&mut self) -> bool {
370 if !self.preferred {
371 return false;
373 }
374 self.preferred = false;
375 self.attempt = 0;
376 self.counter =
377 self.resolver.stream.counter(self.resolver.options().rotate);
378 true
379 }
380
381 fn next_server(&mut self) -> bool {
382 if self.counter.next() {
383 return true;
384 }
385 self.attempt += 1;
386 if self.attempt >= self.resolver.options().attempts {
387 return false;
388 }
389 self.counter = if self.preferred {
390 self.resolver
391 .preferred
392 .counter(self.resolver.options().rotate)
393 } else {
394 self.resolver.stream.counter(self.resolver.options().rotate)
395 };
396 true
397 }
398}
399
400pub(super) type QueryMessage = AdditionalBuilder<StreamTarget<Array<512>>>;
404
405#[derive(Clone)]
412pub struct Answer {
413 message: Message<Bytes>,
414}
415
416impl Answer {
417 pub fn is_final(&self) -> bool {
419 (self.message.header().rcode() == Rcode::NoError
420 || self.message.header().rcode() == Rcode::NXDomain)
421 && !self.message.header().tc()
422 }
423
424 pub fn is_truncated(&self) -> bool {
426 self.message.header().tc()
427 }
428
429 pub fn into_message(self) -> Message<Bytes> {
430 self.message
431 }
432}
433
434impl From<Message<Bytes>> for Answer {
435 fn from(message: Message<Bytes>) -> Self {
436 Answer { message }
437 }
438}
439
440impl ops::Deref for Answer {
441 type Target = Message<Bytes>;
442
443 fn deref(&self) -> &Self::Target {
444 &self.message
445 }
446}
447
448impl AsRef<Message<Bytes>> for Answer {
449 fn as_ref(&self) -> &Message<Bytes> {
450 &self.message
451 }
452}
453
454#[derive(Clone, Debug)]
457struct ServerInfo {
458 conf: ServerConf,
460
461 edns: Arc<AtomicBool>,
465}
466
467impl ServerInfo {
468 pub fn does_edns(&self) -> bool {
469 self.edns.load(Ordering::Relaxed)
470 }
471
472 pub fn disable_edns(&self) {
473 self.edns.store(false, Ordering::Relaxed);
474 }
475
476 pub fn prepare_message(&self, query: &mut QueryMessage) {
477 query.rewind();
478 if self.does_edns() {
479 query
480 .opt(|opt| {
481 opt.set_udp_payload_size(self.conf.udp_payload_size);
482 Ok(())
483 })
484 .unwrap();
485 }
486 }
487
488 pub async fn query(
489 &self,
490 query: &QueryMessage,
491 ) -> Result<Answer, io::Error> {
492 let res = match self.conf.transport {
493 Transport::Udp => {
494 timeout(
495 self.conf.request_timeout,
496 Self::udp_query(
497 query,
498 self.conf.addr,
499 self.conf.recv_size,
500 ),
501 )
502 .await
503 }
504 Transport::Tcp => {
505 timeout(
506 self.conf.request_timeout,
507 Self::tcp_query(query, self.conf.addr),
508 )
509 .await
510 }
511 };
512 match res {
513 Ok(Ok(answer)) => Ok(answer),
514 Ok(Err(err)) => Err(err),
515 Err(_) => Err(io::Error::new(
516 io::ErrorKind::TimedOut,
517 "request timed out",
518 )),
519 }
520 }
521
522 pub async fn tcp_query(
523 query: &QueryMessage,
524 addr: SocketAddr,
525 ) -> Result<Answer, io::Error> {
526 let mut sock = TcpStream::connect(&addr).await?;
527 sock.write_all(query.as_target().as_stream_slice()).await?;
528
529 loop {
532 let mut buf = Vec::new();
533 let len = sock.read_u16().await? as u64;
534 AsyncReadExt::take(&mut sock, len)
535 .read_to_end(&mut buf)
536 .await?;
537 if let Ok(answer) = Message::from_octets(buf.into()) {
538 if answer.is_answer(&query.as_message()) {
539 return Ok(answer.into());
540 }
541 } else {
543 return Err(io::Error::new(
544 io::ErrorKind::Other,
545 "short buf",
546 ));
547 }
548 }
549 }
550
551 pub async fn udp_query(
552 query: &QueryMessage,
553 addr: SocketAddr,
554 recv_size: usize,
555 ) -> Result<Answer, io::Error> {
556 let sock = Self::udp_bind(addr.is_ipv4()).await?;
557 sock.connect(addr).await?;
558 let sent = sock.send(query.as_target().as_dgram_slice()).await?;
559 if sent != query.as_target().as_dgram_slice().len() {
560 return Err(io::Error::new(
561 io::ErrorKind::Other,
562 "short UDP send",
563 ));
564 }
565 loop {
566 let mut buf = vec![0; recv_size]; let len = sock.recv(&mut buf).await?;
568 buf.truncate(len);
569
570 let answer = match Message::from_octets(buf.into()) {
572 Ok(answer) => answer,
573 Err(_) => continue,
574 };
575 if !answer.is_answer(&query.as_message()) {
576 continue;
577 }
578 return Ok(answer.into());
579 }
580 }
581
582 async fn udp_bind(v4: bool) -> Result<UdpSocket, io::Error> {
583 let mut i = 0;
584 loop {
585 let local: SocketAddr = if v4 {
586 ([0u8; 4], 0).into()
587 } else {
588 ([0u16; 8], 0).into()
589 };
590 match UdpSocket::bind(&local).await {
591 Ok(sock) => return Ok(sock),
592 Err(err) => {
593 if i == RETRY_RANDOM_PORT {
594 return Err(err);
595 } else {
596 i += 1
597 }
598 }
599 }
600 }
601 }
602}
603
604impl From<ServerConf> for ServerInfo {
605 fn from(conf: ServerConf) -> Self {
606 ServerInfo {
607 conf,
608 edns: Arc::new(AtomicBool::new(true)),
609 }
610 }
611}
612
613impl<'a> From<&'a ServerConf> for ServerInfo {
614 fn from(conf: &'a ServerConf) -> Self {
615 conf.clone().into()
616 }
617}
618
619#[derive(Clone, Debug)]
622struct ServerList {
623 servers: Vec<ServerInfo>,
625
626 start: Arc<AtomicUsize>,
635}
636
637impl ServerList {
638 pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
639 where
640 F: Fn(&ServerConf) -> bool,
641 {
642 ServerList {
643 servers: {
644 conf.servers
645 .iter()
646 .filter(|f| filter(f))
647 .map(Into::into)
648 .collect()
649 },
650 start: Arc::new(AtomicUsize::new(0)),
651 }
652 }
653
654 pub fn is_empty(&self) -> bool {
655 self.servers.is_empty()
656 }
657
658 pub fn counter(&self, rotate: bool) -> ServerListCounter {
659 let res = ServerListCounter::new(self);
660 if rotate {
661 self.rotate()
662 }
663 res
664 }
665
666 pub fn iter(&self) -> ServerListIter {
667 ServerListIter::new(self)
668 }
669
670 pub fn rotate(&self) {
671 self.start.fetch_add(1, Ordering::SeqCst);
672 }
673}
674
675impl<'a> IntoIterator for &'a ServerList {
676 type Item = &'a ServerInfo;
677 type IntoIter = ServerListIter<'a>;
678
679 fn into_iter(self) -> Self::IntoIter {
680 self.iter()
681 }
682}
683
684impl<I: SliceIndex<[ServerInfo]>> ops::Index<I> for ServerList {
685 type Output = <I as SliceIndex<[ServerInfo]>>::Output;
686
687 fn index(&self, index: I) -> &<I as SliceIndex<[ServerInfo]>>::Output {
688 self.servers.index(index)
689 }
690}
691
692#[derive(Clone, Debug)]
695struct ServerListCounter {
696 cur: usize,
697 end: usize,
698}
699
700impl ServerListCounter {
701 fn new(list: &ServerList) -> Self {
702 if list.servers.is_empty() {
703 return ServerListCounter { cur: 0, end: 0 };
704 }
705
706 let start = list.start.load(Ordering::Relaxed) % list.servers.len();
709 ServerListCounter {
710 cur: start,
711 end: start + list.servers.len(),
712 }
713 }
714
715 #[allow(clippy::should_implement_trait)]
716 pub fn next(&mut self) -> bool {
717 let next = self.cur + 1;
718 if next < self.end {
719 self.cur = next;
720 true
721 } else {
722 false
723 }
724 }
725
726 pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
727 &list[self.cur % list.servers.len()]
728 }
729}
730
731#[derive(Clone, Debug)]
734struct ServerListIter<'a> {
735 servers: &'a ServerList,
736 counter: ServerListCounter,
737}
738
739impl<'a> ServerListIter<'a> {
740 fn new(list: &'a ServerList) -> Self {
741 ServerListIter {
742 servers: list,
743 counter: ServerListCounter::new(list),
744 }
745 }
746}
747
748impl<'a> Iterator for ServerListIter<'a> {
749 type Item = &'a ServerInfo;
750
751 fn next(&mut self) -> Option<Self::Item> {
752 if self.counter.next() {
753 Some(self.counter.info(self.servers))
754 } else {
755 None
756 }
757 }
758}
759
760#[derive(Clone, Debug)]
763pub struct SearchIter<'a> {
764 resolver: &'a StubResolver,
765 pos: usize,
766}
767
768impl<'a> Iterator for SearchIter<'a> {
769 type Item = SearchSuffix;
770
771 fn next(&mut self) -> Option<Self::Item> {
772 if let Some(res) = self.resolver.options().search.get(self.pos) {
773 self.pos += 1;
774 Some(res.clone())
775 } else {
776 None
777 }
778 }
779}