1use std::fmt;
13use std::future::Future;
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use anyhow::bail;
23use async_trait::async_trait;
24use clap::builder::ArgPredicate;
25use futures::stream::{BoxStream, Stream, StreamExt};
26use mz_dyncfg::{Config, ConfigSet};
27use mz_ore::channel::trigger;
28use mz_ore::error::ErrorExt;
29use mz_ore::netio::AsyncReady;
30use mz_ore::option::OptionExt;
31use mz_ore::task::JoinSetExt;
32use openssl::ssl::{SslAcceptor, SslContext, SslFiletype, SslMethod};
33use proxy_header::{ParseConfig, ProxiedAddress, ProxyHeader};
34use schemars::JsonSchema;
35use scopeguard::ScopeGuard;
36use serde::{Deserialize, Serialize};
37use socket2::{SockRef, TcpKeepalive};
38use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, Interest, ReadBuf, Ready};
39use tokio::net::{TcpListener, TcpStream};
40use tokio::sync::oneshot;
41use tokio::task::JoinSet;
42use tokio_metrics::TaskMetrics;
43use tokio_stream::wrappers::{IntervalStream, TcpListenerStream};
44use tracing::{debug, error, warn};
45use uuid::Uuid;
46
47pub mod listeners;
48
49const KEEPALIVE: TcpKeepalive = TcpKeepalive::new()
54 .with_time(Duration::from_secs(60))
55 .with_interval(Duration::from_secs(60))
56 .with_retries(9);
57
58pub type ConnectionHandler = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
60
61pub struct Connection {
64 conn_uuid: Arc<Mutex<Option<Uuid>>>,
65 tcp_stream: TcpStream,
66}
67
68impl Connection {
69 fn new(tcp_stream: TcpStream) -> Connection {
70 Connection {
71 conn_uuid: Arc::new(Mutex::new(None)),
72 tcp_stream,
73 }
74 }
75
76 pub fn uuid_handle(&self) -> ConnectionUuidHandle {
78 ConnectionUuidHandle(Arc::clone(&self.conn_uuid))
79 }
80
81 pub async fn take_proxy_header_address(&mut self) -> Option<ProxiedAddress> {
90 let mut buf = [0u8; 1024];
95 let len = match self.tcp_stream.peek(&mut buf).await {
96 Ok(n) if n > 0 => n,
97 _ => {
98 debug!("Failed to read from client socket or no data received");
99 return None;
100 }
101 };
102
103 let (header, hlen) = match ProxyHeader::parse(
105 &buf[..len],
106 ParseConfig {
107 include_tlvs: false,
108 allow_v1: false,
109 allow_v2: true,
110 },
111 ) {
112 Ok((header, hlen)) => (header, hlen),
113 Err(proxy_header::Error::Invalid) => {
114 debug!(
115 "Proxy header is invalid. This is likely due to no no header being provided",
116 );
117 return None;
118 }
119 Err(e) => {
120 debug!("Proxy header parse error '{:?}', ignoring header.", e);
121 return None;
122 }
123 };
124 debug!("Proxied connection with header {:?}", header);
125 let address = header.proxied_address().map(|a| a.to_owned());
126 let _ = self.read_exact(&mut buf[..hlen]).await;
128 address
129 }
130
131 pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
133 self.tcp_stream.peer_addr()
134 }
135}
136
137impl AsyncRead for Connection {
138 fn poll_read(
139 mut self: Pin<&mut Self>,
140 cx: &mut Context,
141 buf: &mut ReadBuf,
142 ) -> Poll<io::Result<()>> {
143 Pin::new(&mut self.tcp_stream).poll_read(cx, buf)
144 }
145}
146
147impl AsyncWrite for Connection {
148 fn poll_write(
149 mut self: Pin<&mut Self>,
150 cx: &mut Context,
151 buf: &[u8],
152 ) -> Poll<io::Result<usize>> {
153 Pin::new(&mut self.tcp_stream).poll_write(cx, buf)
154 }
155
156 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
157 Pin::new(&mut self.tcp_stream).poll_flush(cx)
158 }
159
160 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
161 Pin::new(&mut self.tcp_stream).poll_shutdown(cx)
162 }
163}
164
165#[async_trait]
166impl AsyncReady for Connection {
167 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
168 self.tcp_stream.ready(interest).await
169 }
170}
171
172pub struct ConnectionUuidHandle(Arc<Mutex<Option<Uuid>>>);
182
183impl ConnectionUuidHandle {
184 pub fn get(&self) -> Option<Uuid> {
186 *self.0.lock().expect("lock poisoned")
187 }
188
189 pub fn set(&self, conn_uuid: Uuid) {
191 *self.0.lock().expect("lock poisoned") = Some(conn_uuid);
192 }
193
194 pub fn display(&self) -> impl fmt::Display {
196 self.get().display_or("<unknown>")
197 }
198}
199
200pub trait Server {
202 const NAME: &'static str;
204
205 fn handle_connection(
207 &self,
208 conn: Connection,
209 tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
210 ) -> ConnectionHandler;
211}
212
213pub trait ConnectionStream: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
215
216impl<T> ConnectionStream for T where T: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
217
218#[derive(Debug)]
220pub struct ListenerHandle {
221 pub local_addr: SocketAddr,
222 _trigger: trigger::Trigger,
223}
224
225impl ListenerHandle {
226 pub fn local_addr(&self) -> SocketAddr {
228 self.local_addr
229 }
230}
231
232pub async fn listen(
238 addr: &SocketAddr,
239) -> Result<(ListenerHandle, Pin<Box<dyn ConnectionStream>>), io::Error> {
240 let listener = TcpListener::bind(addr).await?;
241 let local_addr = listener.local_addr()?;
242 let (trigger, trigger_rx) = trigger::channel();
243 let handle = ListenerHandle {
244 local_addr,
245 _trigger: trigger,
246 };
247 let stream = TcpListenerStream::new(listener).take_until(trigger_rx);
250 Ok((handle, Box::pin(stream)))
251}
252
253pub struct ServeConfig<S, C>
255where
256 S: Server,
257 C: ConnectionStream,
258{
259 pub server: S,
261 pub conns: C,
263 pub dyncfg: Option<ServeDyncfg>,
265}
266
267pub struct ServeDyncfg {
269 pub config_set: ConfigSet,
271 pub sigterm_wait_config: &'static Config<Duration>,
277}
278
279pub async fn serve<S, C>(
284 ServeConfig {
285 server,
286 mut conns,
287 dyncfg,
288 }: ServeConfig<S, C>,
289) -> JoinSet<()>
290where
291 S: Server,
292 C: ConnectionStream,
293{
294 let task_name = format!("handle_{}_connection", S::NAME);
295 let mut set = JoinSet::new();
296 loop {
297 tokio::select! {
298 conn = conns.next() => {
300 let conn = match conn {
301 None => break,
302 Some(Ok(conn)) => conn,
303 Some(Err(err)) => {
304 error!("error accepting connection: {}", err);
305 continue;
306 }
307 };
308 conn.set_nodelay(true).expect("set_nodelay failed");
320 if let Err(e) = SockRef::from(&conn).set_tcp_keepalive(&KEEPALIVE) {
325 error!("failed enabling keepalive: {e}");
326 continue;
327 }
328 let conn = Connection::new(conn);
329 let conn_uuid = conn.uuid_handle();
330 let metrics_monitor = tokio_metrics::TaskMonitor::new();
331 let tokio_metrics_intervals = metrics_monitor.intervals();
332 let fut = server.handle_connection(conn, tokio_metrics_intervals);
333 set.spawn_named(|| &task_name, metrics_monitor.instrument(async move {
334 let guard = scopeguard::guard((), |_| {
335 debug!(
336 server = S::NAME,
337 conn_uuid = %conn_uuid.display(),
338 "dropping connection without explicit termination",
339 );
340 });
341
342 match fut.await {
343 Ok(()) => {
344 debug!(
345 server = S::NAME,
346 conn_uuid = %conn_uuid.display(),
347 "successfully handled connection",
348 );
349 }
350 Err(e) => {
351 warn!(
352 server = S::NAME,
353 conn_uuid = %conn_uuid.display(),
354 "error handling connection: {}",
355 e.display_with_causes(),
356 );
357 }
358 }
359
360 let () = ScopeGuard::into_inner(guard);
361 }));
362 }
363 res = set.join_next(), if set.len() > 0 => {
366 if let Some(Err(e)) = res {
367 warn!(
368 "error joining connection in {}: {}",
369 S::NAME,
370 e.display_with_causes()
371 );
372 }
373 }
374 }
375 }
376 if let Some(dyncfg) = dyncfg {
377 let wait = dyncfg.sigterm_wait_config.get(&dyncfg.config_set);
378 if set.len() > 0 {
379 warn!(
380 "{} exiting, {} outstanding connections, waiting for {:?}",
381 S::NAME,
382 set.len(),
383 wait
384 );
385 }
386 let timedout = tokio::time::timeout(wait, async {
387 while let Some(res) = set.join_next().await {
388 if let Err(e) = res {
389 warn!(
390 "error joining connection in {}: {}",
391 S::NAME,
392 e.display_with_causes()
393 );
394 }
395 }
396 })
397 .await;
398 if timedout.is_err() {
399 warn!(
400 "{}: wait timeout of {:?} exceeded, {} outstanding connections",
401 S::NAME,
402 wait,
403 set.len()
404 );
405 }
406 }
407 set
408}
409
410#[derive(Clone, Debug)]
412pub struct TlsConfig {
413 pub context: SslContext,
415 pub mode: TlsMode,
417}
418
419#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, JsonSchema)]
421pub enum TlsMode {
422 Allow,
424 Require,
426}
427
428#[derive(Debug, Clone)]
430pub struct TlsCertConfig {
431 pub cert: PathBuf,
433 pub key: PathBuf,
435}
436
437impl TlsCertConfig {
438 pub fn load_context(&self) -> Result<SslContext, anyhow::Error> {
440 let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
447 builder.set_certificate_chain_file(&self.cert)?;
448 builder.set_private_key_file(&self.key, SslFiletype::PEM)?;
449 Ok(builder.build().into_context())
450 }
451
452 pub fn reloading_context(
458 &self,
459 mut ticker: ReloadTrigger,
460 ) -> Result<ReloadingSslContext, anyhow::Error> {
461 let context = Arc::new(RwLock::new(self.load_context()?));
462 let updater_context = Arc::clone(&context);
463 let config = self.clone();
464 mz_ore::task::spawn(|| "TlsCertConfig reloading_context", async move {
465 while let Some(chan) = ticker.next().await {
466 let result = match config.load_context() {
467 Ok(ctx) => {
468 *updater_context.write().expect("poisoned") = ctx;
469 Ok(())
470 }
471 Err(err) => {
472 tracing::error!("failed to reload SSL certificate: {err}");
473 Err(err)
474 }
475 };
476 if let Some(chan) = chan {
477 let _ = chan.send(result);
478 }
479 }
480 tracing::warn!("TlsCertConfig reloading_context updater closed");
481 });
482 Ok(ReloadingSslContext { context })
483 }
484}
485
486#[derive(Clone, Debug)]
488pub struct ReloadingSslContext {
489 context: Arc<RwLock<SslContext>>,
491}
492
493impl ReloadingSslContext {
494 pub fn get(&self) -> RwLockReadGuard<'_, SslContext> {
495 self.context.read().expect("poisoned")
496 }
497}
498
499#[derive(Clone, Debug)]
501pub struct ReloadingTlsConfig {
502 pub context: ReloadingSslContext,
504 pub mode: TlsMode,
506}
507
508pub type ReloadTrigger = BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>;
509
510pub fn default_cert_reload_ticker() -> ReloadTrigger {
512 let ticker = IntervalStream::new(tokio::time::interval(Duration::from_secs(60 * 60)));
513 let ticker = ticker.map(|_| None);
514 let ticker = Box::pin(ticker);
515 ticker
516}
517
518pub fn cert_reload_never_reload() -> ReloadTrigger {
520 let ticker = futures::stream::empty();
521 let ticker = Box::pin(ticker);
522 ticker
523}
524
525#[derive(Debug, Clone, clap::Parser)]
527pub struct TlsCliArgs {
528 #[clap(
537 long, env = "TLS_MODE",
538 value_parser = ["disable", "require"],
539 default_value = "disable",
540 default_value_ifs = [
541 ("frontegg_tenant", ArgPredicate::IsPresent, Some("require")),
542 ("frontegg_resolver_template", ArgPredicate::IsPresent, Some("require")),
543 ],
544 value_name = "MODE",
545 )]
546 tls_mode: String,
547 #[clap(
549 long,
550 env = "TLS_CERT",
551 requires = "tls_key",
552 required_if_eq_any([("tls_mode", "require")]),
553 value_name = "PATH"
554 )]
555 tls_cert: Option<PathBuf>,
556 #[clap(
558 long,
559 env = "TLS_KEY",
560 requires = "tls_cert",
561 required_if_eq_any([("tls_mode", "require")]),
562 value_name = "PATH"
563 )]
564 tls_key: Option<PathBuf>,
565}
566
567impl TlsCliArgs {
568 pub fn into_config(self) -> Result<Option<TlsCertConfig>, anyhow::Error> {
570 if self.tls_mode == "disable" {
571 if self.tls_cert.is_some() {
572 bail!("cannot specify --tls-mode=disable and --tls-cert simultaneously");
573 }
574 if self.tls_key.is_some() {
575 bail!("cannot specify --tls-mode=disable and --tls-key simultaneously");
576 }
577 Ok(None)
578 } else {
579 let cert = self.tls_cert.unwrap();
580 let key = self.tls_key.unwrap();
581 Ok(Some(TlsCertConfig { cert, key }))
582 }
583 }
584}