mz_pgwire/
server.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::future::Future;
11use std::net::IpAddr;
12use std::pin::Pin;
13use std::str::FromStr;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use mz_frontegg_auth::Authenticator as FronteggAuthentication;
18use mz_ore::now::{SYSTEM_TIME, epoch_to_uuid_v7};
19use mz_pgwire_common::{
20    ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, FrontendStartupMessage,
21    MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
22};
23use mz_server_core::{Connection, ConnectionHandler, ReloadingTlsConfig};
24use openssl::ssl::Ssl;
25use tokio::io::AsyncWriteExt;
26use tokio_openssl::SslStream;
27use tracing::{debug, error, trace};
28
29use crate::codec::FramedConn;
30use crate::metrics::{Metrics, MetricsConfig};
31use crate::protocol;
32
33/// Configures a [`Server`].
34#[derive(Debug)]
35pub struct Config {
36    /// The label for the mz_connection_status metric.
37    pub label: &'static str,
38    /// A client for the adapter with which the server will communicate.
39    pub adapter_client: mz_adapter::Client,
40    /// The TLS configuration for the server.
41    ///
42    /// If not present, then TLS is not enabled, and clients requests to
43    /// negotiate TLS will be rejected.
44    pub tls: Option<ReloadingTlsConfig>,
45    /// The Frontegg authentication configuration.
46    ///
47    /// If present, Frontegg authentication is enabled, and users may present
48    /// a valid Frontegg API token as a password to authenticate. Otherwise,
49    /// password authentication is disabled.
50    pub frontegg: Option<FronteggAuthentication>,
51    /// Whether to use self-hosted authentication.
52    pub use_self_hosted_auth: bool,
53    /// The registry entries that the pgwire server uses to report metrics.
54    pub metrics: MetricsConfig,
55    /// Whether this is an internal server that permits access to restricted
56    /// system resources.
57    pub internal: bool,
58    /// Global connection limit and count
59    pub active_connection_counter: ConnectionCounter,
60    /// Helm chart version
61    pub helm_chart_version: Option<String>,
62}
63
64/// A server that communicates with clients via the pgwire protocol.
65pub struct Server {
66    tls: Option<ReloadingTlsConfig>,
67    adapter_client: mz_adapter::Client,
68    frontegg: Option<FronteggAuthentication>,
69    metrics: Metrics,
70    internal: bool,
71    active_connection_counter: ConnectionCounter,
72    helm_chart_version: Option<String>,
73    use_self_hosted_auth: bool,
74}
75
76#[async_trait]
77impl mz_server_core::Server for Server {
78    const NAME: &'static str = "pgwire";
79
80    fn handle_connection(&self, conn: Connection) -> ConnectionHandler {
81        // Using fully-qualified syntax means we won't accidentally call
82        // ourselves (i.e., silently infinitely recurse) if the name or type of
83        // `crate::Server::handle_connection` changes.
84        Box::pin(crate::Server::handle_connection(self, conn))
85    }
86}
87
88impl Server {
89    /// Constructs a new server.
90    pub fn new(config: Config) -> Server {
91        Server {
92            tls: config.tls,
93            adapter_client: config.adapter_client,
94            frontegg: config.frontegg,
95            use_self_hosted_auth: config.use_self_hosted_auth,
96            metrics: Metrics::new(config.metrics, config.label),
97            internal: config.internal,
98            active_connection_counter: config.active_connection_counter,
99            helm_chart_version: config.helm_chart_version,
100        }
101    }
102
103    #[mz_ore::instrument(level = "debug")]
104    pub fn handle_connection(
105        &self,
106        conn: Connection,
107    ) -> impl Future<Output = Result<(), anyhow::Error>> + 'static + Send {
108        let adapter_client = self.adapter_client.clone();
109        let frontegg = self.frontegg.clone();
110        let use_self_hosted_auth = self.use_self_hosted_auth;
111        let tls = self.tls.clone();
112        let internal = self.internal;
113        let metrics = self.metrics.clone();
114        let active_connection_counter = self.active_connection_counter.clone();
115        let helm_chart_version = self.helm_chart_version.clone();
116
117        // TODO(guswynn): remove this redundant_closure_call
118        #[allow(clippy::redundant_closure_call)]
119        async move {
120            let result = (|| {
121                async move {
122                    let conn_id = adapter_client.new_conn_id()?;
123                    let mut conn = Conn::Unencrypted(conn);
124                    loop {
125                        let message = decode_startup(&mut conn).await?;
126
127                        match &message {
128                            Some(message) => trace!("cid={} recv={:?}", conn_id, message),
129                            None => trace!("cid={} recv=<eof>", conn_id),
130                        }
131
132                        conn = match message {
133                            // Clients sometimes hang up during the startup sequence, e.g.
134                            // because they receive an unacceptable response to an
135                            // `SslRequest`. This is considered a graceful termination.
136                            None => return Ok(()),
137
138                            Some(FrontendStartupMessage::Startup {
139                                version,
140                                mut params,
141                            }) => {
142                                // If someone (usually the balancer) forwarded a connection UUID,
143                                // then use that, otherwise generate one.
144                                let conn_uuid_handle = conn.inner_mut().uuid_handle();
145                                let conn_uuid = params
146                                    .remove(CONN_UUID_KEY)
147                                    .and_then(|uuid| uuid.parse().inspect_err(|e| error!("pgwire connection with invalid conn UUID: {e}")).ok());
148                                let conn_uuid_forwarded = conn_uuid.is_some();
149                                // FIXME(ptravers): we should be able to inject the clock when instantiating the `Server`
150                                // but as of writing there's no great way, I can see, to harmonize the lifetimes of the return type
151                                // and &self which must house `NowFn`.
152                                let conn_uuid = conn_uuid.unwrap_or_else(|| epoch_to_uuid_v7(&(SYSTEM_TIME.clone())()));
153                                conn_uuid_handle.set(conn_uuid);
154                                debug!(conn_uuid = %conn_uuid_handle.display(), conn_uuid_forwarded, "starting new pgwire connection in adapter");
155
156                                let direct_peer_addr = conn
157                                    .inner_mut()
158                                    .peer_addr()
159                                    .context("fetching peer addr")?
160                                    .ip();
161                                let peer_addr= match params.remove(MZ_FORWARDED_FOR_KEY) {
162                                    Some(ip_str) => {
163                                        match IpAddr::from_str(&ip_str) {
164                                            Ok(ip) => Some(ip),
165                                            Err(e) => {
166                                                error!("pgwire connection with invalid mz_forwarded_for address: {e}");
167                                                None
168                                            }
169                                        }
170                                    }
171                                    None => Some(direct_peer_addr)
172                                };
173                                let mut conn = FramedConn::new(
174                                    conn_id.clone(),
175                                    peer_addr,
176                                    conn,
177                                );
178
179                                protocol::run(protocol::RunParams {
180                                    tls_mode: tls.as_ref().map(|tls| tls.mode),
181                                    adapter_client,
182                                    conn: &mut conn,
183                                    conn_uuid,
184                                    version,
185                                    params,
186                                    frontegg: frontegg.as_ref(),
187                                    use_self_hosted_auth,
188                                    internal,
189                                    active_connection_counter,
190                                    helm_chart_version,
191                                })
192                                .await?;
193                                conn.flush().await?;
194                                return Ok(());
195                            }
196
197                            Some(FrontendStartupMessage::CancelRequest {
198                                conn_id,
199                                secret_key,
200                            }) => {
201                                adapter_client.cancel_request(conn_id, secret_key);
202                                // For security, the client is not told whether the cancel
203                                // request succeeds or fails.
204                                return Ok(());
205                            }
206
207                            Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
208                                (Conn::Unencrypted(mut conn), Some(tls)) => {
209                                    trace!("cid={} send=AcceptSsl", conn_id);
210                                    conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
211                                    let mut ssl_stream =
212                                        SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
213                                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
214                                        let _ = ssl_stream.get_mut().shutdown().await;
215                                        return Err(e.into());
216                                    }
217                                    Conn::Ssl(ssl_stream)
218                                }
219                                (mut conn, _) => {
220                                    trace!("cid={} send=RejectSsl", conn_id);
221                                    conn.write_all(&[REJECT_ENCRYPTION]).await?;
222                                    conn
223                                }
224                            },
225
226                            Some(FrontendStartupMessage::GssEncRequest) => {
227                                trace!("cid={} send=RejectGssEnc", conn_id);
228                                conn.write_all(&[REJECT_ENCRYPTION]).await?;
229                                conn
230                            }
231                        }
232                    }
233                }
234            })()
235            .await;
236            metrics.connection_status(result.is_ok()).inc();
237            result
238        }
239    }
240}