domain/resolv/stub/
mod.rs

1//! A stub resolver.
2//!
3//! The most simple resolver possible simply relays all messages to one of a
4//! set of pre-configured resolvers that will do the actual work. This is
5//! equivalent to what the resolver part of the C library does. This module
6//! provides such a stub resolver that emulates this C resolver as closely
7//! as possible, in particular in the way it is being configured.
8//!
9//! The main type is [`StubResolver`] that implements the [`Resolver`] trait
10//! and thus can be used with the various lookup functions.
11
12use self::conf::{
13    ResolvConf, ResolvOptions, SearchSuffix, ServerConf, Transport,
14};
15use crate::base::iana::Rcode;
16use crate::base::message::Message;
17use crate::base::message_builder::{
18    AdditionalBuilder, MessageBuilder, StreamTarget,
19};
20use crate::base::name::{ToDname, ToRelativeDname};
21use crate::base::question::Question;
22use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs};
23use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts};
24use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError};
25use crate::resolv::resolver::{Resolver, SearchNames};
26use bytes::Bytes;
27use octseq::array::Array;
28use std::boxed::Box;
29use std::future::Future;
30use std::net::{IpAddr, SocketAddr};
31use std::pin::Pin;
32use std::slice::SliceIndex;
33use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
34use std::sync::Arc;
35use std::vec::Vec;
36use std::{io, ops};
37use tokio::io::{AsyncReadExt, AsyncWriteExt};
38use tokio::net::{TcpStream, UdpSocket};
39#[cfg(feature = "resolv-sync")]
40use tokio::runtime;
41use tokio::time::timeout;
42
43//------------ Sub-modules ---------------------------------------------------
44
45pub mod conf;
46
47//------------ Module Configuration ------------------------------------------
48
49/// How many times do we try a new random port if we get ‘address in use.’
50const RETRY_RANDOM_PORT: usize = 10;
51
52//------------ StubResolver --------------------------------------------------
53
54/// A DNS stub resolver.
55///
56/// This type collects all information making it possible to start DNS
57/// queries. You can create a new resolver using the system’s configuration
58/// using the [`new()`] associate function or using your own configuration
59/// with [`from_conf()`].
60///
61/// Stub resolver values can be cloned relatively cheaply as they keep all
62/// information behind an arc.
63///
64/// If you want to run a single query or lookup on a resolver synchronously,
65/// you can do so simply by using the [`run()`] or [`run_with_conf()`]
66/// associated functions.
67///
68/// [`new()`]: #method.new
69/// [`from_conf()`]: #method.from_conf
70/// [`query()`]: #method.query
71/// [`run()`]: #method.run
72/// [`run_with_conf()`]: #method.run_with_conf
73#[derive(Clone, Debug)]
74pub struct StubResolver {
75    /// Preferred servers.
76    preferred: ServerList,
77
78    /// Streaming servers.
79    stream: ServerList,
80
81    /// Resolver options.
82    options: ResolvOptions,
83}
84
85impl StubResolver {
86    /// Creates a new resolver using the system’s default configuration.
87    pub fn new() -> Self {
88        Self::from_conf(ResolvConf::default())
89    }
90
91    /// Creates a new resolver using the given configuraiton.
92    pub fn from_conf(conf: ResolvConf) -> Self {
93        StubResolver {
94            preferred: ServerList::from_conf(&conf, |s| {
95                s.transport.is_preferred()
96            }),
97            stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()),
98            options: conf.options,
99        }
100    }
101
102    pub fn options(&self) -> &ResolvOptions {
103        &self.options
104    }
105
106    pub async fn query<N: ToDname, Q: Into<Question<N>>>(
107        &self,
108        question: Q,
109    ) -> Result<Answer, io::Error> {
110        Query::new(self)?
111            .run(Query::create_message(question.into()))
112            .await
113    }
114
115    async fn query_message(
116        &self,
117        message: QueryMessage,
118    ) -> Result<Answer, io::Error> {
119        Query::new(self)?.run(message).await
120    }
121}
122
123impl StubResolver {
124    pub async fn lookup_addr(
125        &self,
126        addr: IpAddr,
127    ) -> Result<FoundAddrs<&Self>, io::Error> {
128        lookup_addr(&self, addr).await
129    }
130
131    pub async fn lookup_host(
132        &self,
133        qname: impl ToDname,
134    ) -> Result<FoundHosts<&Self>, io::Error> {
135        lookup_host(&self, qname).await
136    }
137
138    pub async fn search_host(
139        &self,
140        qname: impl ToRelativeDname,
141    ) -> Result<FoundHosts<&Self>, io::Error> {
142        search_host(&self, qname).await
143    }
144
145    /// Performs an SRV lookup using this resolver.
146    ///
147    /// See the documentation for the [`lookup_srv`] function for details.
148    pub async fn lookup_srv(
149        &self,
150        service: impl ToRelativeDname,
151        name: impl ToDname,
152        fallback_port: u16,
153    ) -> Result<Option<FoundSrvs>, SrvError> {
154        lookup_srv(&self, service, name, fallback_port).await
155    }
156}
157
158#[cfg(feature = "resolv-sync")]
159#[cfg_attr(docsrs, doc(cfg(feature = "resolv-sync")))]
160impl StubResolver {
161    /// Synchronously perform a DNS operation atop a standard resolver.
162    ///
163    /// This associated functions removes almost all boiler plate for the
164    /// case that you want to perform some DNS operation, either a query or
165    /// lookup, on a resolver using the system’s configuration and wait for
166    /// the result.
167    ///
168    /// The only argument is a closure taking a reference to a `StubResolver`
169    /// and returning a future. Whatever that future resolves to will be
170    /// returned.
171    pub fn run<R, F>(op: F) -> R::Output
172    where
173        R: Future + Send + 'static,
174        R::Output: Send + 'static,
175        F: FnOnce(StubResolver) -> R + Send + 'static,
176    {
177        Self::run_with_conf(ResolvConf::default(), op)
178    }
179
180    /// Synchronously perform a DNS operation atop a configured resolver.
181    ///
182    /// This is like [`run()`] but also takes a resolver configuration for
183    /// tailor-making your own resolver.
184    ///
185    /// [`run()`]: #method.run
186    pub fn run_with_conf<R, F>(conf: ResolvConf, op: F) -> R::Output
187    where
188        R: Future + Send + 'static,
189        R::Output: Send + 'static,
190        F: FnOnce(StubResolver) -> R + Send + 'static,
191    {
192        let resolver = Self::from_conf(conf);
193        let runtime = runtime::Builder::new_current_thread()
194            .enable_all()
195            .build()
196            .unwrap();
197        runtime.block_on(op(resolver))
198    }
199}
200
201impl Default for StubResolver {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207impl<'a> Resolver for &'a StubResolver {
208    type Octets = Bytes;
209    type Answer = Answer;
210    type Query =
211        Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + Send + 'a>>;
212
213    fn query<N, Q>(&self, question: Q) -> Self::Query
214    where
215        N: ToDname,
216        Q: Into<Question<N>>,
217    {
218        let message = Query::create_message(question.into());
219        Box::pin(self.query_message(message))
220    }
221}
222
223impl<'a> SearchNames for &'a StubResolver {
224    type Name = SearchSuffix;
225    type Iter = SearchIter<'a>;
226
227    fn search_iter(&self) -> Self::Iter {
228        SearchIter {
229            resolver: self,
230            pos: 0,
231        }
232    }
233}
234
235//------------ Query ---------------------------------------------------------
236
237pub struct Query<'a> {
238    /// The resolver whose configuration we are using.
239    resolver: &'a StubResolver,
240
241    /// Are we still in the preferred server list or have gone streaming?
242    preferred: bool,
243
244    /// The number of attempts, starting with zero.
245    attempt: usize,
246
247    /// The index in the server list we currently trying.
248    counter: ServerListCounter,
249
250    /// The preferred error to return.
251    ///
252    /// Every time we finish a single query, we see if we can update this with
253    /// a better one. If we finally have to fail, we return this result. This
254    /// is a result so we can return a servfail answer if that is the only
255    /// answer we get. (Remember, SERVFAIL is returned for a bogus answer, so
256    /// you might want to know.)
257    error: Result<Answer, io::Error>,
258}
259
260impl<'a> Query<'a> {
261    pub fn new(resolver: &'a StubResolver) -> Result<Self, io::Error> {
262        let (preferred, counter) =
263            if resolver.options().use_vc || resolver.preferred.is_empty() {
264                if resolver.stream.is_empty() {
265                    return Err(io::Error::new(
266                        io::ErrorKind::NotFound,
267                        "no servers available",
268                    ));
269                }
270                (false, resolver.stream.counter(resolver.options().rotate))
271            } else {
272                (true, resolver.preferred.counter(resolver.options().rotate))
273            };
274        Ok(Query {
275            resolver,
276            preferred,
277            attempt: 0,
278            counter,
279            error: Err(io::Error::new(
280                io::ErrorKind::TimedOut,
281                "all timed out",
282            )),
283        })
284    }
285
286    pub async fn run(
287        mut self,
288        mut message: QueryMessage,
289    ) -> Result<Answer, io::Error> {
290        loop {
291            match self.run_query(&mut message).await {
292                Ok(answer) => {
293                    if answer.header().rcode() == Rcode::FormErr
294                        && self.current_server().does_edns()
295                    {
296                        // FORMERR with EDNS: turn off EDNS and try again.
297                        self.current_server().disable_edns();
298                        continue;
299                    } else if answer.header().rcode() == Rcode::ServFail {
300                        // SERVFAIL: go to next server.
301                        self.update_error_servfail(answer);
302                    } else if answer.header().tc()
303                        && self.preferred
304                        && !self.resolver.options().ign_tc
305                    {
306                        // Truncated. If we can, switch to stream transports
307                        // and try again. Otherwise return the truncated
308                        // answer.
309                        if self.switch_to_stream() {
310                            continue;
311                        } else {
312                            return Ok(answer);
313                        }
314                    } else {
315                        // I guess we have an answer ...
316                        return Ok(answer);
317                    }
318                }
319                Err(err) => self.update_error(err),
320            }
321            if !self.next_server() {
322                return self.error;
323            }
324        }
325    }
326
327    fn create_message(question: Question<impl ToDname>) -> QueryMessage {
328        let mut message = MessageBuilder::from_target(
329            StreamTarget::new(Default::default()).unwrap(),
330        )
331        .unwrap();
332        message.header_mut().set_rd(true);
333        let mut message = message.question();
334        message.push(question).unwrap();
335        message.additional()
336    }
337
338    async fn run_query(
339        &mut self,
340        message: &mut QueryMessage,
341    ) -> Result<Answer, io::Error> {
342        let server = self.current_server();
343        server.prepare_message(message);
344        server.query(message).await
345    }
346
347    fn current_server(&self) -> &ServerInfo {
348        let list = if self.preferred {
349            &self.resolver.preferred
350        } else {
351            &self.resolver.stream
352        };
353        self.counter.info(list)
354    }
355
356    fn update_error(&mut self, err: io::Error) {
357        // We keep the last error except for timeouts or if we have a servfail
358        // answer already. Since we start with a timeout, we still get a that
359        // if everything times out.
360        if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
361            self.error = Err(err)
362        }
363    }
364
365    fn update_error_servfail(&mut self, answer: Answer) {
366        self.error = Ok(answer)
367    }
368
369    fn switch_to_stream(&mut self) -> bool {
370        if !self.preferred {
371            // We already did this.
372            return false;
373        }
374        self.preferred = false;
375        self.attempt = 0;
376        self.counter =
377            self.resolver.stream.counter(self.resolver.options().rotate);
378        true
379    }
380
381    fn next_server(&mut self) -> bool {
382        if self.counter.next() {
383            return true;
384        }
385        self.attempt += 1;
386        if self.attempt >= self.resolver.options().attempts {
387            return false;
388        }
389        self.counter = if self.preferred {
390            self.resolver
391                .preferred
392                .counter(self.resolver.options().rotate)
393        } else {
394            self.resolver.stream.counter(self.resolver.options().rotate)
395        };
396        true
397    }
398}
399
400//------------ QueryMessage --------------------------------------------------
401
402// XXX This needs to be re-evaluated if we start adding OPTtions to the query.
403pub(super) type QueryMessage = AdditionalBuilder<StreamTarget<Array<512>>>;
404
405//------------ Answer --------------------------------------------------------
406
407/// The answer to a question.
408///
409/// This type is a wrapper around the DNS [`Message`] containing the answer
410/// that provides some additional information.
411#[derive(Clone)]
412pub struct Answer {
413    message: Message<Bytes>,
414}
415
416impl Answer {
417    /// Returns whether the answer is a final answer to be returned.
418    pub fn is_final(&self) -> bool {
419        (self.message.header().rcode() == Rcode::NoError
420            || self.message.header().rcode() == Rcode::NXDomain)
421            && !self.message.header().tc()
422    }
423
424    /// Returns whether the answer is truncated.
425    pub fn is_truncated(&self) -> bool {
426        self.message.header().tc()
427    }
428
429    pub fn into_message(self) -> Message<Bytes> {
430        self.message
431    }
432}
433
434impl From<Message<Bytes>> for Answer {
435    fn from(message: Message<Bytes>) -> Self {
436        Answer { message }
437    }
438}
439
440impl ops::Deref for Answer {
441    type Target = Message<Bytes>;
442
443    fn deref(&self) -> &Self::Target {
444        &self.message
445    }
446}
447
448impl AsRef<Message<Bytes>> for Answer {
449    fn as_ref(&self) -> &Message<Bytes> {
450        &self.message
451    }
452}
453
454//------------ ServerInfo ----------------------------------------------------
455
456#[derive(Clone, Debug)]
457struct ServerInfo {
458    /// The basic server configuration.
459    conf: ServerConf,
460
461    /// Whether this server supports EDNS.
462    ///
463    /// We start out with assuming it does and unset it if we get a FORMERR.
464    edns: Arc<AtomicBool>,
465}
466
467impl ServerInfo {
468    pub fn does_edns(&self) -> bool {
469        self.edns.load(Ordering::Relaxed)
470    }
471
472    pub fn disable_edns(&self) {
473        self.edns.store(false, Ordering::Relaxed);
474    }
475
476    pub fn prepare_message(&self, query: &mut QueryMessage) {
477        query.rewind();
478        if self.does_edns() {
479            query
480                .opt(|opt| {
481                    opt.set_udp_payload_size(self.conf.udp_payload_size);
482                    Ok(())
483                })
484                .unwrap();
485        }
486    }
487
488    pub async fn query(
489        &self,
490        query: &QueryMessage,
491    ) -> Result<Answer, io::Error> {
492        let res = match self.conf.transport {
493            Transport::Udp => {
494                timeout(
495                    self.conf.request_timeout,
496                    Self::udp_query(
497                        query,
498                        self.conf.addr,
499                        self.conf.recv_size,
500                    ),
501                )
502                .await
503            }
504            Transport::Tcp => {
505                timeout(
506                    self.conf.request_timeout,
507                    Self::tcp_query(query, self.conf.addr),
508                )
509                .await
510            }
511        };
512        match res {
513            Ok(Ok(answer)) => Ok(answer),
514            Ok(Err(err)) => Err(err),
515            Err(_) => Err(io::Error::new(
516                io::ErrorKind::TimedOut,
517                "request timed out",
518            )),
519        }
520    }
521
522    pub async fn tcp_query(
523        query: &QueryMessage,
524        addr: SocketAddr,
525    ) -> Result<Answer, io::Error> {
526        let mut sock = TcpStream::connect(&addr).await?;
527        sock.write_all(query.as_target().as_stream_slice()).await?;
528
529        // This loop can be infinite because we have a timeout on this whole
530        // thing, anyway.
531        loop {
532            let mut buf = Vec::new();
533            let len = sock.read_u16().await? as u64;
534            AsyncReadExt::take(&mut sock, len)
535                .read_to_end(&mut buf)
536                .await?;
537            if let Ok(answer) = Message::from_octets(buf.into()) {
538                if answer.is_answer(&query.as_message()) {
539                    return Ok(answer.into());
540                }
541            // else try with the next message.
542            } else {
543                return Err(io::Error::new(
544                    io::ErrorKind::Other,
545                    "short buf",
546                ));
547            }
548        }
549    }
550
551    pub async fn udp_query(
552        query: &QueryMessage,
553        addr: SocketAddr,
554        recv_size: usize,
555    ) -> Result<Answer, io::Error> {
556        let sock = Self::udp_bind(addr.is_ipv4()).await?;
557        sock.connect(addr).await?;
558        let sent = sock.send(query.as_target().as_dgram_slice()).await?;
559        if sent != query.as_target().as_dgram_slice().len() {
560            return Err(io::Error::new(
561                io::ErrorKind::Other,
562                "short UDP send",
563            ));
564        }
565        loop {
566            let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here.
567            let len = sock.recv(&mut buf).await?;
568            buf.truncate(len);
569
570            // We ignore garbage since there is a timer on this whole thing.
571            let answer = match Message::from_octets(buf.into()) {
572                Ok(answer) => answer,
573                Err(_) => continue,
574            };
575            if !answer.is_answer(&query.as_message()) {
576                continue;
577            }
578            return Ok(answer.into());
579        }
580    }
581
582    async fn udp_bind(v4: bool) -> Result<UdpSocket, io::Error> {
583        let mut i = 0;
584        loop {
585            let local: SocketAddr = if v4 {
586                ([0u8; 4], 0).into()
587            } else {
588                ([0u16; 8], 0).into()
589            };
590            match UdpSocket::bind(&local).await {
591                Ok(sock) => return Ok(sock),
592                Err(err) => {
593                    if i == RETRY_RANDOM_PORT {
594                        return Err(err);
595                    } else {
596                        i += 1
597                    }
598                }
599            }
600        }
601    }
602}
603
604impl From<ServerConf> for ServerInfo {
605    fn from(conf: ServerConf) -> Self {
606        ServerInfo {
607            conf,
608            edns: Arc::new(AtomicBool::new(true)),
609        }
610    }
611}
612
613impl<'a> From<&'a ServerConf> for ServerInfo {
614    fn from(conf: &'a ServerConf) -> Self {
615        conf.clone().into()
616    }
617}
618
619//------------ ServerList ----------------------------------------------------
620
621#[derive(Clone, Debug)]
622struct ServerList {
623    /// The actual list of servers.
624    servers: Vec<ServerInfo>,
625
626    /// Where to start accessing the list.
627    ///
628    /// In rotate mode, this value will always keep growing and will have to
629    /// be used modulo `servers`’s length.
630    ///
631    /// When it eventually wraps around the end of usize’s range, there will
632    /// be a jump in rotation. Since that will happen only oh-so-often, we
633    /// accept that in favour of simpler code.
634    start: Arc<AtomicUsize>,
635}
636
637impl ServerList {
638    pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
639    where
640        F: Fn(&ServerConf) -> bool,
641    {
642        ServerList {
643            servers: {
644                conf.servers
645                    .iter()
646                    .filter(|f| filter(f))
647                    .map(Into::into)
648                    .collect()
649            },
650            start: Arc::new(AtomicUsize::new(0)),
651        }
652    }
653
654    pub fn is_empty(&self) -> bool {
655        self.servers.is_empty()
656    }
657
658    pub fn counter(&self, rotate: bool) -> ServerListCounter {
659        let res = ServerListCounter::new(self);
660        if rotate {
661            self.rotate()
662        }
663        res
664    }
665
666    pub fn iter(&self) -> ServerListIter {
667        ServerListIter::new(self)
668    }
669
670    pub fn rotate(&self) {
671        self.start.fetch_add(1, Ordering::SeqCst);
672    }
673}
674
675impl<'a> IntoIterator for &'a ServerList {
676    type Item = &'a ServerInfo;
677    type IntoIter = ServerListIter<'a>;
678
679    fn into_iter(self) -> Self::IntoIter {
680        self.iter()
681    }
682}
683
684impl<I: SliceIndex<[ServerInfo]>> ops::Index<I> for ServerList {
685    type Output = <I as SliceIndex<[ServerInfo]>>::Output;
686
687    fn index(&self, index: I) -> &<I as SliceIndex<[ServerInfo]>>::Output {
688        self.servers.index(index)
689    }
690}
691
692//------------ ServerListCounter ---------------------------------------------
693
694#[derive(Clone, Debug)]
695struct ServerListCounter {
696    cur: usize,
697    end: usize,
698}
699
700impl ServerListCounter {
701    fn new(list: &ServerList) -> Self {
702        if list.servers.is_empty() {
703            return ServerListCounter { cur: 0, end: 0 };
704        }
705
706        // We modulo the start value here to prevent hick-ups towards the
707        // end of usize’s range.
708        let start = list.start.load(Ordering::Relaxed) % list.servers.len();
709        ServerListCounter {
710            cur: start,
711            end: start + list.servers.len(),
712        }
713    }
714
715    #[allow(clippy::should_implement_trait)]
716    pub fn next(&mut self) -> bool {
717        let next = self.cur + 1;
718        if next < self.end {
719            self.cur = next;
720            true
721        } else {
722            false
723        }
724    }
725
726    pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
727        &list[self.cur % list.servers.len()]
728    }
729}
730
731//------------ ServerListIter ------------------------------------------------
732
733#[derive(Clone, Debug)]
734struct ServerListIter<'a> {
735    servers: &'a ServerList,
736    counter: ServerListCounter,
737}
738
739impl<'a> ServerListIter<'a> {
740    fn new(list: &'a ServerList) -> Self {
741        ServerListIter {
742            servers: list,
743            counter: ServerListCounter::new(list),
744        }
745    }
746}
747
748impl<'a> Iterator for ServerListIter<'a> {
749    type Item = &'a ServerInfo;
750
751    fn next(&mut self) -> Option<Self::Item> {
752        if self.counter.next() {
753            Some(self.counter.info(self.servers))
754        } else {
755            None
756        }
757    }
758}
759
760//------------ SearchIter ----------------------------------------------------
761
762#[derive(Clone, Debug)]
763pub struct SearchIter<'a> {
764    resolver: &'a StubResolver,
765    pos: usize,
766}
767
768impl<'a> Iterator for SearchIter<'a> {
769    type Item = SearchSuffix;
770
771    fn next(&mut self) -> Option<Self::Item> {
772        if let Some(res) = self.resolver.options().search.get(self.pos) {
773            self.pos += 1;
774            Some(res.clone())
775        } else {
776            None
777        }
778    }
779}