domain/resolv/lookup/
srv.rs

1//! Looking up SRV records.
2
3use super::host::lookup_host;
4use crate::base::iana::{Class, Rtype};
5use crate::base::message::Message;
6use crate::base::name::{Dname, ToDname, ToRelativeDname};
7use crate::base::wire::ParseError;
8use crate::rdata::{Aaaa, Srv, A};
9use crate::resolv::resolver::Resolver;
10use core::fmt;
11use futures_util::stream::{self, Stream, StreamExt};
12use octseq::octets::Octets;
13use rand::distributions::{Distribution, Uniform};
14use std::net::{IpAddr, SocketAddr};
15use std::vec::Vec;
16use std::{io, mem, ops};
17
18// Look up SRV record. Three outcomes:
19//
20// *  at least one SRV record with a regular target,
21// *  one single SRV record with the root target -- no such service,
22// *  no SRV records at all.
23//
24// In the first case we have a set of (target, port) pairs which we need to
25// resolve further if there was no address records for the target in the
26// additional section.
27//
28// In the second case we have nothing.
29//
30// In the third case we have a single (target, port) pair with the original
31// host and the fallback port which we need to resolve further.
32
33//------------ OctetsVec -----------------------------------------------------
34
35#[cfg(feature = "smallvec")]
36type OctetsVec = octseq::octets::SmallOctets;
37
38#[cfg(not(feature = "smallvec"))]
39type OctetsVec = Vec<u8>;
40
41//------------ lookup_srv ----------------------------------------------------
42
43/// Creates a future that looks up SRV records.
44///
45/// The future will use the resolver given in `resolver` to query the
46/// DNS for SRV records associated with domain name `name` and service
47/// `service`.
48///
49/// The value returned upon success can be turned into a stream of
50/// [`ResolvedSrvItem`]s corresponding to the found SRV records, ordered as per
51/// the usage rules defined in [RFC 2782]. If no matching SRV record is found,
52/// A/AAAA queries on the bare domain name `name` will be attempted, yielding
53/// a single element upon success using the port given by `fallback_port`,
54/// typcially the standard port for the service in question.
55///
56/// Each item in the stream can be turned into an iterator over socket
57/// addresses as accepted by, for instance, [`TcpStream::connect`].
58///
59/// The future resolves to `None` whenever the request service is
60/// “decidedly not available” at the requested domain, that is there is a
61/// single SRV record with the root label as its target.
62///
63///[`TcpStream::connect`]: tokio::net::TcpStream::connect
64pub async fn lookup_srv(
65    resolver: &impl Resolver,
66    service: impl ToRelativeDname,
67    name: impl ToDname,
68    fallback_port: u16,
69) -> Result<Option<FoundSrvs>, SrvError> {
70    let full_name = match (&service).chain(&name) {
71        Ok(name) => name,
72        Err(_) => return Err(SrvError::LongName),
73    };
74    let answer = resolver.query((full_name, Rtype::Srv)).await?;
75    FoundSrvs::new(answer.as_ref().for_slice(), name, fallback_port)
76}
77
78//------------ FoundSrvs -----------------------------------------------------
79
80/// This is the return type for [`lookup_srv`].
81#[derive(Clone, Debug)]
82pub struct FoundSrvs {
83    /// The SRV items we found.
84    ///
85    /// If this is `Ok(some)`, there were SRV records. If this is `Err(some)`,
86    /// there wasn’t any SRV records and the sole item is the bare host and
87    /// fallback port.
88    items: Result<Vec<SrvItem>, SrvItem>,
89}
90
91impl FoundSrvs {
92    /// Converts the found SRV records into socket addresses.
93    ///
94    /// The method takes a reference to a resolver and returns a stream of
95    /// socket addresses in the order prescribed by the SRV records. Each
96    /// returned item provides the set of addresses for one host.
97    ///
98    /// Note that if you are using the
99    /// [`StubResolver`][crate::resolv::stub::StubResolver], you will have to
100    /// pass in a double reference since [`Resolver`] is implemented for a
101    /// reference to it and this method requires a reference to that impl
102    /// being passed. This quirk will be fixed in future versions.
103    pub fn into_stream<R: Resolver>(
104        self,
105        resolver: &R,
106    ) -> impl Stream<Item = Result<ResolvedSrvItem, io::Error>> + '_
107    where
108        R::Octets: Octets,
109    {
110        // Let’s make a somewhat elaborate single iterator from self.items
111        // that we can use as the base for the stream: We turn the result into
112        // two options of the two cases and chain those up.
113        let iter = match self.items {
114            Ok(vec) => {
115                Some(vec.into_iter()).into_iter().flatten().chain(None)
116            }
117            Err(one) => None.into_iter().flatten().chain(Some(one)),
118        };
119        stream::iter(iter).then(move |item| item.resolve(resolver))
120    }
121
122    /// Converts the value into an iterator over the found SRV records.
123    ///
124    /// If results were found, this returns them in the order prescribed by
125    /// the SRV records.
126    ///
127    /// If not results were found, the iterator will yield a single entry
128    /// with the bare host and the default fallback port.
129    pub fn into_srvs(self) -> impl Iterator<Item = Srv<Dname<OctetsVec>>> {
130        let (left, right) = match self.items {
131            Ok(ok) => (Some(ok.into_iter()), None),
132            Err(err) => (None, Some(std::iter::once(err))),
133        };
134        left.into_iter()
135            .flatten()
136            .chain(right.into_iter().flatten())
137            .map(|item| item.srv)
138    }
139
140    /// Merges all results from `other` into `self`.
141    ///
142    /// Reorders merged results as if they were from a single query.
143    pub fn merge(&mut self, other: &Self) {
144        if self.items.is_err() {
145            let one =
146                mem::replace(&mut self.items, Ok(Vec::new())).unwrap_err();
147            self.items.as_mut().unwrap().push(one);
148        }
149        match self.items {
150            Ok(ref mut items) => {
151                match other.items {
152                    Ok(ref vec) => items.extend_from_slice(vec),
153                    Err(ref one) => items.push(one.clone()),
154                }
155                Self::reorder_items(items);
156            }
157            Err(_) => unreachable!(),
158        }
159    }
160}
161
162impl FoundSrvs {
163    fn new(
164        answer: &Message<[u8]>,
165        fallback_name: impl ToDname,
166        fallback_port: u16,
167    ) -> Result<Option<Self>, SrvError> {
168        let name =
169            answer.canonical_name().ok_or(SrvError::MalformedAnswer)?;
170        let mut items = Self::process_records(answer, &name)?;
171
172        if items.is_empty() {
173            return Ok(Some(FoundSrvs {
174                items: Err(SrvItem::fallback(fallback_name, fallback_port)),
175            }));
176        }
177        if items.len() == 1 && items[0].target().is_root() {
178            // Exactly one record with target "." indicates no service.
179            return Ok(None);
180        }
181
182        // Build results including potentially resolved IP addresses
183        Self::process_additional(&mut items, answer)?;
184        Self::reorder_items(&mut items);
185        Ok(Some(FoundSrvs { items: Ok(items) }))
186    }
187
188    fn process_records(
189        answer: &Message<[u8]>,
190        name: &impl ToDname,
191    ) -> Result<Vec<SrvItem>, SrvError> {
192        let mut res = Vec::new();
193        // XXX We could also error out if any SRV error is broken?
194        for record in answer.answer()?.limit_to_in::<Srv<_>>().flatten() {
195            if record.owner() == name {
196                res.push(SrvItem::from_rdata(record.data()))
197            }
198        }
199        Ok(res)
200    }
201
202    fn process_additional(
203        items: &mut [SrvItem],
204        answer: &Message<[u8]>,
205    ) -> Result<(), SrvError> {
206        let additional = answer.additional()?;
207        for item in items {
208            let mut addrs = Vec::new();
209            for record in additional {
210                let record = match record {
211                    Ok(record) => record,
212                    Err(_) => continue,
213                };
214                if record.class() != Class::In
215                    || record.owner() != item.target()
216                {
217                    continue;
218                }
219                if let Ok(Some(record)) = record.to_record::<A>() {
220                    addrs.push(record.data().addr().into())
221                }
222                if let Ok(Some(record)) = record.to_record::<Aaaa>() {
223                    addrs.push(record.data().addr().into())
224                }
225            }
226            if !addrs.is_empty() {
227                item.resolved = Some(addrs)
228            }
229        }
230        Ok(())
231    }
232
233    fn reorder_items(items: &mut [SrvItem]) {
234        // First, reorder by priority and weight, effectively
235        // grouping by priority, with weight 0 records at the beginning of
236        // each group.
237        items.sort_by_key(|k| (k.priority(), k.weight()));
238
239        // Find each group and reorder them using reorder_by_weight
240        let mut current_prio = 0;
241        let mut weight_sum = 0;
242        let mut first_index = 0;
243        for i in 0..items.len() {
244            if current_prio != items[i].priority() {
245                current_prio = items[i].priority();
246                Self::reorder_by_weight(
247                    &mut items[first_index..i],
248                    weight_sum,
249                );
250                weight_sum = 0;
251                first_index = i;
252            }
253            weight_sum += u32::from(items[i].weight());
254        }
255        Self::reorder_by_weight(&mut items[first_index..], weight_sum);
256    }
257
258    /// Reorders items in a priority level based on their weight
259    fn reorder_by_weight(items: &mut [SrvItem], weight_sum: u32) {
260        let mut rng = rand::thread_rng();
261        let mut weight_sum = weight_sum;
262        for i in 0..items.len() {
263            let range = Uniform::new(0, weight_sum + 1);
264            let mut sum: u32 = 0;
265            let pick = range.sample(&mut rng);
266            for j in 0..items.len() {
267                sum += u32::from(items[j].weight());
268                if sum >= pick {
269                    weight_sum -= u32::from(items[j].weight());
270                    items.swap(i, j);
271                    break;
272                }
273            }
274        }
275    }
276}
277
278//------------ SrvItem -------------------------------------------------------
279
280#[derive(Clone, Debug)]
281pub struct SrvItem {
282    /// The SRV record.
283    srv: Srv<Dname<OctetsVec>>,
284
285    /// Fall back?
286    #[allow(dead_code)] // XXX Check if we can actually remove it.
287    fallback: bool,
288
289    /// A resolved answer if we have one.
290    resolved: Option<Vec<IpAddr>>,
291}
292
293impl SrvItem {
294    fn from_rdata(srv: &Srv<impl ToDname>) -> Self {
295        SrvItem {
296            srv: Srv::new(
297                srv.priority(),
298                srv.weight(),
299                srv.port(),
300                srv.target().to_dname().unwrap(),
301            ),
302            fallback: false,
303            resolved: None,
304        }
305    }
306
307    fn fallback(name: impl ToDname, fallback_port: u16) -> Self {
308        SrvItem {
309            srv: Srv::new(0, 0, fallback_port, name.to_dname().unwrap()),
310            fallback: true,
311            resolved: None,
312        }
313    }
314
315    // Resolves the target.
316    pub async fn resolve<R: Resolver>(
317        self,
318        resolver: &R,
319    ) -> Result<ResolvedSrvItem, io::Error>
320    where
321        R::Octets: Octets,
322    {
323        let port = self.port();
324        if let Some(resolved) = self.resolved {
325            return Ok(ResolvedSrvItem {
326                srv: self.srv,
327                resolved: {
328                    resolved
329                        .into_iter()
330                        .map(|addr| SocketAddr::new(addr, port))
331                        .collect()
332                },
333            });
334        }
335        let resolved = lookup_host(resolver, self.target()).await?;
336        Ok(ResolvedSrvItem {
337            srv: self.srv,
338            resolved: {
339                resolved
340                    .iter()
341                    .map(|addr| SocketAddr::new(addr, port))
342                    .collect()
343            },
344        })
345    }
346}
347
348impl AsRef<Srv<Dname<OctetsVec>>> for SrvItem {
349    fn as_ref(&self) -> &Srv<Dname<OctetsVec>> {
350        &self.srv
351    }
352}
353
354impl ops::Deref for SrvItem {
355    type Target = Srv<Dname<OctetsVec>>;
356
357    fn deref(&self) -> &Self::Target {
358        self.as_ref()
359    }
360}
361
362//------------ ResolvedSrvItems ----------------------------------------------
363
364/// An SRV record which has itself been resolved into a [`SocketAddr`].
365#[derive(Clone, Debug)]
366pub struct ResolvedSrvItem {
367    srv: Srv<Dname<OctetsVec>>,
368    resolved: Vec<SocketAddr>,
369}
370
371impl ResolvedSrvItem {
372    /// Returns the resolved address for this record.
373    pub fn resolved(&self) -> &[SocketAddr] {
374        &self.resolved
375    }
376}
377
378impl AsRef<Srv<Dname<OctetsVec>>> for ResolvedSrvItem {
379    fn as_ref(&self) -> &Srv<Dname<OctetsVec>> {
380        &self.srv
381    }
382}
383
384impl ops::Deref for ResolvedSrvItem {
385    type Target = Srv<Dname<OctetsVec>>;
386
387    fn deref(&self) -> &Self::Target {
388        self.as_ref()
389    }
390}
391
392//------------ SrvError ------------------------------------------------------
393
394#[derive(Debug)]
395pub enum SrvError {
396    LongName,
397    MalformedAnswer,
398    Query(io::Error),
399}
400
401impl fmt::Display for SrvError {
402    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403        match self {
404            SrvError::LongName => write!(f, "name too long"),
405            SrvError::MalformedAnswer => write!(f, "malformed answer"),
406            SrvError::Query(e) => write!(f, "error executing query {}", e),
407        }
408    }
409}
410
411impl std::error::Error for SrvError {}
412
413impl From<io::Error> for SrvError {
414    fn from(err: io::Error) -> SrvError {
415        SrvError::Query(err)
416    }
417}
418
419impl From<ParseError> for SrvError {
420    fn from(_: ParseError) -> SrvError {
421        SrvError::MalformedAnswer
422    }
423}