mz_pgwire/
server.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::future::Future;
11use std::net::IpAddr;
12use std::pin::Pin;
13use std::str::FromStr;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use mz_authenticator::Authenticator;
18use mz_ore::now::{SYSTEM_TIME, epoch_to_uuid_v7};
19use mz_pgwire_common::{
20    ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, FrontendStartupMessage,
21    MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
22};
23use mz_server_core::listeners::AllowedRoles;
24use mz_server_core::{Connection, ConnectionHandler, ReloadingTlsConfig};
25use openssl::ssl::Ssl;
26use tokio::io::AsyncWriteExt;
27use tokio_metrics::TaskMetrics;
28use tokio_openssl::SslStream;
29use tracing::{debug, error, trace};
30
31use crate::codec::FramedConn;
32use crate::metrics::{Metrics, MetricsConfig};
33use crate::protocol;
34
35/// Configures a [`Server`].
36#[derive(Debug)]
37pub struct Config {
38    /// The label for the mz_connection_status metric.
39    pub label: &'static str,
40    /// A client for the adapter with which the server will communicate.
41    pub adapter_client: mz_adapter::Client,
42    /// The TLS configuration for the server.
43    ///
44    /// If not present, then TLS is not enabled, and clients requests to
45    /// negotiate TLS will be rejected.
46    pub tls: Option<ReloadingTlsConfig>,
47    /// Authentication method to use. Frontegg, Password, or None.
48    pub authenticator: Authenticator,
49    /// The registry entries that the pgwire server uses to report metrics.
50    pub metrics: MetricsConfig,
51    /// Global connection limit and count
52    pub active_connection_counter: ConnectionCounter,
53    /// Helm chart version
54    pub helm_chart_version: Option<String>,
55    /// Whether to allow reserved users (ie: mz_system).
56    pub allowed_roles: AllowedRoles,
57}
58
59/// A server that communicates with clients via the pgwire protocol.
60pub struct Server {
61    tls: Option<ReloadingTlsConfig>,
62    adapter_client: mz_adapter::Client,
63    authenticator: Authenticator,
64    metrics: Metrics,
65    active_connection_counter: ConnectionCounter,
66    helm_chart_version: Option<String>,
67    allowed_roles: AllowedRoles,
68}
69
70#[async_trait]
71impl mz_server_core::Server for Server {
72    const NAME: &'static str = "pgwire";
73
74    fn handle_connection(
75        &self,
76        conn: Connection,
77        tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
78    ) -> ConnectionHandler {
79        // Using fully-qualified syntax means we won't accidentally call
80        // ourselves (i.e., silently infinitely recurse) if the name or type of
81        // `crate::Server::handle_connection` changes.
82        Box::pin(crate::Server::handle_connection(
83            self,
84            conn,
85            tokio_metrics_intervals,
86        ))
87    }
88}
89
90impl Server {
91    /// Constructs a new server.
92    pub fn new(config: Config) -> Server {
93        Server {
94            tls: config.tls,
95            adapter_client: config.adapter_client,
96            authenticator: config.authenticator,
97            metrics: Metrics::new(config.metrics, config.label),
98            active_connection_counter: config.active_connection_counter,
99            helm_chart_version: config.helm_chart_version,
100            allowed_roles: config.allowed_roles,
101        }
102    }
103
104    #[mz_ore::instrument(level = "debug")]
105    pub fn handle_connection(
106        &self,
107        conn: Connection,
108        tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
109    ) -> impl Future<Output = Result<(), anyhow::Error>> + Send + 'static {
110        let adapter_client = self.adapter_client.clone();
111        let authenticator = self.authenticator.clone();
112        let tls = self.tls.clone();
113        let metrics = self.metrics.clone();
114        let active_connection_counter = self.active_connection_counter.clone();
115        let helm_chart_version = self.helm_chart_version.clone();
116        let allowed_roles = self.allowed_roles;
117
118        // TODO(guswynn): remove this redundant_closure_call
119        #[allow(clippy::redundant_closure_call)]
120        async move {
121            let result = (|| {
122                async move {
123                    let conn_id = adapter_client.new_conn_id()?;
124                    let mut conn = Conn::Unencrypted(conn);
125                    loop {
126                        let message = decode_startup(&mut conn).await?;
127
128                        match &message {
129                            Some(message) => trace!("cid={} recv={:?}", conn_id, message),
130                            None => trace!("cid={} recv=<eof>", conn_id),
131                        }
132
133                        conn = match message {
134                            // Clients sometimes hang up during the startup sequence, e.g.
135                            // because they receive an unacceptable response to an
136                            // `SslRequest`. This is considered a graceful termination.
137                            None => return Ok(()),
138
139                            Some(FrontendStartupMessage::Startup {
140                                version,
141                                mut params,
142                            }) => {
143                                // If someone (usually the balancer) forwarded a connection UUID,
144                                // then use that, otherwise generate one.
145                                let conn_uuid_handle = conn.inner_mut().uuid_handle();
146                                let conn_uuid = params
147                                    .remove(CONN_UUID_KEY)
148                                    .and_then(|uuid| uuid.parse().inspect_err(|e| error!("pgwire connection with invalid conn UUID: {e}")).ok());
149                                let conn_uuid_forwarded = conn_uuid.is_some();
150                                // FIXME(ptravers): we should be able to inject the clock when instantiating the `Server`
151                                // but as of writing there's no great way, I can see, to harmonize the lifetimes of the return type
152                                // and &self which must house `NowFn`.
153                                let conn_uuid = conn_uuid.unwrap_or_else(|| epoch_to_uuid_v7(&(SYSTEM_TIME.clone())()));
154                                conn_uuid_handle.set(conn_uuid);
155                                debug!(conn_uuid = %conn_uuid_handle.display(), conn_uuid_forwarded, "starting new pgwire connection in adapter");
156
157                                let direct_peer_addr = conn
158                                    .inner_mut()
159                                    .peer_addr()
160                                    .context("fetching peer addr")?
161                                    .ip();
162                                let peer_addr= match params.remove(MZ_FORWARDED_FOR_KEY) {
163                                    Some(ip_str) => {
164                                        match IpAddr::from_str(&ip_str) {
165                                            Ok(ip) => Some(ip),
166                                            Err(e) => {
167                                                error!("pgwire connection with invalid mz_forwarded_for address: {e}");
168                                                None
169                                            }
170                                        }
171                                    }
172                                    None => Some(direct_peer_addr)
173                                };
174                                let mut conn = FramedConn::new(
175                                    conn_id.clone(),
176                                    peer_addr,
177                                    conn,
178                                );
179
180                                protocol::run(protocol::RunParams {
181                                    tls_mode: tls.as_ref().map(|tls| tls.mode),
182                                    adapter_client,
183                                    conn: &mut conn,
184                                    conn_uuid,
185                                    version,
186                                    params,
187                                    authenticator,
188                                    active_connection_counter,
189                                    helm_chart_version,
190                                    allowed_roles,
191                                    tokio_metrics_intervals,
192                                })
193                                .await?;
194                                conn.flush().await?;
195                                return Ok(());
196                            }
197
198                            Some(FrontendStartupMessage::CancelRequest {
199                                conn_id,
200                                secret_key,
201                            }) => {
202                                adapter_client.cancel_request(conn_id, secret_key);
203                                // For security, the client is not told whether the cancel
204                                // request succeeds or fails.
205                                return Ok(());
206                            }
207
208                            Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
209                                (Conn::Unencrypted(mut conn), Some(tls)) => {
210                                    trace!("cid={} send=AcceptSsl", conn_id);
211                                    conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
212                                    let mut ssl_stream =
213                                        SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
214                                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
215                                        let _ = ssl_stream.get_mut().shutdown().await;
216                                        return Err(e.into());
217                                    }
218                                    Conn::Ssl(ssl_stream)
219                                }
220                                (mut conn, _) => {
221                                    trace!("cid={} send=RejectSsl", conn_id);
222                                    conn.write_all(&[REJECT_ENCRYPTION]).await?;
223                                    conn
224                                }
225                            },
226
227                            Some(FrontendStartupMessage::GssEncRequest) => {
228                                trace!("cid={} send=RejectGssEnc", conn_id);
229                                conn.write_all(&[REJECT_ENCRYPTION]).await?;
230                                conn
231                            }
232                        }
233                    }
234                }
235            })()
236            .await;
237            metrics.connection_status(result.is_ok()).inc();
238            result
239        }
240    }
241}