1use std::fmt;
13use std::future::Future;
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use anyhow::bail;
23use async_trait::async_trait;
24use clap::builder::ArgPredicate;
25use futures::stream::{BoxStream, Stream, StreamExt};
26use mz_dyncfg::{Config, ConfigSet};
27use mz_ore::channel::trigger;
28use mz_ore::error::ErrorExt;
29use mz_ore::netio::AsyncReady;
30use mz_ore::option::OptionExt;
31use mz_ore::task::JoinSetExt;
32use openssl::ssl::{SslAcceptor, SslContext, SslFiletype, SslMethod};
33use proxy_header::{ParseConfig, ProxiedAddress, ProxyHeader};
34use schemars::JsonSchema;
35use scopeguard::ScopeGuard;
36use serde::{Deserialize, Serialize};
37use socket2::{SockRef, TcpKeepalive};
38use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, Interest, ReadBuf, Ready};
39use tokio::net::{TcpListener, TcpStream};
40use tokio::sync::oneshot;
41use tokio::task::JoinSet;
42use tokio_metrics::TaskMetrics;
43use tokio_stream::wrappers::{IntervalStream, TcpListenerStream};
44use tracing::{debug, error, warn};
45use uuid::Uuid;
46
47pub mod listeners;
48
49const KEEPALIVE: TcpKeepalive = TcpKeepalive::new()
54 .with_time(Duration::from_secs(60))
55 .with_interval(Duration::from_secs(60))
56 .with_retries(9);
57
58pub type ConnectionHandler = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
60
61pub struct Connection {
64 conn_uuid: Arc<Mutex<Option<Uuid>>>,
65 tcp_stream: TcpStream,
66}
67
68impl Connection {
69 fn new(tcp_stream: TcpStream) -> Connection {
70 Connection {
71 conn_uuid: Arc::new(Mutex::new(None)),
72 tcp_stream,
73 }
74 }
75
76 pub fn uuid_handle(&self) -> ConnectionUuidHandle {
78 ConnectionUuidHandle(Arc::clone(&self.conn_uuid))
79 }
80
81 pub async fn take_proxy_header_address(&mut self) -> Option<ProxiedAddress> {
90 let mut buf = [0u8; 1024];
95 let len = match self.tcp_stream.peek(&mut buf).await {
96 Ok(n) if n > 0 => n,
97 _ => {
98 debug!("Failed to read from client socket or no data received");
99 return None;
100 }
101 };
102
103 let (header, hlen) = match ProxyHeader::parse(
105 &buf[..len],
106 ParseConfig {
107 include_tlvs: false,
108 allow_v1: false,
109 allow_v2: true,
110 },
111 ) {
112 Ok((header, hlen)) => (header, hlen),
113 Err(proxy_header::Error::Invalid) => {
114 debug!("Proxy header is invalid. This is likely due to no header being provided",);
115 return None;
116 }
117 Err(proxy_header::Error::BufferTooShort) => {
122 return self.read_proxy_v2_header(&mut buf).await;
123 }
124 Err(e) => {
125 debug!("Proxy header parse error '{:?}', ignoring header.", e);
126 return None;
127 }
128 };
129 debug!("Proxied connection with header {:?}", header);
130 let address = header.proxied_address().map(|a| a.to_owned());
131 let _ = self.read_exact(&mut buf[..hlen]).await;
133 address
134 }
135
136 async fn read_proxy_v2_header(&mut self, buf: &mut [u8; 1024]) -> Option<ProxiedAddress> {
140 const V2_PREFIX_LEN: usize = 16;
142 if self.read_exact(&mut buf[..V2_PREFIX_LEN]).await.is_err() {
143 debug!("Failed to read PROXY v2 fixed header");
144 return None;
145 }
146 let addr_len = usize::from(u16::from_be_bytes([buf[14], buf[15]]));
147 let total = V2_PREFIX_LEN + addr_len;
148 if total > buf.len() {
149 debug!("PROXY v2 header too large: {total} bytes");
150 return None;
151 }
152 if self
153 .read_exact(&mut buf[V2_PREFIX_LEN..total])
154 .await
155 .is_err()
156 {
157 debug!("Failed to read PROXY v2 address data");
158 return None;
159 }
160 match ProxyHeader::parse(
161 &buf[..total],
162 ParseConfig {
163 include_tlvs: false,
164 allow_v1: false,
165 allow_v2: true,
166 },
167 ) {
168 Ok((header, _)) => {
169 debug!("Proxied connection with header {:?}", header);
170 header.proxied_address().map(|a| a.to_owned())
171 }
172 Err(e) => {
173 debug!("Proxy header parse error '{:?}', ignoring header.", e);
174 None
175 }
176 }
177 }
178
179 pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
181 self.tcp_stream.peer_addr()
182 }
183}
184
185impl AsyncRead for Connection {
186 fn poll_read(
187 mut self: Pin<&mut Self>,
188 cx: &mut Context,
189 buf: &mut ReadBuf,
190 ) -> Poll<io::Result<()>> {
191 Pin::new(&mut self.tcp_stream).poll_read(cx, buf)
192 }
193}
194
195impl AsyncWrite for Connection {
196 fn poll_write(
197 mut self: Pin<&mut Self>,
198 cx: &mut Context,
199 buf: &[u8],
200 ) -> Poll<io::Result<usize>> {
201 Pin::new(&mut self.tcp_stream).poll_write(cx, buf)
202 }
203
204 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
205 Pin::new(&mut self.tcp_stream).poll_flush(cx)
206 }
207
208 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
209 Pin::new(&mut self.tcp_stream).poll_shutdown(cx)
210 }
211}
212
213#[async_trait]
214impl AsyncReady for Connection {
215 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
216 self.tcp_stream.ready(interest).await
217 }
218}
219
220pub struct ConnectionUuidHandle(Arc<Mutex<Option<Uuid>>>);
230
231impl ConnectionUuidHandle {
232 pub fn get(&self) -> Option<Uuid> {
234 *self.0.lock().expect("lock poisoned")
235 }
236
237 pub fn set(&self, conn_uuid: Uuid) {
239 *self.0.lock().expect("lock poisoned") = Some(conn_uuid);
240 }
241
242 pub fn display(&self) -> impl fmt::Display {
244 self.get().display_or("<unknown>")
245 }
246}
247
248pub trait Server {
250 const NAME: &'static str;
252
253 fn handle_connection(
255 &self,
256 conn: Connection,
257 tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
258 ) -> ConnectionHandler;
259}
260
261pub trait ConnectionStream: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
263
264impl<T> ConnectionStream for T where T: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
265
266#[derive(Debug)]
268pub struct ListenerHandle {
269 pub local_addr: SocketAddr,
270 _trigger: trigger::Trigger,
271}
272
273impl ListenerHandle {
274 pub fn local_addr(&self) -> SocketAddr {
276 self.local_addr
277 }
278}
279
280pub async fn listen(
286 addr: &SocketAddr,
287) -> Result<(ListenerHandle, Pin<Box<dyn ConnectionStream>>), io::Error> {
288 let listener = TcpListener::bind(addr).await?;
289 let local_addr = listener.local_addr()?;
290 let (trigger, trigger_rx) = trigger::channel();
291 let handle = ListenerHandle {
292 local_addr,
293 _trigger: trigger,
294 };
295 let stream = TcpListenerStream::new(listener).take_until(trigger_rx);
298 Ok((handle, Box::pin(stream)))
299}
300
301pub struct ServeConfig<S, C>
303where
304 S: Server,
305 C: ConnectionStream,
306{
307 pub server: S,
309 pub conns: C,
311 pub dyncfg: Option<ServeDyncfg>,
313}
314
315pub struct ServeDyncfg {
317 pub config_set: ConfigSet,
319 pub sigterm_wait_config: &'static Config<Duration>,
325}
326
327pub async fn serve<S, C>(
332 ServeConfig {
333 server,
334 mut conns,
335 dyncfg,
336 }: ServeConfig<S, C>,
337) -> JoinSet<()>
338where
339 S: Server,
340 C: ConnectionStream,
341{
342 let task_name = format!("handle_{}_connection", S::NAME);
343 let mut set = JoinSet::new();
344 loop {
345 tokio::select! {
346 conn = conns.next() => {
348 let conn = match conn {
349 None => break,
350 Some(Ok(conn)) => conn,
351 Some(Err(err)) => {
352 error!("error accepting connection: {}", err);
353 continue;
354 }
355 };
356 conn.set_nodelay(true).expect("set_nodelay failed");
368 if let Err(e) = SockRef::from(&conn).set_tcp_keepalive(&KEEPALIVE) {
373 error!("failed enabling keepalive: {e}");
374 continue;
375 }
376 let conn = Connection::new(conn);
377 let conn_uuid = conn.uuid_handle();
378 let metrics_monitor = tokio_metrics::TaskMonitor::new();
379 let tokio_metrics_intervals = metrics_monitor.intervals();
380 let fut = server.handle_connection(conn, tokio_metrics_intervals);
381 set.spawn_named(|| &task_name, metrics_monitor.instrument(async move {
382 let guard = scopeguard::guard((), |_| {
383 debug!(
384 server = S::NAME,
385 conn_uuid = %conn_uuid.display(),
386 "dropping connection without explicit termination",
387 );
388 });
389
390 match fut.await {
391 Ok(()) => {
392 debug!(
393 server = S::NAME,
394 conn_uuid = %conn_uuid.display(),
395 "successfully handled connection",
396 );
397 }
398 Err(e) => {
399 warn!(
400 server = S::NAME,
401 conn_uuid = %conn_uuid.display(),
402 "error handling connection: {}",
403 e.display_with_causes(),
404 );
405 }
406 }
407
408 let () = ScopeGuard::into_inner(guard);
409 }));
410 }
411 res = set.join_next(), if set.len() > 0 => {
414 if let Some(Err(e)) = res {
415 warn!(
416 "error joining connection in {}: {}",
417 S::NAME,
418 e.display_with_causes()
419 );
420 }
421 }
422 }
423 }
424 if let Some(dyncfg) = dyncfg {
425 let wait = dyncfg.sigterm_wait_config.get(&dyncfg.config_set);
426 if set.len() > 0 {
427 warn!(
428 "{} exiting, {} outstanding connections, waiting for {:?}",
429 S::NAME,
430 set.len(),
431 wait
432 );
433 }
434 let timedout = tokio::time::timeout(wait, async {
435 while let Some(res) = set.join_next().await {
436 if let Err(e) = res {
437 warn!(
438 "error joining connection in {}: {}",
439 S::NAME,
440 e.display_with_causes()
441 );
442 }
443 }
444 })
445 .await;
446 if timedout.is_err() {
447 warn!(
448 "{}: wait timeout of {:?} exceeded, {} outstanding connections",
449 S::NAME,
450 wait,
451 set.len()
452 );
453 }
454 }
455 set
456}
457
458#[derive(Clone, Debug)]
460pub struct TlsConfig {
461 pub context: SslContext,
463 pub mode: TlsMode,
465}
466
467#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, JsonSchema)]
469pub enum TlsMode {
470 Allow,
472 Require,
474}
475
476#[derive(Debug, Clone)]
478pub struct TlsCertConfig {
479 pub cert: PathBuf,
481 pub key: PathBuf,
483}
484
485impl TlsCertConfig {
486 pub fn load_context(&self) -> Result<SslContext, anyhow::Error> {
488 let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
495 builder.set_certificate_chain_file(&self.cert)?;
496 builder.set_private_key_file(&self.key, SslFiletype::PEM)?;
497 Ok(builder.build().into_context())
498 }
499
500 pub fn reloading_context(
506 &self,
507 mut ticker: ReloadTrigger,
508 ) -> Result<ReloadingSslContext, anyhow::Error> {
509 let context = Arc::new(RwLock::new(self.load_context()?));
510 let updater_context = Arc::clone(&context);
511 let config = self.clone();
512 mz_ore::task::spawn(|| "TlsCertConfig reloading_context", async move {
513 while let Some(chan) = ticker.next().await {
514 let result = match config.load_context() {
515 Ok(ctx) => {
516 *updater_context.write().expect("poisoned") = ctx;
517 Ok(())
518 }
519 Err(err) => {
520 tracing::error!("failed to reload SSL certificate: {err}");
521 Err(err)
522 }
523 };
524 if let Some(chan) = chan {
525 let _ = chan.send(result);
526 }
527 }
528 tracing::warn!("TlsCertConfig reloading_context updater closed");
529 });
530 Ok(ReloadingSslContext { context })
531 }
532}
533
534#[derive(Clone, Debug)]
536pub struct ReloadingSslContext {
537 context: Arc<RwLock<SslContext>>,
539}
540
541impl ReloadingSslContext {
542 pub fn get(&self) -> RwLockReadGuard<'_, SslContext> {
543 self.context.read().expect("poisoned")
544 }
545}
546
547#[derive(Clone, Debug)]
549pub struct ReloadingTlsConfig {
550 pub context: ReloadingSslContext,
552 pub mode: TlsMode,
554}
555
556pub type ReloadTrigger = BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>;
557
558pub fn default_cert_reload_ticker() -> ReloadTrigger {
560 let ticker = IntervalStream::new(tokio::time::interval(Duration::from_secs(60 * 60)));
561 let ticker = ticker.map(|_| None);
562 let ticker = Box::pin(ticker);
563 ticker
564}
565
566pub fn cert_reload_never_reload() -> ReloadTrigger {
568 let ticker = futures::stream::empty();
569 let ticker = Box::pin(ticker);
570 ticker
571}
572
573#[derive(Debug, Clone, clap::Parser)]
575pub struct TlsCliArgs {
576 #[clap(
585 long, env = "TLS_MODE",
586 value_parser = ["disable", "require"],
587 default_value = "disable",
588 default_value_ifs = [
589 ("frontegg_tenant", ArgPredicate::IsPresent, Some("require")),
590 ("frontegg_resolver_template", ArgPredicate::IsPresent, Some("require")),
591 ],
592 value_name = "MODE",
593 )]
594 tls_mode: String,
595 #[clap(
597 long,
598 env = "TLS_CERT",
599 requires = "tls_key",
600 required_if_eq_any([("tls_mode", "require")]),
601 value_name = "PATH"
602 )]
603 tls_cert: Option<PathBuf>,
604 #[clap(
606 long,
607 env = "TLS_KEY",
608 requires = "tls_cert",
609 required_if_eq_any([("tls_mode", "require")]),
610 value_name = "PATH"
611 )]
612 tls_key: Option<PathBuf>,
613}
614
615impl TlsCliArgs {
616 pub fn into_config(self) -> Result<Option<TlsCertConfig>, anyhow::Error> {
618 if self.tls_mode == "disable" {
619 if self.tls_cert.is_some() {
620 bail!("cannot specify --tls-mode=disable and --tls-cert simultaneously");
621 }
622 if self.tls_key.is_some() {
623 bail!("cannot specify --tls-mode=disable and --tls-key simultaneously");
624 }
625 Ok(None)
626 } else {
627 let cert = self.tls_cert.unwrap();
628 let key = self.tls_key.unwrap();
629 Ok(Some(TlsCertConfig { cert, key }))
630 }
631 }
632}