1#![warn(missing_docs)]
8
9use crate::base::Message;
13use crate::net::client::protocol::{
14 AsyncConnect, AsyncDgramRecv, AsyncDgramRecvEx, AsyncDgramSend,
15 AsyncDgramSendEx,
16};
17use crate::net::client::request::{
18 ComposeRequest, Error, GetResponse, SendRequest,
19};
20use crate::utils::config::DefMinMax;
21use bytes::Bytes;
22use core::fmt;
23use octseq::OctetsInto;
24use std::boxed::Box;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::vec::Vec;
29use std::{error, io};
30use tokio::sync::Semaphore;
31use tokio::time::{timeout_at, Duration, Instant};
32use tracing::trace;
33
34const MAX_PARALLEL: DefMinMax<usize> = DefMinMax::new(100, 1, 1000);
38
39const READ_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
41 Duration::from_secs(5),
42 Duration::from_millis(1),
43 Duration::from_secs(60),
44);
45
46const MAX_RETRIES: DefMinMax<u8> = DefMinMax::new(5, 0, 100);
48
49const DEF_UDP_PAYLOAD_SIZE: u16 = 1232;
51
52const DEF_RECV_SIZE: usize = 2000;
54
55#[derive(Clone, Debug)]
59pub struct Config {
60 max_parallel: usize,
62
63 read_timeout: Duration,
65
66 max_retries: u8,
68
69 udp_payload_size: Option<u16>,
73
74 recv_size: usize,
76}
77
78impl Config {
79 pub fn new() -> Self {
81 Default::default()
82 }
83
84 pub fn set_max_parallel(&mut self, value: usize) {
91 self.max_parallel = MAX_PARALLEL.limit(value)
92 }
93
94 pub fn max_parallel(&self) -> usize {
96 self.max_parallel
97 }
98
99 pub fn set_read_timeout(&mut self, value: Duration) {
106 self.read_timeout = READ_TIMEOUT.limit(value)
107 }
108
109 pub fn read_timeout(&self) -> Duration {
111 self.read_timeout
112 }
113
114 pub fn set_max_retries(&mut self, value: u8) {
119 self.max_retries = MAX_RETRIES.limit(value)
120 }
121
122 pub fn max_retries(&self) -> u8 {
124 self.max_retries
125 }
126
127 pub fn set_udp_payload_size(&mut self, value: Option<u16>) {
143 self.udp_payload_size = value;
144 }
145
146 pub fn udp_payload_size(&self) -> Option<u16> {
148 self.udp_payload_size
149 }
150
151 pub fn set_recv_size(&mut self, size: usize) {
156 self.recv_size = size
157 }
158
159 pub fn recv_size(&self) -> usize {
161 self.recv_size
162 }
163}
164
165impl Default for Config {
166 fn default() -> Self {
167 Self {
168 max_parallel: MAX_PARALLEL.default(),
169 read_timeout: READ_TIMEOUT.default(),
170 max_retries: MAX_RETRIES.default(),
171 udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE),
172 recv_size: DEF_RECV_SIZE,
173 }
174 }
175}
176
177#[derive(Clone, Debug)]
181pub struct Connection<S> {
182 state: Arc<ConnectionState<S>>,
184}
185
186#[derive(Debug)]
189struct ConnectionState<S> {
190 config: Config,
192
193 connect: S,
195
196 semaphore: Semaphore,
198}
199
200impl<S> Connection<S> {
201 pub fn new(connect: S) -> Self {
203 Self::with_config(connect, Default::default())
204 }
205
206 pub fn with_config(connect: S, config: Config) -> Self {
208 Self {
209 state: Arc::new(ConnectionState {
210 semaphore: Semaphore::new(config.max_parallel),
211 config,
212 connect,
213 }),
214 }
215 }
216}
217
218impl<S> Connection<S>
219where
220 S: AsyncConnect,
221 S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin,
222{
223 async fn handle_request_impl<Req: ComposeRequest>(
229 self,
230 mut request: Req,
231 ) -> Result<Message<Bytes>, Error> {
232 let _permit = self
234 .state
235 .semaphore
236 .acquire()
237 .await
238 .expect("semaphore closed");
239
240 let mut buf = Vec::new();
242
243 for _ in 0..1 + self.state.config.max_retries {
245 let mut sock = self
246 .state
247 .connect
248 .connect()
249 .await
250 .map_err(QueryError::connect)?;
251
252 request.header_mut().set_random_id();
254
255 if let Some(size) = self.state.config.udp_payload_size {
257 request.set_udp_payload_size(size)
258 }
259
260 let request_msg = request.to_message()?;
262 let dgram = request_msg.as_slice();
263 let sent = sock.send(dgram).await.map_err(QueryError::send)?;
264 if sent != dgram.len() {
265 return Err(QueryError::short_send().into());
266 }
267
268 let deadline = Instant::now() + self.state.config.read_timeout;
270 while deadline > Instant::now() {
271 buf.resize(self.state.config.recv_size, 0);
274
275 let len =
276 match timeout_at(deadline, sock.recv(&mut buf)).await {
277 Ok(Ok(len)) => len,
278 Ok(Err(err)) => {
279 return Err(QueryError::receive(err).into());
281 }
282 Err(_) => {
283 trace!("Receive timed out");
285 break;
286 }
287 };
288
289 trace!("Received {len} bytes of message");
290 buf.truncate(len);
291
292 let answer = match Message::try_from_octets(buf) {
295 Ok(answer) => answer,
296 Err(old_buf) => {
297 trace!("Received bytes were garbage, reading more");
299 buf = old_buf;
300 continue;
301 }
302 };
303
304 if !request.is_answer(answer.for_slice()) {
305 trace!("Received message is not the answer we were waiting for, reading more");
307 buf = answer.into_octets();
308 continue;
309 }
310
311 trace!("Received message is accepted");
312 return Ok(answer.octets_into());
313 }
314 }
315 Err(QueryError::timeout().into())
316 }
317}
318
319impl<S, Req> SendRequest<Req> for Connection<S>
322where
323 S: AsyncConnect + Clone + Send + Sync + 'static,
324 S::Connection:
325 AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static,
326 Req: ComposeRequest + Send + Sync + 'static,
327{
328 fn send_request(
329 &self,
330 request_msg: Req,
331 ) -> Box<dyn GetResponse + Send + Sync> {
332 Box::new(Request {
333 fut: Box::pin(self.clone().handle_request_impl(request_msg)),
334 })
335 }
336}
337
338pub struct Request {
342 fut: Pin<
344 Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
345 >,
346}
347
348impl Request {
349 async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
351 (&mut self.fut).await
352 }
353}
354
355impl fmt::Debug for Request {
356 fn fmt(&self, _: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
357 todo!()
358 }
359}
360
361impl GetResponse for Request {
362 fn get_response(
363 &mut self,
364 ) -> Pin<
365 Box<
366 dyn Future<Output = Result<Message<Bytes>, Error>>
367 + Send
368 + Sync
369 + '_,
370 >,
371 > {
372 Box::pin(self.get_response_impl())
373 }
374}
375
376#[derive(Debug)]
382pub struct QueryError {
383 kind: QueryErrorKind,
385
386 io: std::io::Error,
388}
389
390impl QueryError {
391 fn new(kind: QueryErrorKind, io: io::Error) -> Self {
393 Self { kind, io }
394 }
395
396 fn connect(io: io::Error) -> Self {
398 Self::new(QueryErrorKind::Connect, io)
399 }
400
401 fn send(io: io::Error) -> Self {
403 Self::new(QueryErrorKind::Send, io)
404 }
405
406 fn short_send() -> Self {
408 Self::new(
409 QueryErrorKind::Send,
410 io::Error::other("short request sent"),
411 )
412 }
413
414 fn timeout() -> Self {
416 Self::new(
417 QueryErrorKind::Timeout,
418 io::Error::new(io::ErrorKind::TimedOut, "timeout expired"),
419 )
420 }
421
422 fn receive(io: io::Error) -> Self {
424 Self::new(QueryErrorKind::Receive, io)
425 }
426}
427
428impl QueryError {
429 pub fn kind(&self) -> QueryErrorKind {
431 self.kind
432 }
433
434 pub fn io_error(self) -> std::io::Error {
436 self.io
437 }
438}
439
440impl From<QueryError> for std::io::Error {
441 fn from(err: QueryError) -> std::io::Error {
442 err.io
443 }
444}
445
446impl fmt::Display for QueryError {
447 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
448 write!(f, "{}: {}", self.kind.error_str(), self.io)
449 }
450}
451
452impl error::Error for QueryError {}
453
454#[derive(Copy, Clone, Debug)]
458pub enum QueryErrorKind {
459 Connect,
461
462 Send,
464
465 Timeout,
467
468 Receive,
470}
471
472impl QueryErrorKind {
473 fn error_str(self) -> &'static str {
475 match self {
476 Self::Connect => "connecting failed",
477 Self::Send => "sending request failed",
478 Self::Timeout | Self::Receive => "reading response failed",
479 }
480 }
481}
482
483impl fmt::Display for QueryErrorKind {
484 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
485 f.write_str(match self {
486 Self::Connect => "connecting failed",
487 Self::Send => "sending request failed",
488 Self::Timeout => "request timeout",
489 Self::Receive => "reading response failed",
490 })
491 }
492}