Skip to main content

mz_pgwire/
server.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::future::Future;
11use std::net::IpAddr;
12use std::pin::Pin;
13use std::str::FromStr;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use mz_authenticator::GenericOidcAuthenticator;
18use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
19use mz_ore::now::{SYSTEM_TIME, epoch_to_uuid_v7};
20use mz_pgwire_common::{
21    ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, FrontendStartupMessage,
22    MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
23};
24use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind};
25use mz_server_core::{Connection, ConnectionHandler, ReloadingTlsConfig};
26use openssl::ssl::Ssl;
27use tokio::io::AsyncWriteExt;
28use tokio_metrics::TaskMetrics;
29use tokio_openssl::SslStream;
30use tracing::{debug, error, trace};
31
32use crate::codec::FramedConn;
33use crate::metrics::{Metrics, MetricsConfig};
34use crate::protocol;
35
36/// Configures a [`Server`].
37#[derive(Debug)]
38pub struct Config {
39    /// The label for the mz_connection_status metric.
40    pub label: &'static str,
41    /// A client for the adapter with which the server will communicate.
42    pub adapter_client: mz_adapter::Client,
43    /// The TLS configuration for the server.
44    ///
45    /// If not present, then TLS is not enabled, and clients requests to
46    /// negotiate TLS will be rejected.
47    pub tls: Option<ReloadingTlsConfig>,
48    /// Frontegg JWT authenticator.
49    pub frontegg: Option<FronteggAuthenticator>,
50    /// OIDC authenticator.
51    pub oidc: GenericOidcAuthenticator,
52    /// The authentication method defined by the server's listener
53    /// configuration.
54    pub authenticator_kind: AuthenticatorKind,
55    /// The registry entries that the pgwire server uses to report metrics.
56    pub metrics: MetricsConfig,
57    /// Global connection limit and count
58    pub active_connection_counter: ConnectionCounter,
59    /// Helm chart version
60    pub helm_chart_version: Option<String>,
61    /// Whether to allow reserved users (ie: mz_system).
62    pub allowed_roles: AllowedRoles,
63}
64
65/// A server that communicates with clients via the pgwire protocol.
66pub struct Server {
67    tls: Option<ReloadingTlsConfig>,
68    adapter_client: mz_adapter::Client,
69    authenticator_kind: AuthenticatorKind,
70    frontegg: Option<FronteggAuthenticator>,
71    oidc: GenericOidcAuthenticator,
72    metrics: Metrics,
73    active_connection_counter: ConnectionCounter,
74    helm_chart_version: Option<String>,
75    allowed_roles: AllowedRoles,
76}
77
78#[async_trait]
79impl mz_server_core::Server for Server {
80    const NAME: &'static str = "pgwire";
81
82    fn handle_connection(
83        &self,
84        conn: Connection,
85        tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
86    ) -> ConnectionHandler {
87        // Using fully-qualified syntax means we won't accidentally call
88        // ourselves (i.e., silently infinitely recurse) if the name or type of
89        // `crate::Server::handle_connection` changes.
90        Box::pin(crate::Server::handle_connection(
91            self,
92            conn,
93            tokio_metrics_intervals,
94        ))
95    }
96}
97
98impl Server {
99    /// Constructs a new server.
100    pub fn new(config: Config) -> Server {
101        Server {
102            tls: config.tls,
103            adapter_client: config.adapter_client,
104            authenticator_kind: config.authenticator_kind,
105            frontegg: config.frontegg,
106            oidc: config.oidc,
107            metrics: Metrics::new(config.metrics, config.label),
108            active_connection_counter: config.active_connection_counter,
109            helm_chart_version: config.helm_chart_version,
110            allowed_roles: config.allowed_roles,
111        }
112    }
113
114    #[mz_ore::instrument(level = "debug")]
115    pub fn handle_connection(
116        &self,
117        conn: Connection,
118        tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
119    ) -> impl Future<Output = Result<(), anyhow::Error>> + Send + 'static {
120        let adapter_client = self.adapter_client.clone();
121        let authenticator_kind = self.authenticator_kind;
122        let frontegg = self.frontegg.clone();
123        let oidc = self.oidc.clone();
124        let tls = self.tls.clone();
125        let metrics = self.metrics.clone();
126        let active_connection_counter = self.active_connection_counter.clone();
127        let helm_chart_version = self.helm_chart_version.clone();
128        let allowed_roles = self.allowed_roles;
129
130        // TODO(guswynn): remove this redundant_closure_call
131        #[allow(clippy::redundant_closure_call)]
132        async move {
133            let result = (|| {
134                async move {
135                    let conn_id = adapter_client.new_conn_id()?;
136                    let mut conn = Conn::Unencrypted(conn);
137                    loop {
138                        let message = decode_startup(&mut conn).await?;
139
140                        match &message {
141                            Some(message) => trace!("cid={} recv={:?}", conn_id, message),
142                            None => trace!("cid={} recv=<eof>", conn_id),
143                        }
144
145                        conn = match message {
146                            // Clients sometimes hang up during the startup sequence, e.g.
147                            // because they receive an unacceptable response to an
148                            // `SslRequest`. This is considered a graceful termination.
149                            None => return Ok(()),
150
151                            Some(FrontendStartupMessage::Startup {
152                                version,
153                                mut params,
154                            }) => {
155                                // If someone (usually the balancer) forwarded a connection UUID,
156                                // then use that, otherwise generate one.
157                                let conn_uuid_handle = conn.inner_mut().uuid_handle();
158                                let conn_uuid = params
159                                    .remove(CONN_UUID_KEY)
160                                    .and_then(|uuid| {
161                                        uuid.parse()
162                                            .inspect_err(|e| {
163                                                error!(
164                                                    "pgwire connection with invalid conn UUID: {e}",
165                                                )
166                                            })
167                                            .ok()
168                                    });
169                                let conn_uuid_forwarded = conn_uuid.is_some();
170                                // FIXME(ptravers): we should be able to inject
171                                // the clock when instantiating the `Server`
172                                // but as of writing there's no great way, I can
173                                // see, to harmonize the lifetimes of the return
174                                // type and &self which must house `NowFn`.
175                                let conn_uuid = conn_uuid.unwrap_or_else(
176                                    || epoch_to_uuid_v7(&(SYSTEM_TIME.clone())()),
177                                );
178                                conn_uuid_handle.set(conn_uuid);
179                                debug!(
180                                    conn_uuid = %conn_uuid_handle.display(),
181                                    conn_uuid_forwarded,
182                                    "starting new pgwire connection in adapter",
183                                );
184
185                                let direct_peer_addr = conn
186                                    .inner_mut()
187                                    .peer_addr()
188                                    .context("fetching peer addr")?
189                                    .ip();
190                                let peer_addr= match params.remove(MZ_FORWARDED_FOR_KEY) {
191                                    Some(ip_str) => {
192                                        match IpAddr::from_str(&ip_str) {
193                                            Ok(ip) => Some(ip),
194                                            Err(e) => {
195                                                error!("pgwire connection with invalid mz_forwarded_for address: {e}");
196                                                None
197                                            }
198                                        }
199                                    }
200                                    None => Some(direct_peer_addr)
201                                };
202                                let mut conn = FramedConn::new(
203                                    conn_id.clone(),
204                                    peer_addr,
205                                    conn,
206                                );
207
208                                protocol::run(protocol::RunParams {
209                                    tls_mode: tls.as_ref().map(|tls| tls.mode),
210                                    adapter_client,
211                                    conn: &mut conn,
212                                    conn_uuid,
213                                    version,
214                                    params,
215                                    frontegg,
216                                    oidc,
217                                    authenticator_kind,
218                                    active_connection_counter,
219                                    helm_chart_version,
220                                    allowed_roles,
221                                    tokio_metrics_intervals,
222                                })
223                                .await?;
224                                conn.flush().await?;
225                                return Ok(());
226                            }
227
228                            Some(FrontendStartupMessage::CancelRequest {
229                                conn_id,
230                                secret_key,
231                            }) => {
232                                adapter_client.cancel_request(conn_id, secret_key);
233                                // For security, the client is not told whether the cancel
234                                // request succeeds or fails.
235                                return Ok(());
236                            }
237
238                            Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
239                                (Conn::Unencrypted(mut conn), Some(tls)) => {
240                                    trace!("cid={} send=AcceptSsl", conn_id);
241                                    conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
242                                    let mut ssl_stream =
243                                        SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
244                                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
245                                        let _ = ssl_stream.get_mut().shutdown().await;
246                                        return Err(e.into());
247                                    }
248                                    Conn::Ssl(ssl_stream)
249                                }
250                                (mut conn, _) => {
251                                    trace!("cid={} send=RejectSsl", conn_id);
252                                    conn.write_all(&[REJECT_ENCRYPTION]).await?;
253                                    conn
254                                }
255                            },
256
257                            Some(FrontendStartupMessage::GssEncRequest) => {
258                                trace!("cid={} send=RejectGssEnc", conn_id);
259                                conn.write_all(&[REJECT_ENCRYPTION]).await?;
260                                conn
261                            }
262                        }
263                    }
264                }
265            })()
266            .await;
267            metrics.connection_status(result.is_ok()).inc();
268            result
269        }
270    }
271}