1use self::conf::{
13 ResolvConf, ResolvOptions, SearchSuffix, ServerConf, Transport,
14};
15use crate::base::iana::Rcode;
16use crate::base::message::Message;
17use crate::base::message_builder::{AdditionalBuilder, MessageBuilder};
18use crate::base::name::{ToName, ToRelativeName};
19use crate::base::question::Question;
20use crate::net::client::dgram_stream;
21use crate::net::client::multi_stream;
22use crate::net::client::protocol::{TcpConnect, UdpConnect};
23use crate::net::client::redundant;
24use crate::net::client::request::{
25 ComposeRequest, Error, RequestMessage, SendRequest,
26};
27use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs};
28use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts};
29use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError};
30use crate::resolv::resolver::{Resolver, SearchNames};
31use bytes::Bytes;
32use futures_util::stream::{FuturesUnordered, StreamExt};
33use octseq::array::Array;
34use std::boxed::Box;
35use std::fmt::Debug;
36use std::future::Future;
37use std::net::IpAddr;
38use std::pin::Pin;
39use std::string::ToString;
40use std::sync::atomic::{AtomicBool, Ordering};
41use std::sync::Arc;
42use std::vec::Vec;
43use std::{io, ops};
44#[cfg(feature = "resolv-sync")]
45use tokio::runtime;
46use tokio::sync::Mutex;
47use tokio::time::timeout;
48
49pub mod conf;
52
53#[derive(Debug)]
77pub struct StubResolver {
78 transport: Mutex<Option<redundant::Connection<RequestMessage<Vec<u8>>>>>,
79
80 options: ResolvOptions,
82
83 servers: Vec<ServerConf>,
84}
85
86impl StubResolver {
87 pub fn new() -> Self {
89 Self::from_conf(ResolvConf::default())
90 }
91
92 pub fn from_conf(conf: ResolvConf) -> Self {
94 StubResolver {
95 transport: None.into(),
96 options: conf.options,
97
98 servers: conf.servers,
99 }
100 }
101
102 pub fn options(&self) -> &ResolvOptions {
103 &self.options
104 }
105
106 pub async fn add_connection(
108 &self,
109 connection: Box<
110 dyn SendRequest<RequestMessage<Vec<u8>>> + Send + Sync,
111 >,
112 ) {
113 self.get_transport()
114 .await
115 .expect("The 'redundant::Connection' task should not fail")
116 .add(connection)
117 .await
118 .expect("The 'redundant::Connection' task should not fail");
119 }
120
121 pub async fn query<N: ToName, Q: Into<Question<N>>>(
122 &self,
123 question: Q,
124 ) -> Result<Answer, io::Error> {
125 Query::new(self)?
126 .run(Query::create_message(question.into()))
127 .await
128 }
129
130 async fn query_message(
131 &self,
132 message: QueryMessage,
133 ) -> Result<Answer, io::Error> {
134 Query::new(self)?.run(message).await
135 }
136
137 async fn setup_transport<
138 CR: Clone + Debug + ComposeRequest + Send + Sync + 'static,
139 >(
140 &self,
141 ) -> Result<redundant::Connection<CR>, Error> {
142 let (redun, transp) = redundant::Connection::new();
144
145 let redun_run_fut = transp.run();
147
148 tokio::spawn(async move {
153 redun_run_fut.await;
154 });
155
156 let fut_list_tcp = FuturesUnordered::new();
157 let fut_list_udp_tcp = FuturesUnordered::new();
158
159 for s in &self.servers {
166 if self.options.use_vc || matches!(s.transport, Transport::Tcp) {
169 let (conn, tran) =
170 multi_stream::Connection::new(TcpConnect::new(s.addr));
171 fut_list_tcp.push(tran.run());
173 redun.add(Box::new(conn)).await?;
174 } else {
175 let udp_connect = UdpConnect::new(s.addr);
176 let tcp_connect = TcpConnect::new(s.addr);
177 let (conn, tran) =
178 dgram_stream::Connection::new(udp_connect, tcp_connect);
179 fut_list_udp_tcp.push(tran.run());
181 redun.add(Box::new(conn)).await?;
182 }
183 }
184
185 tokio::spawn(async move {
186 run(fut_list_tcp, fut_list_udp_tcp).await;
187 });
188
189 Ok(redun)
190 }
191
192 async fn get_transport(
193 &self,
194 ) -> Result<redundant::Connection<RequestMessage<Vec<u8>>>, Error> {
195 let mut opt_transport = self.transport.lock().await;
196
197 match &*opt_transport {
198 Some(transport) => Ok(transport.clone()),
199 None => {
200 let transport = self.setup_transport().await?;
201 *opt_transport = Some(transport.clone());
202 Ok(transport)
203 }
204 }
205 }
206}
207
208async fn run<TcpFut: Future, UdpTcpFut: Future>(
209 mut fut_list_tcp: FuturesUnordered<TcpFut>,
210 mut fut_list_udp_tcp: FuturesUnordered<UdpTcpFut>,
211) {
212 loop {
213 let tcp_empty = fut_list_tcp.is_empty();
214 let udp_tcp_empty = fut_list_udp_tcp.is_empty();
215 if tcp_empty && udp_tcp_empty {
216 break;
217 }
218 tokio::select! {
219 _ = fut_list_tcp.next(), if !tcp_empty => {
220 }
222 _ = fut_list_udp_tcp.next(), if !udp_tcp_empty => {
223 }
225 }
226 }
227}
228
229impl StubResolver {
230 pub async fn lookup_addr(
231 &self,
232 addr: IpAddr,
233 ) -> Result<FoundAddrs<&Self>, io::Error> {
234 lookup_addr(&self, addr).await
235 }
236
237 pub async fn lookup_host(
238 &self,
239 qname: impl ToName,
240 ) -> Result<FoundHosts<&Self>, io::Error> {
241 lookup_host(&self, qname).await
242 }
243
244 pub async fn search_host(
245 &self,
246 qname: impl ToRelativeName,
247 ) -> Result<FoundHosts<&Self>, io::Error> {
248 search_host(&self, qname).await
249 }
250
251 pub async fn lookup_srv(
255 &self,
256 service: impl ToRelativeName,
257 name: impl ToName,
258 fallback_port: u16,
259 ) -> Result<Option<FoundSrvs>, SrvError> {
260 lookup_srv(&self, service, name, fallback_port).await
261 }
262}
263
264#[cfg(feature = "resolv-sync")]
265#[cfg_attr(docsrs, doc(cfg(feature = "resolv-sync")))]
266impl StubResolver {
267 pub fn run<R, T, E, F>(op: F) -> R::Output
278 where
279 R: Future<Output = Result<T, E>> + Send + 'static,
280 E: From<io::Error>,
281 F: FnOnce(StubResolver) -> R + Send + 'static,
282 {
283 Self::run_with_conf(ResolvConf::default(), op)
284 }
285
286 pub fn run_with_conf<R, T, E, F>(conf: ResolvConf, op: F) -> R::Output
293 where
294 R: Future<Output = Result<T, E>> + Send + 'static,
295 E: From<io::Error>,
296 F: FnOnce(StubResolver) -> R + Send + 'static,
297 {
298 let resolver = Self::from_conf(conf);
299 let runtime = runtime::Builder::new_current_thread()
300 .enable_all()
301 .build()?;
302 runtime.block_on(op(resolver))
303 }
304}
305
306impl Default for StubResolver {
307 fn default() -> Self {
308 Self::new()
309 }
310}
311
312impl<'a> Resolver for &'a StubResolver {
313 type Octets = Bytes;
314 type Answer = Answer;
315 type Query =
316 Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + Send + 'a>>;
317
318 fn query<N, Q>(&self, question: Q) -> Self::Query
319 where
320 N: ToName,
321 Q: Into<Question<N>>,
322 {
323 let message = Query::create_message(question.into());
324 Box::pin(self.query_message(message))
325 }
326}
327
328impl<'a> SearchNames for &'a StubResolver {
329 type Name = SearchSuffix;
330 type Iter = SearchIter<'a>;
331
332 fn search_iter(&self) -> Self::Iter {
333 SearchIter {
334 resolver: self,
335 pos: 0,
336 }
337 }
338}
339
340pub struct Query<'a> {
343 resolver: &'a StubResolver,
345
346 edns: Arc<AtomicBool>,
347
348 error: Result<Answer, io::Error>,
356}
357
358impl<'a> Query<'a> {
359 pub fn new(resolver: &'a StubResolver) -> Result<Self, io::Error> {
360 Ok(Query {
361 resolver,
362 edns: Arc::new(AtomicBool::new(true)),
363 error: Err(io::Error::new(
364 io::ErrorKind::TimedOut,
365 "all timed out",
366 )),
367 })
368 }
369
370 pub async fn run(
371 mut self,
372 mut message: QueryMessage,
373 ) -> Result<Answer, io::Error> {
374 loop {
375 match self.run_query(&mut message).await {
376 Ok(answer) => {
377 if answer.header().rcode() == Rcode::FORMERR
378 && self.does_edns()
379 {
380 self.disable_edns();
382 continue;
383 } else if answer.header().rcode() == Rcode::SERVFAIL {
384 self.update_error_servfail(answer);
386 } else {
387 return Ok(answer);
389 }
390 }
391 Err(err) => self.update_error(err),
392 }
393 return self.error;
394 }
395 }
396
397 fn create_message(question: Question<impl ToName>) -> QueryMessage {
398 let mut message = MessageBuilder::from_target(Default::default())
399 .expect("MessageBuilder should not fail");
400 message.header_mut().set_rd(true);
401 let mut message = message.question();
402 message.push(question).expect("push should not fail");
403 message.additional()
404 }
405
406 async fn run_query(
407 &mut self,
408 message: &mut QueryMessage,
409 ) -> Result<Answer, io::Error> {
410 let msg = Message::from_octets(message.as_target().to_vec())
411 .expect("Message::from_octets should not fail");
412
413 let request_msg = RequestMessage::new(msg)
414 .map_err(|e| io::Error::other(e.to_string()))?;
415
416 let transport = self
417 .resolver
418 .get_transport()
419 .await
420 .map_err(|e| io::Error::other(e.to_string()))?;
421 let mut gr_fut = transport.send_request(request_msg);
422 let reply =
423 timeout(self.resolver.options.timeout, gr_fut.get_response())
424 .await?
425 .map_err(|e| io::Error::other(e.to_string()))?;
426 Ok(Answer { message: reply })
427 }
428
429 fn update_error(&mut self, err: io::Error) {
430 if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
434 self.error = Err(err)
435 }
436 }
437
438 fn update_error_servfail(&mut self, answer: Answer) {
439 self.error = Ok(answer)
440 }
441
442 pub fn does_edns(&self) -> bool {
443 self.edns.load(Ordering::Relaxed)
444 }
445
446 pub fn disable_edns(&self) {
447 self.edns.store(false, Ordering::Relaxed);
448 }
449}
450
451pub(super) type QueryMessage = AdditionalBuilder<Array<512>>;
455
456#[derive(Clone)]
463pub struct Answer {
464 message: Message<Bytes>,
465}
466
467impl Answer {
468 pub fn is_final(&self) -> bool {
470 (self.message.header().rcode() == Rcode::NOERROR
471 || self.message.header().rcode() == Rcode::NXDOMAIN)
472 && !self.message.header().tc()
473 }
474
475 pub fn is_truncated(&self) -> bool {
477 self.message.header().tc()
478 }
479
480 pub fn into_message(self) -> Message<Bytes> {
481 self.message
482 }
483}
484
485impl From<Message<Bytes>> for Answer {
486 fn from(message: Message<Bytes>) -> Self {
487 Answer { message }
488 }
489}
490
491impl ops::Deref for Answer {
492 type Target = Message<Bytes>;
493
494 fn deref(&self) -> &Self::Target {
495 &self.message
496 }
497}
498
499impl AsRef<Message<Bytes>> for Answer {
500 fn as_ref(&self) -> &Message<Bytes> {
501 &self.message
502 }
503}
504
505#[derive(Clone, Debug)]
508pub struct SearchIter<'a> {
509 resolver: &'a StubResolver,
510 pos: usize,
511}
512
513impl Iterator for SearchIter<'_> {
514 type Item = SearchSuffix;
515
516 fn next(&mut self) -> Option<Self::Item> {
517 if let Some(res) = self.resolver.options().search.get(self.pos) {
518 self.pos += 1;
519 Some(res.clone())
520 } else {
521 None
522 }
523 }
524}