mz_pgwire_common/
conn.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::pin::Pin;
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll};
13
14use async_trait::async_trait;
15use derivative::Derivative;
16use mz_ore::netio::AsyncReady;
17use mz_server_core::TlsMode;
18use tokio::io::{self, AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
19use tokio_openssl::SslStream;
20use tokio_postgres::error::SqlState;
21
22use crate::ErrorResponse;
23
24pub const CONN_UUID_KEY: &str = "mz_connection_uuid";
25pub const MZ_FORWARDED_FOR_KEY: &str = "mz_forwarded_for";
26
27#[derive(Debug)]
28pub enum Conn<A> {
29    Unencrypted(A),
30    Ssl(SslStream<A>),
31}
32
33impl<A> Conn<A> {
34    pub fn inner_mut(&mut self) -> &mut A {
35        match self {
36            Conn::Unencrypted(inner) => inner,
37            Conn::Ssl(inner) => inner.get_mut(),
38        }
39    }
40
41    /// Returns an error if tls_mode is incompatible with this connection's stream type.
42    pub fn ensure_tls_compatibility(
43        &self,
44        tls_mode: &Option<TlsMode>,
45    ) -> Result<(), ErrorResponse> {
46        // Validate that the connection is compatible with the TLS mode.
47        //
48        // The match here explicitly spells out all cases to be resilient to
49        // future changes to TlsMode.
50        match (tls_mode, self) {
51            (None, Conn::Unencrypted(_)) => (),
52            (None, Conn::Ssl(_)) => unreachable!(),
53            (Some(TlsMode::Allow), Conn::Unencrypted(_)) => (),
54            (Some(TlsMode::Allow), Conn::Ssl(_)) => (),
55            (Some(TlsMode::Require), Conn::Ssl(_)) => (),
56            (Some(TlsMode::Require), Conn::Unencrypted(_)) => {
57                return Err(ErrorResponse::fatal(
58                    SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
59                    "TLS encryption is required",
60                ));
61            }
62        }
63
64        Ok(())
65    }
66}
67
68impl<A> AsyncRead for Conn<A>
69where
70    A: AsyncRead + AsyncWrite + Unpin,
71{
72    fn poll_read(
73        self: Pin<&mut Self>,
74        cx: &mut Context,
75        buf: &mut ReadBuf,
76    ) -> Poll<io::Result<()>> {
77        match self.get_mut() {
78            Conn::Unencrypted(inner) => Pin::new(inner).poll_read(cx, buf),
79            Conn::Ssl(inner) => Pin::new(inner).poll_read(cx, buf),
80        }
81    }
82}
83
84impl<A> AsyncWrite for Conn<A>
85where
86    A: AsyncRead + AsyncWrite + Unpin,
87{
88    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
89        match self.get_mut() {
90            Conn::Unencrypted(inner) => Pin::new(inner).poll_write(cx, buf),
91            Conn::Ssl(inner) => Pin::new(inner).poll_write(cx, buf),
92        }
93    }
94
95    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
96        match self.get_mut() {
97            Conn::Unencrypted(inner) => Pin::new(inner).poll_flush(cx),
98            Conn::Ssl(inner) => Pin::new(inner).poll_flush(cx),
99        }
100    }
101
102    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
103        match self.get_mut() {
104            Conn::Unencrypted(inner) => Pin::new(inner).poll_shutdown(cx),
105            Conn::Ssl(inner) => Pin::new(inner).poll_shutdown(cx),
106        }
107    }
108}
109
110#[async_trait]
111impl<A> AsyncReady for Conn<A>
112where
113    A: AsyncRead + AsyncWrite + AsyncReady + Sync + Unpin,
114{
115    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
116        match self {
117            Conn::Unencrypted(inner) => inner.ready(interest).await,
118            Conn::Ssl(inner) => inner.ready(interest).await,
119        }
120    }
121}
122
123/// Metadata about a user that is required to allocate a [`ConnectionHandle`].
124#[derive(Debug, Clone, Copy)]
125pub struct UserMetadata {
126    pub is_admin: bool,
127    pub should_limit_connections: bool,
128}
129
130#[derive(Debug, Clone)]
131pub struct ConnectionCounter {
132    inner: Arc<Mutex<ConnectionCounterInner>>,
133}
134
135impl ConnectionCounter {
136    /// Returns a [`ConnectionHandle`] which must be kept alive for the entire duration of the
137    /// external connection.
138    ///
139    /// Dropping the [`ConnectionHandle`] decrements the connection count.
140    pub fn allocate_connection(
141        &self,
142        metadata: impl Into<UserMetadata>,
143    ) -> Result<Option<ConnectionHandle>, ConnectionError> {
144        let mut inner = self.inner.lock().expect("environmentd panicked");
145        let metadata = metadata.into();
146
147        if !metadata.should_limit_connections {
148            return Ok(None);
149        }
150
151        if (metadata.is_admin && inner.reserved_remaining() > 0)
152            || inner.non_reserved_remaining() > 0
153        {
154            inner.inc_connection_count();
155            Ok(Some(self.create_handle()))
156        } else {
157            Err(ConnectionError::TooManyConnections {
158                current: inner.current,
159                limit: inner.limit,
160            })
161        }
162    }
163
164    /// Updates the maximum number of connections we allow.
165    pub fn update_limit(&self, new_limit: u64) {
166        let mut inner = self.inner.lock().expect("environmentd panicked");
167        inner.limit = new_limit;
168    }
169
170    /// Updates the number of connections we reserve for superusers.
171    pub fn update_superuser_reserved(&self, new_reserve: u64) {
172        let mut inner = self.inner.lock().expect("environmentd panicked");
173        inner.superuser_reserved = new_reserve;
174    }
175
176    fn create_handle(&self) -> ConnectionHandle {
177        let inner = Arc::clone(&self.inner);
178        let decrement_fn = Box::new(move || {
179            let mut inner = inner.lock().expect("environmentd panicked");
180            inner.dec_connection_count();
181        });
182
183        ConnectionHandle {
184            decrement_fn: Some(decrement_fn),
185        }
186    }
187}
188
189impl Default for ConnectionCounter {
190    fn default() -> Self {
191        let inner = ConnectionCounterInner::new(10, 3);
192        ConnectionCounter {
193            inner: Arc::new(Mutex::new(inner)),
194        }
195    }
196}
197
198#[derive(Debug)]
199pub struct ConnectionCounterInner {
200    /// Current number of connections.
201    current: u64,
202    /// Total number of connections allowed.
203    limit: u64,
204    /// Number of connections in `limit` we'll reserve for superusers.
205    superuser_reserved: u64,
206}
207
208impl ConnectionCounterInner {
209    fn new(limit: u64, superuser_reserved: u64) -> Self {
210        assert!(superuser_reserved < limit);
211        ConnectionCounterInner {
212            current: 0,
213            limit,
214            superuser_reserved,
215        }
216    }
217
218    fn inc_connection_count(&mut self) {
219        self.current += 1;
220    }
221
222    fn dec_connection_count(&mut self) {
223        self.current -= 1;
224    }
225
226    /// The number of connections still available to superusers.
227    fn reserved_remaining(&self) -> u64 {
228        // Use a saturating sub in case the limit is reduced below the number
229        // of current connections.
230        self.limit.saturating_sub(self.current)
231    }
232
233    /// The number of connections available to non-superusers.
234    fn non_reserved_remaining(&self) -> u64 {
235        // This ensures that at least a few connections remain for superusers.
236        let limit = self.limit.saturating_sub(self.superuser_reserved);
237        // Use a saturating sub in case the limit is reduced below the number
238        // of current connections.
239        limit.saturating_sub(self.current)
240    }
241}
242
243/// Handle to an open connection, allows us to maintain a count of all connections.
244///
245/// When Drop-ed decrements the count of open connections.
246#[derive(Derivative)]
247#[derivative(Debug)]
248pub struct ConnectionHandle {
249    #[derivative(Debug = "ignore")]
250    decrement_fn: Option<Box<dyn FnOnce() -> () + Send + Sync>>,
251}
252
253impl Drop for ConnectionHandle {
254    fn drop(&mut self) {
255        match self.decrement_fn.take() {
256            Some(decrement_fn) => (decrement_fn)(),
257            None => tracing::error!("ConnectionHandle dropped twice!?"),
258        }
259    }
260}
261
262#[derive(Debug)]
263pub enum ConnectionError {
264    /// There were too many connections
265    TooManyConnections { current: u64, limit: u64 },
266}