mz_pgwire_common/
conn.rs
1use std::pin::Pin;
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll};
13
14use async_trait::async_trait;
15use derivative::Derivative;
16use mz_ore::netio::AsyncReady;
17use mz_server_core::TlsMode;
18use tokio::io::{self, AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
19use tokio_openssl::SslStream;
20use tokio_postgres::error::SqlState;
21
22use crate::ErrorResponse;
23
24pub const CONN_UUID_KEY: &str = "mz_connection_uuid";
25pub const MZ_FORWARDED_FOR_KEY: &str = "mz_forwarded_for";
26
27#[derive(Debug)]
28pub enum Conn<A> {
29 Unencrypted(A),
30 Ssl(SslStream<A>),
31}
32
33impl<A> Conn<A> {
34 pub fn inner_mut(&mut self) -> &mut A {
35 match self {
36 Conn::Unencrypted(inner) => inner,
37 Conn::Ssl(inner) => inner.get_mut(),
38 }
39 }
40
41 pub fn ensure_tls_compatibility(
43 &self,
44 tls_mode: &Option<TlsMode>,
45 ) -> Result<(), ErrorResponse> {
46 match (tls_mode, self) {
51 (None, Conn::Unencrypted(_)) => (),
52 (None, Conn::Ssl(_)) => unreachable!(),
53 (Some(TlsMode::Allow), Conn::Unencrypted(_)) => (),
54 (Some(TlsMode::Allow), Conn::Ssl(_)) => (),
55 (Some(TlsMode::Require), Conn::Ssl(_)) => (),
56 (Some(TlsMode::Require), Conn::Unencrypted(_)) => {
57 return Err(ErrorResponse::fatal(
58 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
59 "TLS encryption is required",
60 ));
61 }
62 }
63
64 Ok(())
65 }
66}
67
68impl<A> AsyncRead for Conn<A>
69where
70 A: AsyncRead + AsyncWrite + Unpin,
71{
72 fn poll_read(
73 self: Pin<&mut Self>,
74 cx: &mut Context,
75 buf: &mut ReadBuf,
76 ) -> Poll<io::Result<()>> {
77 match self.get_mut() {
78 Conn::Unencrypted(inner) => Pin::new(inner).poll_read(cx, buf),
79 Conn::Ssl(inner) => Pin::new(inner).poll_read(cx, buf),
80 }
81 }
82}
83
84impl<A> AsyncWrite for Conn<A>
85where
86 A: AsyncRead + AsyncWrite + Unpin,
87{
88 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
89 match self.get_mut() {
90 Conn::Unencrypted(inner) => Pin::new(inner).poll_write(cx, buf),
91 Conn::Ssl(inner) => Pin::new(inner).poll_write(cx, buf),
92 }
93 }
94
95 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
96 match self.get_mut() {
97 Conn::Unencrypted(inner) => Pin::new(inner).poll_flush(cx),
98 Conn::Ssl(inner) => Pin::new(inner).poll_flush(cx),
99 }
100 }
101
102 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
103 match self.get_mut() {
104 Conn::Unencrypted(inner) => Pin::new(inner).poll_shutdown(cx),
105 Conn::Ssl(inner) => Pin::new(inner).poll_shutdown(cx),
106 }
107 }
108}
109
110#[async_trait]
111impl<A> AsyncReady for Conn<A>
112where
113 A: AsyncRead + AsyncWrite + AsyncReady + Sync + Unpin,
114{
115 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
116 match self {
117 Conn::Unencrypted(inner) => inner.ready(interest).await,
118 Conn::Ssl(inner) => inner.ready(interest).await,
119 }
120 }
121}
122
123#[derive(Debug, Clone, Copy)]
125pub struct UserMetadata {
126 pub is_admin: bool,
127 pub should_limit_connections: bool,
128}
129
130#[derive(Debug, Clone)]
131pub struct ConnectionCounter {
132 inner: Arc<Mutex<ConnectionCounterInner>>,
133}
134
135impl ConnectionCounter {
136 pub fn allocate_connection(
141 &self,
142 metadata: impl Into<UserMetadata>,
143 ) -> Result<Option<ConnectionHandle>, ConnectionError> {
144 let mut inner = self.inner.lock().expect("environmentd panicked");
145 let metadata = metadata.into();
146
147 if !metadata.should_limit_connections {
148 return Ok(None);
149 }
150
151 if (metadata.is_admin && inner.reserved_remaining() > 0)
152 || inner.non_reserved_remaining() > 0
153 {
154 inner.inc_connection_count();
155 Ok(Some(self.create_handle()))
156 } else {
157 Err(ConnectionError::TooManyConnections {
158 current: inner.current,
159 limit: inner.limit,
160 })
161 }
162 }
163
164 pub fn update_limit(&self, new_limit: u64) {
166 let mut inner = self.inner.lock().expect("environmentd panicked");
167 inner.limit = new_limit;
168 }
169
170 pub fn update_superuser_reserved(&self, new_reserve: u64) {
172 let mut inner = self.inner.lock().expect("environmentd panicked");
173 inner.superuser_reserved = new_reserve;
174 }
175
176 fn create_handle(&self) -> ConnectionHandle {
177 let inner = Arc::clone(&self.inner);
178 let decrement_fn = Box::new(move || {
179 let mut inner = inner.lock().expect("environmentd panicked");
180 inner.dec_connection_count();
181 });
182
183 ConnectionHandle {
184 decrement_fn: Some(decrement_fn),
185 }
186 }
187}
188
189impl Default for ConnectionCounter {
190 fn default() -> Self {
191 let inner = ConnectionCounterInner::new(10, 3);
192 ConnectionCounter {
193 inner: Arc::new(Mutex::new(inner)),
194 }
195 }
196}
197
198#[derive(Debug)]
199pub struct ConnectionCounterInner {
200 current: u64,
202 limit: u64,
204 superuser_reserved: u64,
206}
207
208impl ConnectionCounterInner {
209 fn new(limit: u64, superuser_reserved: u64) -> Self {
210 assert!(superuser_reserved < limit);
211 ConnectionCounterInner {
212 current: 0,
213 limit,
214 superuser_reserved,
215 }
216 }
217
218 fn inc_connection_count(&mut self) {
219 self.current += 1;
220 }
221
222 fn dec_connection_count(&mut self) {
223 self.current -= 1;
224 }
225
226 fn reserved_remaining(&self) -> u64 {
228 self.limit.saturating_sub(self.current)
231 }
232
233 fn non_reserved_remaining(&self) -> u64 {
235 let limit = self.limit.saturating_sub(self.superuser_reserved);
237 limit.saturating_sub(self.current)
240 }
241}
242
243#[derive(Derivative)]
247#[derivative(Debug)]
248pub struct ConnectionHandle {
249 #[derivative(Debug = "ignore")]
250 decrement_fn: Option<Box<dyn FnOnce() -> () + Send + Sync>>,
251}
252
253impl Drop for ConnectionHandle {
254 fn drop(&mut self) {
255 match self.decrement_fn.take() {
256 Some(decrement_fn) => (decrement_fn)(),
257 None => tracing::error!("ConnectionHandle dropped twice!?"),
258 }
259 }
260}
261
262#[derive(Debug)]
263pub enum ConnectionError {
264 TooManyConnections { current: u64, limit: u64 },
266}