domain/net/client/
dgram.rs

1//! A client over datagram protocols.
2//!
3//! This module implements a DNS client for use with datagram protocols, i.e.,
4//! message-oriented, connection-less, unreliable network protocols. In
5//! practice, this is pretty much exclusively UDP.
6
7#![warn(missing_docs)]
8
9// To do:
10// - cookies
11
12use crate::base::Message;
13use crate::net::client::protocol::{
14    AsyncConnect, AsyncDgramRecv, AsyncDgramRecvEx, AsyncDgramSend,
15    AsyncDgramSendEx,
16};
17use crate::net::client::request::{
18    ComposeRequest, Error, GetResponse, SendRequest,
19};
20use crate::utils::config::DefMinMax;
21use bytes::Bytes;
22use core::fmt;
23use octseq::OctetsInto;
24use std::boxed::Box;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::vec::Vec;
29use std::{error, io};
30use tokio::sync::Semaphore;
31use tokio::time::{timeout_at, Duration, Instant};
32use tracing::trace;
33
34//------------ Configuration Constants ----------------------------------------
35
36/// Configuration limits for the maximum number of parallel requests.
37const MAX_PARALLEL: DefMinMax<usize> = DefMinMax::new(100, 1, 1000);
38
39/// Configuration limits for the read timeout.
40const READ_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
41    Duration::from_secs(5),
42    Duration::from_millis(1),
43    Duration::from_secs(60),
44);
45
46/// Configuration limits for the maximum number of retries.
47const MAX_RETRIES: DefMinMax<u8> = DefMinMax::new(5, 0, 100);
48
49/// Default UDP payload size.
50const DEF_UDP_PAYLOAD_SIZE: u16 = 1232;
51
52/// The default receive buffer size.
53const DEF_RECV_SIZE: usize = 2000;
54
55//------------ Config ---------------------------------------------------------
56
57/// Configuration of a datagram transport.
58#[derive(Clone, Debug)]
59pub struct Config {
60    /// Maximum number of parallel requests for a transport connection.
61    max_parallel: usize,
62
63    /// Read timeout.
64    read_timeout: Duration,
65
66    /// Maximum number of retries.
67    max_retries: u8,
68
69    /// EDNS UDP payload size.
70    ///
71    /// If this is `None`, no OPT record will be included at all.
72    udp_payload_size: Option<u16>,
73
74    /// Receive buffer size.
75    recv_size: usize,
76}
77
78impl Config {
79    /// Creates a new config with default values.
80    pub fn new() -> Self {
81        Default::default()
82    }
83
84    /// Sets the maximum number of parallel requests.
85    ///
86    /// Once this many number of requests are currently outstanding,
87    /// additional requests will wait.
88    ///
89    /// If this value is too small or too large, it will be capped.
90    pub fn set_max_parallel(&mut self, value: usize) {
91        self.max_parallel = MAX_PARALLEL.limit(value)
92    }
93
94    /// Returns the maximum number of parallel requests.
95    pub fn max_parallel(&self) -> usize {
96        self.max_parallel
97    }
98
99    /// Sets the read timeout.
100    ///
101    /// The read timeout is the maximum amount of time to wait for any
102    /// response after a request was sent.
103    ///
104    /// If this value is too small or too large, it will be capped.
105    pub fn set_read_timeout(&mut self, value: Duration) {
106        self.read_timeout = READ_TIMEOUT.limit(value)
107    }
108
109    /// Returns the read timeout.
110    pub fn read_timeout(&self) -> Duration {
111        self.read_timeout
112    }
113
114    /// Sets the maximum number of times a request is retried before giving
115    /// up.
116    ///
117    /// If this value is too small or too large, it will be capped.
118    pub fn set_max_retries(&mut self, value: u8) {
119        self.max_retries = MAX_RETRIES.limit(value)
120    }
121
122    /// Returns the maximum number of request retries.
123    pub fn max_retries(&self) -> u8 {
124        self.max_retries
125    }
126
127    /// Sets the requested UDP payload size.
128    ///
129    /// This value indicates to the server the maximum size of a UDP packet.
130    /// For UDP on public networks, this value should be left at the default
131    /// of 1232 to avoid issues rising from packet fragmentation. See
132    /// [draft-ietf-dnsop-avoid-fragmentation] for a discussion on these
133    /// issues and recommendations.
134    ///
135    /// On private networks or protocols other than UDP, other values can be
136    /// used.
137    ///
138    /// Setting the UDP payload size to `None` currently results in messages
139    /// that will not include an OPT record.
140    ///
141    /// [draft-ietf-dnsop-avoid-fragmentation]: https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/
142    pub fn set_udp_payload_size(&mut self, value: Option<u16>) {
143        self.udp_payload_size = value;
144    }
145
146    /// Returns the UDP payload size.
147    pub fn udp_payload_size(&self) -> Option<u16> {
148        self.udp_payload_size
149    }
150
151    /// Sets the receive buffer size.
152    ///
153    /// This is the amount of memory that is allocated for receiving a
154    /// response.
155    pub fn set_recv_size(&mut self, size: usize) {
156        self.recv_size = size
157    }
158
159    /// Returns the receive buffer size.
160    pub fn recv_size(&self) -> usize {
161        self.recv_size
162    }
163}
164
165impl Default for Config {
166    fn default() -> Self {
167        Self {
168            max_parallel: MAX_PARALLEL.default(),
169            read_timeout: READ_TIMEOUT.default(),
170            max_retries: MAX_RETRIES.default(),
171            udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE),
172            recv_size: DEF_RECV_SIZE,
173        }
174    }
175}
176
177//------------ Connection -----------------------------------------------------
178
179/// A datagram protocol connection.
180#[derive(Clone, Debug)]
181pub struct Connection<S> {
182    /// Actual state of the connection.
183    state: Arc<ConnectionState<S>>,
184}
185
186/// Because it owns the connection’s resources, this type is not [`Clone`].
187/// However, it is entirely safe to share it by sticking it into e.g. an arc.
188#[derive(Debug)]
189struct ConnectionState<S> {
190    /// User configuration variables.
191    config: Config,
192
193    /// Connections to datagram sockets.
194    connect: S,
195
196    /// Semaphore to limit access to UDP sockets.
197    semaphore: Semaphore,
198}
199
200impl<S> Connection<S> {
201    /// Create a new datagram transport with default configuration.
202    pub fn new(connect: S) -> Self {
203        Self::with_config(connect, Default::default())
204    }
205
206    /// Create a new datagram transport with a given configuration.
207    pub fn with_config(connect: S, config: Config) -> Self {
208        Self {
209            state: Arc::new(ConnectionState {
210                semaphore: Semaphore::new(config.max_parallel),
211                config,
212                connect,
213            }),
214        }
215    }
216}
217
218impl<S> Connection<S>
219where
220    S: AsyncConnect,
221    S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin,
222{
223    /// Performs a request.
224    ///
225    /// Sends the provided and returns either a response or an error. If there
226    /// are currently too many active queries, the future will wait until the
227    /// number has dropped below the limit.
228    async fn handle_request_impl<Req: ComposeRequest>(
229        self,
230        mut request: Req,
231    ) -> Result<Message<Bytes>, Error> {
232        // Acquire the semaphore or wait for it.
233        let _permit = self
234            .state
235            .semaphore
236            .acquire()
237            .await
238            .expect("semaphore closed");
239
240        // The buffer we will reuse on subsequent requests
241        let mut buf = Vec::new();
242
243        // Transmit loop.
244        for _ in 0..1 + self.state.config.max_retries {
245            let mut sock = self
246                .state
247                .connect
248                .connect()
249                .await
250                .map_err(QueryError::connect)?;
251
252            // Set random ID in header
253            request.header_mut().set_random_id();
254
255            // Set UDP payload size if necessary.
256            if let Some(size) = self.state.config.udp_payload_size {
257                request.set_udp_payload_size(size)
258            }
259
260            // Create the message and send it out.
261            let request_msg = request.to_message()?;
262            let dgram = request_msg.as_slice();
263            let sent = sock.send(dgram).await.map_err(QueryError::send)?;
264            if sent != dgram.len() {
265                return Err(QueryError::short_send().into());
266            }
267
268            // Receive loop. It may at most take read_timeout time.
269            let deadline = Instant::now() + self.state.config.read_timeout;
270            while deadline > Instant::now() {
271                // The buffer might have been truncated in a previous
272                // iteration.
273                buf.resize(self.state.config.recv_size, 0);
274
275                let len =
276                    match timeout_at(deadline, sock.recv(&mut buf)).await {
277                        Ok(Ok(len)) => len,
278                        Ok(Err(err)) => {
279                            // Receiving failed.
280                            return Err(QueryError::receive(err).into());
281                        }
282                        Err(_) => {
283                            // Timeout.
284                            trace!("Receive timed out");
285                            break;
286                        }
287                    };
288
289                trace!("Received {len} bytes of message");
290                buf.truncate(len);
291
292                // We ignore garbage since there is a timer on this whole
293                // thing.
294                let answer = match Message::try_from_octets(buf) {
295                    Ok(answer) => answer,
296                    Err(old_buf) => {
297                        // Just go back to receiving.
298                        trace!("Received bytes were garbage, reading more");
299                        buf = old_buf;
300                        continue;
301                    }
302                };
303
304                if !request.is_answer(answer.for_slice()) {
305                    // Wrong answer, go back to receiving
306                    trace!("Received message is not the answer we were waiting for, reading more");
307                    buf = answer.into_octets();
308                    continue;
309                }
310
311                trace!("Received message is accepted");
312                return Ok(answer.octets_into());
313            }
314        }
315        Err(QueryError::timeout().into())
316    }
317}
318
319//--- SendRequest
320
321impl<S, Req> SendRequest<Req> for Connection<S>
322where
323    S: AsyncConnect + Clone + Send + Sync + 'static,
324    S::Connection:
325        AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static,
326    Req: ComposeRequest + Send + Sync + 'static,
327{
328    fn send_request(
329        &self,
330        request_msg: Req,
331    ) -> Box<dyn GetResponse + Send + Sync> {
332        Box::new(Request {
333            fut: Box::pin(self.clone().handle_request_impl(request_msg)),
334        })
335    }
336}
337
338//------------ Request ------------------------------------------------------
339
340/// The state of a DNS request.
341pub struct Request {
342    /// Future that does the actual work of GetResponse.
343    fut: Pin<
344        Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
345    >,
346}
347
348impl Request {
349    /// Async function that waits for the future stored in Request to complete.
350    async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
351        (&mut self.fut).await
352    }
353}
354
355impl fmt::Debug for Request {
356    fn fmt(&self, _: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
357        todo!()
358    }
359}
360
361impl GetResponse for Request {
362    fn get_response(
363        &mut self,
364    ) -> Pin<
365        Box<
366            dyn Future<Output = Result<Message<Bytes>, Error>>
367                + Send
368                + Sync
369                + '_,
370        >,
371    > {
372        Box::pin(self.get_response_impl())
373    }
374}
375
376//============ Errors ========================================================
377
378//------------ QueryError ----------------------------------------------------
379
380/// A query failed.
381#[derive(Debug)]
382pub struct QueryError {
383    /// Which step failed?
384    kind: QueryErrorKind,
385
386    /// The underlying IO error.
387    io: std::io::Error,
388}
389
390impl QueryError {
391    /// Create a new `QueryError`.
392    fn new(kind: QueryErrorKind, io: io::Error) -> Self {
393        Self { kind, io }
394    }
395
396    /// Create a new connect error.
397    fn connect(io: io::Error) -> Self {
398        Self::new(QueryErrorKind::Connect, io)
399    }
400
401    /// Create a new send error.
402    fn send(io: io::Error) -> Self {
403        Self::new(QueryErrorKind::Send, io)
404    }
405
406    /// Create a new short send error.
407    fn short_send() -> Self {
408        Self::new(
409            QueryErrorKind::Send,
410            io::Error::other("short request sent"),
411        )
412    }
413
414    /// Create a new timeout error.
415    fn timeout() -> Self {
416        Self::new(
417            QueryErrorKind::Timeout,
418            io::Error::new(io::ErrorKind::TimedOut, "timeout expired"),
419        )
420    }
421
422    /// Create a new receive error.
423    fn receive(io: io::Error) -> Self {
424        Self::new(QueryErrorKind::Receive, io)
425    }
426}
427
428impl QueryError {
429    /// Returns information about when the query has failed.
430    pub fn kind(&self) -> QueryErrorKind {
431        self.kind
432    }
433
434    /// Converts the query error into the underlying IO error.
435    pub fn io_error(self) -> std::io::Error {
436        self.io
437    }
438}
439
440impl From<QueryError> for std::io::Error {
441    fn from(err: QueryError) -> std::io::Error {
442        err.io
443    }
444}
445
446impl fmt::Display for QueryError {
447    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
448        write!(f, "{}: {}", self.kind.error_str(), self.io)
449    }
450}
451
452impl error::Error for QueryError {}
453
454//------------ QueryErrorKind ------------------------------------------------
455
456/// Which part of processing the query failed?
457#[derive(Copy, Clone, Debug)]
458pub enum QueryErrorKind {
459    /// Failed to connect to the remote.
460    Connect,
461
462    /// Failed to send the request.
463    Send,
464
465    /// The request has timed out.
466    Timeout,
467
468    /// Failed to read the response.
469    Receive,
470}
471
472impl QueryErrorKind {
473    /// Returns the string to be used when displaying a query error.
474    fn error_str(self) -> &'static str {
475        match self {
476            Self::Connect => "connecting failed",
477            Self::Send => "sending request failed",
478            Self::Timeout | Self::Receive => "reading response failed",
479        }
480    }
481}
482
483impl fmt::Display for QueryErrorKind {
484    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
485        f.write_str(match self {
486            Self::Connect => "connecting failed",
487            Self::Send => "sending request failed",
488            Self::Timeout => "request timeout",
489            Self::Receive => "reading response failed",
490        })
491    }
492}