1use std::future::Future;
11use std::net::IpAddr;
12use std::pin::Pin;
13use std::str::FromStr;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use mz_authenticator::Authenticator;
18use mz_ore::now::{SYSTEM_TIME, epoch_to_uuid_v7};
19use mz_pgwire_common::{
20 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, FrontendStartupMessage,
21 MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
22};
23use mz_server_core::listeners::AllowedRoles;
24use mz_server_core::{Connection, ConnectionHandler, ReloadingTlsConfig};
25use openssl::ssl::Ssl;
26use tokio::io::AsyncWriteExt;
27use tokio_metrics::TaskMetrics;
28use tokio_openssl::SslStream;
29use tracing::{debug, error, trace};
30
31use crate::codec::FramedConn;
32use crate::metrics::{Metrics, MetricsConfig};
33use crate::protocol;
34
35#[derive(Debug)]
37pub struct Config {
38 pub label: &'static str,
40 pub adapter_client: mz_adapter::Client,
42 pub tls: Option<ReloadingTlsConfig>,
47 pub authenticator: Authenticator,
49 pub metrics: MetricsConfig,
51 pub active_connection_counter: ConnectionCounter,
53 pub helm_chart_version: Option<String>,
55 pub allowed_roles: AllowedRoles,
57}
58
59pub struct Server {
61 tls: Option<ReloadingTlsConfig>,
62 adapter_client: mz_adapter::Client,
63 authenticator: Authenticator,
64 metrics: Metrics,
65 active_connection_counter: ConnectionCounter,
66 helm_chart_version: Option<String>,
67 allowed_roles: AllowedRoles,
68}
69
70#[async_trait]
71impl mz_server_core::Server for Server {
72 const NAME: &'static str = "pgwire";
73
74 fn handle_connection(
75 &self,
76 conn: Connection,
77 tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
78 ) -> ConnectionHandler {
79 Box::pin(crate::Server::handle_connection(
83 self,
84 conn,
85 tokio_metrics_intervals,
86 ))
87 }
88}
89
90impl Server {
91 pub fn new(config: Config) -> Server {
93 Server {
94 tls: config.tls,
95 adapter_client: config.adapter_client,
96 authenticator: config.authenticator,
97 metrics: Metrics::new(config.metrics, config.label),
98 active_connection_counter: config.active_connection_counter,
99 helm_chart_version: config.helm_chart_version,
100 allowed_roles: config.allowed_roles,
101 }
102 }
103
104 #[mz_ore::instrument(level = "debug")]
105 pub fn handle_connection(
106 &self,
107 conn: Connection,
108 tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
109 ) -> impl Future<Output = Result<(), anyhow::Error>> + Send + 'static {
110 let adapter_client = self.adapter_client.clone();
111 let authenticator = self.authenticator.clone();
112 let tls = self.tls.clone();
113 let metrics = self.metrics.clone();
114 let active_connection_counter = self.active_connection_counter.clone();
115 let helm_chart_version = self.helm_chart_version.clone();
116 let allowed_roles = self.allowed_roles;
117
118 #[allow(clippy::redundant_closure_call)]
120 async move {
121 let result = (|| {
122 async move {
123 let conn_id = adapter_client.new_conn_id()?;
124 let mut conn = Conn::Unencrypted(conn);
125 loop {
126 let message = decode_startup(&mut conn).await?;
127
128 match &message {
129 Some(message) => trace!("cid={} recv={:?}", conn_id, message),
130 None => trace!("cid={} recv=<eof>", conn_id),
131 }
132
133 conn = match message {
134 None => return Ok(()),
138
139 Some(FrontendStartupMessage::Startup {
140 version,
141 mut params,
142 }) => {
143 let conn_uuid_handle = conn.inner_mut().uuid_handle();
146 let conn_uuid = params
147 .remove(CONN_UUID_KEY)
148 .and_then(|uuid| uuid.parse().inspect_err(|e| error!("pgwire connection with invalid conn UUID: {e}")).ok());
149 let conn_uuid_forwarded = conn_uuid.is_some();
150 let conn_uuid = conn_uuid.unwrap_or_else(|| epoch_to_uuid_v7(&(SYSTEM_TIME.clone())()));
154 conn_uuid_handle.set(conn_uuid);
155 debug!(conn_uuid = %conn_uuid_handle.display(), conn_uuid_forwarded, "starting new pgwire connection in adapter");
156
157 let direct_peer_addr = conn
158 .inner_mut()
159 .peer_addr()
160 .context("fetching peer addr")?
161 .ip();
162 let peer_addr= match params.remove(MZ_FORWARDED_FOR_KEY) {
163 Some(ip_str) => {
164 match IpAddr::from_str(&ip_str) {
165 Ok(ip) => Some(ip),
166 Err(e) => {
167 error!("pgwire connection with invalid mz_forwarded_for address: {e}");
168 None
169 }
170 }
171 }
172 None => Some(direct_peer_addr)
173 };
174 let mut conn = FramedConn::new(
175 conn_id.clone(),
176 peer_addr,
177 conn,
178 );
179
180 protocol::run(protocol::RunParams {
181 tls_mode: tls.as_ref().map(|tls| tls.mode),
182 adapter_client,
183 conn: &mut conn,
184 conn_uuid,
185 version,
186 params,
187 authenticator,
188 active_connection_counter,
189 helm_chart_version,
190 allowed_roles,
191 tokio_metrics_intervals,
192 })
193 .await?;
194 conn.flush().await?;
195 return Ok(());
196 }
197
198 Some(FrontendStartupMessage::CancelRequest {
199 conn_id,
200 secret_key,
201 }) => {
202 adapter_client.cancel_request(conn_id, secret_key);
203 return Ok(());
206 }
207
208 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
209 (Conn::Unencrypted(mut conn), Some(tls)) => {
210 trace!("cid={} send=AcceptSsl", conn_id);
211 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
212 let mut ssl_stream =
213 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
214 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
215 let _ = ssl_stream.get_mut().shutdown().await;
216 return Err(e.into());
217 }
218 Conn::Ssl(ssl_stream)
219 }
220 (mut conn, _) => {
221 trace!("cid={} send=RejectSsl", conn_id);
222 conn.write_all(&[REJECT_ENCRYPTION]).await?;
223 conn
224 }
225 },
226
227 Some(FrontendStartupMessage::GssEncRequest) => {
228 trace!("cid={} send=RejectGssEnc", conn_id);
229 conn.write_all(&[REJECT_ENCRYPTION]).await?;
230 conn
231 }
232 }
233 }
234 }
235 })()
236 .await;
237 metrics.connection_status(result.is_ok()).inc();
238 result
239 }
240 }
241}