mz_pgwire/server.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::future::Future;
11use std::net::IpAddr;
12use std::pin::Pin;
13use std::str::FromStr;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use mz_frontegg_auth::Authenticator as FronteggAuthentication;
18use mz_ore::now::{SYSTEM_TIME, epoch_to_uuid_v7};
19use mz_pgwire_common::{
20 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, FrontendStartupMessage,
21 MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
22};
23use mz_server_core::{Connection, ConnectionHandler, ReloadingTlsConfig};
24use openssl::ssl::Ssl;
25use tokio::io::AsyncWriteExt;
26use tokio_openssl::SslStream;
27use tracing::{debug, error, trace};
28
29use crate::codec::FramedConn;
30use crate::metrics::{Metrics, MetricsConfig};
31use crate::protocol;
32
33/// Configures a [`Server`].
34#[derive(Debug)]
35pub struct Config {
36 /// The label for the mz_connection_status metric.
37 pub label: &'static str,
38 /// A client for the adapter with which the server will communicate.
39 pub adapter_client: mz_adapter::Client,
40 /// The TLS configuration for the server.
41 ///
42 /// If not present, then TLS is not enabled, and clients requests to
43 /// negotiate TLS will be rejected.
44 pub tls: Option<ReloadingTlsConfig>,
45 /// The Frontegg authentication configuration.
46 ///
47 /// If present, Frontegg authentication is enabled, and users may present
48 /// a valid Frontegg API token as a password to authenticate. Otherwise,
49 /// password authentication is disabled.
50 pub frontegg: Option<FronteggAuthentication>,
51 /// Whether to use self-hosted authentication.
52 pub use_self_hosted_auth: bool,
53 /// The registry entries that the pgwire server uses to report metrics.
54 pub metrics: MetricsConfig,
55 /// Whether this is an internal server that permits access to restricted
56 /// system resources.
57 pub internal: bool,
58 /// Global connection limit and count
59 pub active_connection_counter: ConnectionCounter,
60 /// Helm chart version
61 pub helm_chart_version: Option<String>,
62}
63
64/// A server that communicates with clients via the pgwire protocol.
65pub struct Server {
66 tls: Option<ReloadingTlsConfig>,
67 adapter_client: mz_adapter::Client,
68 frontegg: Option<FronteggAuthentication>,
69 metrics: Metrics,
70 internal: bool,
71 active_connection_counter: ConnectionCounter,
72 helm_chart_version: Option<String>,
73 use_self_hosted_auth: bool,
74}
75
76#[async_trait]
77impl mz_server_core::Server for Server {
78 const NAME: &'static str = "pgwire";
79
80 fn handle_connection(&self, conn: Connection) -> ConnectionHandler {
81 // Using fully-qualified syntax means we won't accidentally call
82 // ourselves (i.e., silently infinitely recurse) if the name or type of
83 // `crate::Server::handle_connection` changes.
84 Box::pin(crate::Server::handle_connection(self, conn))
85 }
86}
87
88impl Server {
89 /// Constructs a new server.
90 pub fn new(config: Config) -> Server {
91 Server {
92 tls: config.tls,
93 adapter_client: config.adapter_client,
94 frontegg: config.frontegg,
95 use_self_hosted_auth: config.use_self_hosted_auth,
96 metrics: Metrics::new(config.metrics, config.label),
97 internal: config.internal,
98 active_connection_counter: config.active_connection_counter,
99 helm_chart_version: config.helm_chart_version,
100 }
101 }
102
103 #[mz_ore::instrument(level = "debug")]
104 pub fn handle_connection(
105 &self,
106 conn: Connection,
107 ) -> impl Future<Output = Result<(), anyhow::Error>> + 'static + Send {
108 let adapter_client = self.adapter_client.clone();
109 let frontegg = self.frontegg.clone();
110 let use_self_hosted_auth = self.use_self_hosted_auth;
111 let tls = self.tls.clone();
112 let internal = self.internal;
113 let metrics = self.metrics.clone();
114 let active_connection_counter = self.active_connection_counter.clone();
115 let helm_chart_version = self.helm_chart_version.clone();
116
117 // TODO(guswynn): remove this redundant_closure_call
118 #[allow(clippy::redundant_closure_call)]
119 async move {
120 let result = (|| {
121 async move {
122 let conn_id = adapter_client.new_conn_id()?;
123 let mut conn = Conn::Unencrypted(conn);
124 loop {
125 let message = decode_startup(&mut conn).await?;
126
127 match &message {
128 Some(message) => trace!("cid={} recv={:?}", conn_id, message),
129 None => trace!("cid={} recv=<eof>", conn_id),
130 }
131
132 conn = match message {
133 // Clients sometimes hang up during the startup sequence, e.g.
134 // because they receive an unacceptable response to an
135 // `SslRequest`. This is considered a graceful termination.
136 None => return Ok(()),
137
138 Some(FrontendStartupMessage::Startup {
139 version,
140 mut params,
141 }) => {
142 // If someone (usually the balancer) forwarded a connection UUID,
143 // then use that, otherwise generate one.
144 let conn_uuid_handle = conn.inner_mut().uuid_handle();
145 let conn_uuid = params
146 .remove(CONN_UUID_KEY)
147 .and_then(|uuid| uuid.parse().inspect_err(|e| error!("pgwire connection with invalid conn UUID: {e}")).ok());
148 let conn_uuid_forwarded = conn_uuid.is_some();
149 // FIXME(ptravers): we should be able to inject the clock when instantiating the `Server`
150 // but as of writing there's no great way, I can see, to harmonize the lifetimes of the return type
151 // and &self which must house `NowFn`.
152 let conn_uuid = conn_uuid.unwrap_or_else(|| epoch_to_uuid_v7(&(SYSTEM_TIME.clone())()));
153 conn_uuid_handle.set(conn_uuid);
154 debug!(conn_uuid = %conn_uuid_handle.display(), conn_uuid_forwarded, "starting new pgwire connection in adapter");
155
156 let direct_peer_addr = conn
157 .inner_mut()
158 .peer_addr()
159 .context("fetching peer addr")?
160 .ip();
161 let peer_addr= match params.remove(MZ_FORWARDED_FOR_KEY) {
162 Some(ip_str) => {
163 match IpAddr::from_str(&ip_str) {
164 Ok(ip) => Some(ip),
165 Err(e) => {
166 error!("pgwire connection with invalid mz_forwarded_for address: {e}");
167 None
168 }
169 }
170 }
171 None => Some(direct_peer_addr)
172 };
173 let mut conn = FramedConn::new(
174 conn_id.clone(),
175 peer_addr,
176 conn,
177 );
178
179 protocol::run(protocol::RunParams {
180 tls_mode: tls.as_ref().map(|tls| tls.mode),
181 adapter_client,
182 conn: &mut conn,
183 conn_uuid,
184 version,
185 params,
186 frontegg: frontegg.as_ref(),
187 use_self_hosted_auth,
188 internal,
189 active_connection_counter,
190 helm_chart_version,
191 })
192 .await?;
193 conn.flush().await?;
194 return Ok(());
195 }
196
197 Some(FrontendStartupMessage::CancelRequest {
198 conn_id,
199 secret_key,
200 }) => {
201 adapter_client.cancel_request(conn_id, secret_key);
202 // For security, the client is not told whether the cancel
203 // request succeeds or fails.
204 return Ok(());
205 }
206
207 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
208 (Conn::Unencrypted(mut conn), Some(tls)) => {
209 trace!("cid={} send=AcceptSsl", conn_id);
210 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
211 let mut ssl_stream =
212 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
213 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
214 let _ = ssl_stream.get_mut().shutdown().await;
215 return Err(e.into());
216 }
217 Conn::Ssl(ssl_stream)
218 }
219 (mut conn, _) => {
220 trace!("cid={} send=RejectSsl", conn_id);
221 conn.write_all(&[REJECT_ENCRYPTION]).await?;
222 conn
223 }
224 },
225
226 Some(FrontendStartupMessage::GssEncRequest) => {
227 trace!("cid={} send=RejectGssEnc", conn_id);
228 conn.write_all(&[REJECT_ENCRYPTION]).await?;
229 conn
230 }
231 }
232 }
233 }
234 })()
235 .await;
236 metrics.connection_status(result.is_ok()).inc();
237 result
238 }
239 }
240}