1use super::host::lookup_host;
4use crate::base::iana::{Class, Rtype};
5use crate::base::message::Message;
6use crate::base::name::{Dname, ToDname, ToRelativeDname};
7use crate::base::wire::ParseError;
8use crate::rdata::{Aaaa, Srv, A};
9use crate::resolv::resolver::Resolver;
10use core::fmt;
11use futures_util::stream::{self, Stream, StreamExt};
12use octseq::octets::Octets;
13use rand::distributions::{Distribution, Uniform};
14use std::net::{IpAddr, SocketAddr};
15use std::vec::Vec;
16use std::{io, mem, ops};
17
18#[cfg(feature = "smallvec")]
36type OctetsVec = octseq::octets::SmallOctets;
37
38#[cfg(not(feature = "smallvec"))]
39type OctetsVec = Vec<u8>;
40
41pub async fn lookup_srv(
65 resolver: &impl Resolver,
66 service: impl ToRelativeDname,
67 name: impl ToDname,
68 fallback_port: u16,
69) -> Result<Option<FoundSrvs>, SrvError> {
70 let full_name = match (&service).chain(&name) {
71 Ok(name) => name,
72 Err(_) => return Err(SrvError::LongName),
73 };
74 let answer = resolver.query((full_name, Rtype::Srv)).await?;
75 FoundSrvs::new(answer.as_ref().for_slice(), name, fallback_port)
76}
77
78#[derive(Clone, Debug)]
82pub struct FoundSrvs {
83 items: Result<Vec<SrvItem>, SrvItem>,
89}
90
91impl FoundSrvs {
92 pub fn into_stream<R: Resolver>(
104 self,
105 resolver: &R,
106 ) -> impl Stream<Item = Result<ResolvedSrvItem, io::Error>> + '_
107 where
108 R::Octets: Octets,
109 {
110 let iter = match self.items {
114 Ok(vec) => {
115 Some(vec.into_iter()).into_iter().flatten().chain(None)
116 }
117 Err(one) => None.into_iter().flatten().chain(Some(one)),
118 };
119 stream::iter(iter).then(move |item| item.resolve(resolver))
120 }
121
122 pub fn into_srvs(self) -> impl Iterator<Item = Srv<Dname<OctetsVec>>> {
130 let (left, right) = match self.items {
131 Ok(ok) => (Some(ok.into_iter()), None),
132 Err(err) => (None, Some(std::iter::once(err))),
133 };
134 left.into_iter()
135 .flatten()
136 .chain(right.into_iter().flatten())
137 .map(|item| item.srv)
138 }
139
140 pub fn merge(&mut self, other: &Self) {
144 if self.items.is_err() {
145 let one =
146 mem::replace(&mut self.items, Ok(Vec::new())).unwrap_err();
147 self.items.as_mut().unwrap().push(one);
148 }
149 match self.items {
150 Ok(ref mut items) => {
151 match other.items {
152 Ok(ref vec) => items.extend_from_slice(vec),
153 Err(ref one) => items.push(one.clone()),
154 }
155 Self::reorder_items(items);
156 }
157 Err(_) => unreachable!(),
158 }
159 }
160}
161
162impl FoundSrvs {
163 fn new(
164 answer: &Message<[u8]>,
165 fallback_name: impl ToDname,
166 fallback_port: u16,
167 ) -> Result<Option<Self>, SrvError> {
168 let name =
169 answer.canonical_name().ok_or(SrvError::MalformedAnswer)?;
170 let mut items = Self::process_records(answer, &name)?;
171
172 if items.is_empty() {
173 return Ok(Some(FoundSrvs {
174 items: Err(SrvItem::fallback(fallback_name, fallback_port)),
175 }));
176 }
177 if items.len() == 1 && items[0].target().is_root() {
178 return Ok(None);
180 }
181
182 Self::process_additional(&mut items, answer)?;
184 Self::reorder_items(&mut items);
185 Ok(Some(FoundSrvs { items: Ok(items) }))
186 }
187
188 fn process_records(
189 answer: &Message<[u8]>,
190 name: &impl ToDname,
191 ) -> Result<Vec<SrvItem>, SrvError> {
192 let mut res = Vec::new();
193 for record in answer.answer()?.limit_to_in::<Srv<_>>().flatten() {
195 if record.owner() == name {
196 res.push(SrvItem::from_rdata(record.data()))
197 }
198 }
199 Ok(res)
200 }
201
202 fn process_additional(
203 items: &mut [SrvItem],
204 answer: &Message<[u8]>,
205 ) -> Result<(), SrvError> {
206 let additional = answer.additional()?;
207 for item in items {
208 let mut addrs = Vec::new();
209 for record in additional {
210 let record = match record {
211 Ok(record) => record,
212 Err(_) => continue,
213 };
214 if record.class() != Class::In
215 || record.owner() != item.target()
216 {
217 continue;
218 }
219 if let Ok(Some(record)) = record.to_record::<A>() {
220 addrs.push(record.data().addr().into())
221 }
222 if let Ok(Some(record)) = record.to_record::<Aaaa>() {
223 addrs.push(record.data().addr().into())
224 }
225 }
226 if !addrs.is_empty() {
227 item.resolved = Some(addrs)
228 }
229 }
230 Ok(())
231 }
232
233 fn reorder_items(items: &mut [SrvItem]) {
234 items.sort_by_key(|k| (k.priority(), k.weight()));
238
239 let mut current_prio = 0;
241 let mut weight_sum = 0;
242 let mut first_index = 0;
243 for i in 0..items.len() {
244 if current_prio != items[i].priority() {
245 current_prio = items[i].priority();
246 Self::reorder_by_weight(
247 &mut items[first_index..i],
248 weight_sum,
249 );
250 weight_sum = 0;
251 first_index = i;
252 }
253 weight_sum += u32::from(items[i].weight());
254 }
255 Self::reorder_by_weight(&mut items[first_index..], weight_sum);
256 }
257
258 fn reorder_by_weight(items: &mut [SrvItem], weight_sum: u32) {
260 let mut rng = rand::thread_rng();
261 let mut weight_sum = weight_sum;
262 for i in 0..items.len() {
263 let range = Uniform::new(0, weight_sum + 1);
264 let mut sum: u32 = 0;
265 let pick = range.sample(&mut rng);
266 for j in 0..items.len() {
267 sum += u32::from(items[j].weight());
268 if sum >= pick {
269 weight_sum -= u32::from(items[j].weight());
270 items.swap(i, j);
271 break;
272 }
273 }
274 }
275 }
276}
277
278#[derive(Clone, Debug)]
281pub struct SrvItem {
282 srv: Srv<Dname<OctetsVec>>,
284
285 #[allow(dead_code)] fallback: bool,
288
289 resolved: Option<Vec<IpAddr>>,
291}
292
293impl SrvItem {
294 fn from_rdata(srv: &Srv<impl ToDname>) -> Self {
295 SrvItem {
296 srv: Srv::new(
297 srv.priority(),
298 srv.weight(),
299 srv.port(),
300 srv.target().to_dname().unwrap(),
301 ),
302 fallback: false,
303 resolved: None,
304 }
305 }
306
307 fn fallback(name: impl ToDname, fallback_port: u16) -> Self {
308 SrvItem {
309 srv: Srv::new(0, 0, fallback_port, name.to_dname().unwrap()),
310 fallback: true,
311 resolved: None,
312 }
313 }
314
315 pub async fn resolve<R: Resolver>(
317 self,
318 resolver: &R,
319 ) -> Result<ResolvedSrvItem, io::Error>
320 where
321 R::Octets: Octets,
322 {
323 let port = self.port();
324 if let Some(resolved) = self.resolved {
325 return Ok(ResolvedSrvItem {
326 srv: self.srv,
327 resolved: {
328 resolved
329 .into_iter()
330 .map(|addr| SocketAddr::new(addr, port))
331 .collect()
332 },
333 });
334 }
335 let resolved = lookup_host(resolver, self.target()).await?;
336 Ok(ResolvedSrvItem {
337 srv: self.srv,
338 resolved: {
339 resolved
340 .iter()
341 .map(|addr| SocketAddr::new(addr, port))
342 .collect()
343 },
344 })
345 }
346}
347
348impl AsRef<Srv<Dname<OctetsVec>>> for SrvItem {
349 fn as_ref(&self) -> &Srv<Dname<OctetsVec>> {
350 &self.srv
351 }
352}
353
354impl ops::Deref for SrvItem {
355 type Target = Srv<Dname<OctetsVec>>;
356
357 fn deref(&self) -> &Self::Target {
358 self.as_ref()
359 }
360}
361
362#[derive(Clone, Debug)]
366pub struct ResolvedSrvItem {
367 srv: Srv<Dname<OctetsVec>>,
368 resolved: Vec<SocketAddr>,
369}
370
371impl ResolvedSrvItem {
372 pub fn resolved(&self) -> &[SocketAddr] {
374 &self.resolved
375 }
376}
377
378impl AsRef<Srv<Dname<OctetsVec>>> for ResolvedSrvItem {
379 fn as_ref(&self) -> &Srv<Dname<OctetsVec>> {
380 &self.srv
381 }
382}
383
384impl ops::Deref for ResolvedSrvItem {
385 type Target = Srv<Dname<OctetsVec>>;
386
387 fn deref(&self) -> &Self::Target {
388 self.as_ref()
389 }
390}
391
392#[derive(Debug)]
395pub enum SrvError {
396 LongName,
397 MalformedAnswer,
398 Query(io::Error),
399}
400
401impl fmt::Display for SrvError {
402 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403 match self {
404 SrvError::LongName => write!(f, "name too long"),
405 SrvError::MalformedAnswer => write!(f, "malformed answer"),
406 SrvError::Query(e) => write!(f, "error executing query {}", e),
407 }
408 }
409}
410
411impl std::error::Error for SrvError {}
412
413impl From<io::Error> for SrvError {
414 fn from(err: io::Error) -> SrvError {
415 SrvError::Query(err)
416 }
417}
418
419impl From<ParseError> for SrvError {
420 fn from(_: ParseError) -> SrvError {
421 SrvError::MalformedAnswer
422 }
423}