mz_pgwire/server.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::future::Future;
11use std::net::IpAddr;
12use std::pin::Pin;
13use std::str::FromStr;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use mz_authenticator::GenericOidcAuthenticator;
18use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
19use mz_ore::now::{SYSTEM_TIME, epoch_to_uuid_v7};
20use mz_pgwire_common::{
21 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, FrontendStartupMessage,
22 MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
23};
24use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind};
25use mz_server_core::{Connection, ConnectionHandler, ReloadingTlsConfig};
26use openssl::ssl::Ssl;
27use tokio::io::AsyncWriteExt;
28use tokio_metrics::TaskMetrics;
29use tokio_openssl::SslStream;
30use tracing::{debug, error, trace};
31
32use crate::codec::FramedConn;
33use crate::metrics::{Metrics, MetricsConfig};
34use crate::protocol;
35
36/// Configures a [`Server`].
37#[derive(Debug)]
38pub struct Config {
39 /// The label for the mz_connection_status metric.
40 pub label: &'static str,
41 /// A client for the adapter with which the server will communicate.
42 pub adapter_client: mz_adapter::Client,
43 /// The TLS configuration for the server.
44 ///
45 /// If not present, then TLS is not enabled, and clients requests to
46 /// negotiate TLS will be rejected.
47 pub tls: Option<ReloadingTlsConfig>,
48 /// Frontegg JWT authenticator.
49 pub frontegg: Option<FronteggAuthenticator>,
50 /// OIDC authenticator.
51 pub oidc: GenericOidcAuthenticator,
52 /// The authentication method defined by the server's listener
53 /// configuration.
54 pub authenticator_kind: AuthenticatorKind,
55 /// The registry entries that the pgwire server uses to report metrics.
56 pub metrics: MetricsConfig,
57 /// Global connection limit and count
58 pub active_connection_counter: ConnectionCounter,
59 /// Helm chart version
60 pub helm_chart_version: Option<String>,
61 /// Whether to allow reserved users (ie: mz_system).
62 pub allowed_roles: AllowedRoles,
63}
64
65/// A server that communicates with clients via the pgwire protocol.
66pub struct Server {
67 tls: Option<ReloadingTlsConfig>,
68 adapter_client: mz_adapter::Client,
69 authenticator_kind: AuthenticatorKind,
70 frontegg: Option<FronteggAuthenticator>,
71 oidc: GenericOidcAuthenticator,
72 metrics: Metrics,
73 active_connection_counter: ConnectionCounter,
74 helm_chart_version: Option<String>,
75 allowed_roles: AllowedRoles,
76}
77
78#[async_trait]
79impl mz_server_core::Server for Server {
80 const NAME: &'static str = "pgwire";
81
82 fn handle_connection(
83 &self,
84 conn: Connection,
85 tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
86 ) -> ConnectionHandler {
87 // Using fully-qualified syntax means we won't accidentally call
88 // ourselves (i.e., silently infinitely recurse) if the name or type of
89 // `crate::Server::handle_connection` changes.
90 Box::pin(crate::Server::handle_connection(
91 self,
92 conn,
93 tokio_metrics_intervals,
94 ))
95 }
96}
97
98impl Server {
99 /// Constructs a new server.
100 pub fn new(config: Config) -> Server {
101 Server {
102 tls: config.tls,
103 adapter_client: config.adapter_client,
104 authenticator_kind: config.authenticator_kind,
105 frontegg: config.frontegg,
106 oidc: config.oidc,
107 metrics: Metrics::new(config.metrics, config.label),
108 active_connection_counter: config.active_connection_counter,
109 helm_chart_version: config.helm_chart_version,
110 allowed_roles: config.allowed_roles,
111 }
112 }
113
114 #[mz_ore::instrument(level = "debug")]
115 pub fn handle_connection(
116 &self,
117 conn: Connection,
118 tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
119 ) -> impl Future<Output = Result<(), anyhow::Error>> + Send + 'static {
120 let adapter_client = self.adapter_client.clone();
121 let authenticator_kind = self.authenticator_kind;
122 let frontegg = self.frontegg.clone();
123 let oidc = self.oidc.clone();
124 let tls = self.tls.clone();
125 let metrics = self.metrics.clone();
126 let active_connection_counter = self.active_connection_counter.clone();
127 let helm_chart_version = self.helm_chart_version.clone();
128 let allowed_roles = self.allowed_roles;
129
130 // TODO(guswynn): remove this redundant_closure_call
131 #[allow(clippy::redundant_closure_call)]
132 async move {
133 let result = (|| {
134 async move {
135 let conn_id = adapter_client.new_conn_id()?;
136 let mut conn = Conn::Unencrypted(conn);
137 loop {
138 let message = decode_startup(&mut conn).await?;
139
140 match &message {
141 Some(message) => trace!("cid={} recv={:?}", conn_id, message),
142 None => trace!("cid={} recv=<eof>", conn_id),
143 }
144
145 conn = match message {
146 // Clients sometimes hang up during the startup sequence, e.g.
147 // because they receive an unacceptable response to an
148 // `SslRequest`. This is considered a graceful termination.
149 None => return Ok(()),
150
151 Some(FrontendStartupMessage::Startup {
152 version,
153 mut params,
154 }) => {
155 // If someone (usually the balancer) forwarded a connection UUID,
156 // then use that, otherwise generate one.
157 let conn_uuid_handle = conn.inner_mut().uuid_handle();
158 let conn_uuid = params
159 .remove(CONN_UUID_KEY)
160 .and_then(|uuid| {
161 uuid.parse()
162 .inspect_err(|e| {
163 error!(
164 "pgwire connection with invalid conn UUID: {e}",
165 )
166 })
167 .ok()
168 });
169 let conn_uuid_forwarded = conn_uuid.is_some();
170 // FIXME(ptravers): we should be able to inject
171 // the clock when instantiating the `Server`
172 // but as of writing there's no great way, I can
173 // see, to harmonize the lifetimes of the return
174 // type and &self which must house `NowFn`.
175 let conn_uuid = conn_uuid.unwrap_or_else(
176 || epoch_to_uuid_v7(&(SYSTEM_TIME.clone())()),
177 );
178 conn_uuid_handle.set(conn_uuid);
179 debug!(
180 conn_uuid = %conn_uuid_handle.display(),
181 conn_uuid_forwarded,
182 "starting new pgwire connection in adapter",
183 );
184
185 let direct_peer_addr = conn
186 .inner_mut()
187 .peer_addr()
188 .context("fetching peer addr")?
189 .ip();
190 let peer_addr= match params.remove(MZ_FORWARDED_FOR_KEY) {
191 Some(ip_str) => {
192 match IpAddr::from_str(&ip_str) {
193 Ok(ip) => Some(ip),
194 Err(e) => {
195 error!("pgwire connection with invalid mz_forwarded_for address: {e}");
196 None
197 }
198 }
199 }
200 None => Some(direct_peer_addr)
201 };
202 let mut conn = FramedConn::new(
203 conn_id.clone(),
204 peer_addr,
205 conn,
206 );
207
208 protocol::run(protocol::RunParams {
209 tls_mode: tls.as_ref().map(|tls| tls.mode),
210 adapter_client,
211 conn: &mut conn,
212 conn_uuid,
213 version,
214 params,
215 frontegg,
216 oidc,
217 authenticator_kind,
218 active_connection_counter,
219 helm_chart_version,
220 allowed_roles,
221 tokio_metrics_intervals,
222 })
223 .await?;
224 conn.flush().await?;
225 return Ok(());
226 }
227
228 Some(FrontendStartupMessage::CancelRequest {
229 conn_id,
230 secret_key,
231 }) => {
232 adapter_client.cancel_request(conn_id, secret_key);
233 // For security, the client is not told whether the cancel
234 // request succeeds or fails.
235 return Ok(());
236 }
237
238 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
239 (Conn::Unencrypted(mut conn), Some(tls)) => {
240 trace!("cid={} send=AcceptSsl", conn_id);
241 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
242 let mut ssl_stream =
243 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
244 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
245 let _ = ssl_stream.get_mut().shutdown().await;
246 return Err(e.into());
247 }
248 Conn::Ssl(ssl_stream)
249 }
250 (mut conn, _) => {
251 trace!("cid={} send=RejectSsl", conn_id);
252 conn.write_all(&[REJECT_ENCRYPTION]).await?;
253 conn
254 }
255 },
256
257 Some(FrontendStartupMessage::GssEncRequest) => {
258 trace!("cid={} send=RejectGssEnc", conn_id);
259 conn.write_all(&[REJECT_ENCRYPTION]).await?;
260 conn
261 }
262 }
263 }
264 }
265 })()
266 .await;
267 metrics.connection_status(result.is_ok()).inc();
268 result
269 }
270 }
271}