1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Types related to connection monitoring and management.

use std::fmt::{Debug, Formatter};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};

/// Metadata that tracks the state of an active connection.
#[derive(Clone)]
pub struct ConnectionMetadata {
    is_proxied: bool,
    remote_addr: Option<SocketAddr>,
    poison_fn: Arc<dyn Fn() + Send + Sync>,
}

impl ConnectionMetadata {
    /// Poison this connection, ensuring that it won't be reused.
    pub fn poison(&self) {
        tracing::info!("smithy connection was poisoned");
        (self.poison_fn)()
    }

    /// Create a new [`ConnectionMetadata`].
    pub fn new(
        is_proxied: bool,
        remote_addr: Option<SocketAddr>,
        poison: impl Fn() + Send + Sync + 'static,
    ) -> Self {
        Self {
            is_proxied,
            remote_addr,
            poison_fn: Arc::new(poison),
        }
    }

    /// Get the remote address for this connection, if one is set.
    pub fn remote_addr(&self) -> Option<SocketAddr> {
        self.remote_addr
    }
}

impl Debug for ConnectionMetadata {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("SmithyConnection")
            .field("is_proxied", &self.is_proxied)
            .field("remote_addr", &self.remote_addr)
            .finish()
    }
}

type LoaderFn = dyn Fn() -> Option<ConnectionMetadata> + Send + Sync;

/// State for a middleware that will monitor and manage connections.
#[allow(missing_debug_implementations)]
#[derive(Clone, Default)]
pub struct CaptureSmithyConnection {
    loader: Arc<Mutex<Option<Box<LoaderFn>>>>,
}

impl CaptureSmithyConnection {
    /// Create a new connection monitor.
    pub fn new() -> Self {
        Self {
            loader: Default::default(),
        }
    }

    /// Set the retriever that will capture the `hyper` connection.
    pub fn set_connection_retriever<F>(&self, f: F)
    where
        F: Fn() -> Option<ConnectionMetadata> + Send + Sync + 'static,
    {
        *self.loader.lock().unwrap() = Some(Box::new(f));
    }

    /// Get the associated connection metadata.
    pub fn get(&self) -> Option<ConnectionMetadata> {
        match self.loader.lock().unwrap().as_ref() {
            Some(loader) => loader(),
            None => {
                println!("no loader was set :-/");
                None
            }
        }
    }
}

#[cfg(test)]
mod test {
    use crate::connection::{CaptureSmithyConnection, ConnectionMetadata};

    #[test]
    fn retrieve_connection_metadata() {
        let retriever = CaptureSmithyConnection::new();
        let retriever_clone = retriever.clone();
        assert!(retriever.get().is_none());
        retriever.set_connection_retriever(|| Some(ConnectionMetadata::new(true, None, || {})));

        assert!(retriever.get().is_some());
        assert!(retriever_clone.get().is_some());
    }
}