1use std::fmt;
13use std::future::Future;
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use anyhow::bail;
23use async_trait::async_trait;
24use clap::builder::ArgPredicate;
25use futures::stream::{BoxStream, Stream, StreamExt};
26use mz_dyncfg::{Config, ConfigSet};
27use mz_ore::channel::trigger;
28use mz_ore::error::ErrorExt;
29use mz_ore::netio::AsyncReady;
30use mz_ore::option::OptionExt;
31use mz_ore::task::JoinSetExt;
32use openssl::ssl::{SslAcceptor, SslContext, SslFiletype, SslMethod};
33use proxy_header::{ParseConfig, ProxiedAddress, ProxyHeader};
34use schemars::JsonSchema;
35use scopeguard::ScopeGuard;
36use serde::{Deserialize, Serialize};
37use socket2::{SockRef, TcpKeepalive};
38use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, Interest, ReadBuf, Ready};
39use tokio::net::{TcpListener, TcpStream};
40use tokio::sync::oneshot;
41use tokio::task::JoinSet;
42use tokio_metrics::TaskMetrics;
43use tokio_stream::wrappers::{IntervalStream, TcpListenerStream};
44use tracing::{debug, error, warn};
45use uuid::Uuid;
46
47pub mod listeners;
48
49const KEEPALIVE: TcpKeepalive = TcpKeepalive::new()
54    .with_time(Duration::from_secs(60))
55    .with_interval(Duration::from_secs(60))
56    .with_retries(9);
57
58pub type ConnectionHandler = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
60
61pub struct Connection {
64    conn_uuid: Arc<Mutex<Option<Uuid>>>,
65    tcp_stream: TcpStream,
66}
67
68impl Connection {
69    fn new(tcp_stream: TcpStream) -> Connection {
70        Connection {
71            conn_uuid: Arc::new(Mutex::new(None)),
72            tcp_stream,
73        }
74    }
75
76    pub fn uuid_handle(&self) -> ConnectionUuidHandle {
78        ConnectionUuidHandle(Arc::clone(&self.conn_uuid))
79    }
80
81    pub async fn take_proxy_header_address(&mut self) -> Option<ProxiedAddress> {
90        let mut buf = [0u8; 1024];
95        let len = match self.tcp_stream.peek(&mut buf).await {
96            Ok(n) if n > 0 => n,
97            _ => {
98                debug!("Failed to read from client socket or no data received");
99                return None;
100            }
101        };
102
103        let (header, hlen) = match ProxyHeader::parse(
105            &buf[..len],
106            ParseConfig {
107                include_tlvs: false,
108                allow_v1: false,
109                allow_v2: true,
110            },
111        ) {
112            Ok((header, hlen)) => (header, hlen),
113            Err(proxy_header::Error::Invalid) => {
114                debug!(
115                    "Proxy header is invalid. This is likely due to no no header being provided",
116                );
117                return None;
118            }
119            Err(e) => {
120                debug!("Proxy header parse error '{:?}', ignoring header.", e);
121                return None;
122            }
123        };
124        debug!("Proxied connection with header {:?}", header);
125        let address = header.proxied_address().map(|a| a.to_owned());
126        let _ = self.read_exact(&mut buf[..hlen]).await;
128        address
129    }
130
131    pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
133        self.tcp_stream.peer_addr()
134    }
135}
136
137impl AsyncRead for Connection {
138    fn poll_read(
139        mut self: Pin<&mut Self>,
140        cx: &mut Context,
141        buf: &mut ReadBuf,
142    ) -> Poll<io::Result<()>> {
143        Pin::new(&mut self.tcp_stream).poll_read(cx, buf)
144    }
145}
146
147impl AsyncWrite for Connection {
148    fn poll_write(
149        mut self: Pin<&mut Self>,
150        cx: &mut Context,
151        buf: &[u8],
152    ) -> Poll<io::Result<usize>> {
153        Pin::new(&mut self.tcp_stream).poll_write(cx, buf)
154    }
155
156    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
157        Pin::new(&mut self.tcp_stream).poll_flush(cx)
158    }
159
160    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
161        Pin::new(&mut self.tcp_stream).poll_shutdown(cx)
162    }
163}
164
165#[async_trait]
166impl AsyncReady for Connection {
167    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
168        self.tcp_stream.ready(interest).await
169    }
170}
171
172pub struct ConnectionUuidHandle(Arc<Mutex<Option<Uuid>>>);
182
183impl ConnectionUuidHandle {
184    pub fn get(&self) -> Option<Uuid> {
186        *self.0.lock().expect("lock poisoned")
187    }
188
189    pub fn set(&self, conn_uuid: Uuid) {
191        *self.0.lock().expect("lock poisoned") = Some(conn_uuid);
192    }
193
194    pub fn display(&self) -> impl fmt::Display {
196        self.get().display_or("<unknown>")
197    }
198}
199
200pub trait Server {
202    const NAME: &'static str;
204
205    fn handle_connection(
207        &self,
208        conn: Connection,
209        tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
210    ) -> ConnectionHandler;
211}
212
213pub trait ConnectionStream: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
215
216impl<T> ConnectionStream for T where T: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
217
218#[derive(Debug)]
220pub struct ListenerHandle {
221    pub local_addr: SocketAddr,
222    _trigger: trigger::Trigger,
223}
224
225impl ListenerHandle {
226    pub fn local_addr(&self) -> SocketAddr {
228        self.local_addr
229    }
230}
231
232pub async fn listen(
238    addr: &SocketAddr,
239) -> Result<(ListenerHandle, Pin<Box<dyn ConnectionStream>>), io::Error> {
240    let listener = TcpListener::bind(addr).await?;
241    let local_addr = listener.local_addr()?;
242    let (trigger, trigger_rx) = trigger::channel();
243    let handle = ListenerHandle {
244        local_addr,
245        _trigger: trigger,
246    };
247    let stream = TcpListenerStream::new(listener).take_until(trigger_rx);
250    Ok((handle, Box::pin(stream)))
251}
252
253pub struct ServeConfig<S, C>
255where
256    S: Server,
257    C: ConnectionStream,
258{
259    pub server: S,
261    pub conns: C,
263    pub dyncfg: Option<ServeDyncfg>,
265}
266
267pub struct ServeDyncfg {
269    pub config_set: ConfigSet,
271    pub sigterm_wait_config: &'static Config<Duration>,
277}
278
279pub async fn serve<S, C>(
284    ServeConfig {
285        server,
286        mut conns,
287        dyncfg,
288    }: ServeConfig<S, C>,
289) -> JoinSet<()>
290where
291    S: Server,
292    C: ConnectionStream,
293{
294    let task_name = format!("handle_{}_connection", S::NAME);
295    let mut set = JoinSet::new();
296    loop {
297        tokio::select! {
298            conn = conns.next() => {
300                let conn = match conn {
301                    None => break,
302                    Some(Ok(conn)) => conn,
303                    Some(Err(err)) => {
304                        error!("error accepting connection: {}", err);
305                        continue;
306                    }
307                };
308                conn.set_nodelay(true).expect("set_nodelay failed");
320                if let Err(e) = SockRef::from(&conn).set_tcp_keepalive(&KEEPALIVE) {
325                    error!("failed enabling keepalive: {e}");
326                    continue;
327                }
328                let conn = Connection::new(conn);
329                let conn_uuid = conn.uuid_handle();
330                let metrics_monitor = tokio_metrics::TaskMonitor::new();
331                let tokio_metrics_intervals = metrics_monitor.intervals();
332                let fut = server.handle_connection(conn, tokio_metrics_intervals);
333                set.spawn_named(|| &task_name, metrics_monitor.instrument(async move {
334                    let guard = scopeguard::guard((), |_| {
335                        debug!(
336                            server = S::NAME,
337                            conn_uuid = %conn_uuid.display(),
338                            "dropping connection without explicit termination",
339                        );
340                    });
341
342                    match fut.await {
343                        Ok(()) => {
344                            debug!(
345                                server = S::NAME,
346                                conn_uuid = %conn_uuid.display(),
347                                "successfully handled connection",
348                            );
349                        }
350                        Err(e) => {
351                            warn!(
352                                server = S::NAME,
353                                conn_uuid = %conn_uuid.display(),
354                                "error handling connection: {}",
355                                e.display_with_causes(),
356                            );
357                        }
358                    }
359
360                    let () = ScopeGuard::into_inner(guard);
361                }));
362            }
363            res = set.join_next(), if set.len() > 0 => {
366                if let Some(Err(e)) = res {
367                    warn!(
368                        "error joining connection in {}: {}",
369                        S::NAME,
370                        e.display_with_causes()
371                    );
372                }
373            }
374        }
375    }
376    if let Some(dyncfg) = dyncfg {
377        let wait = dyncfg.sigterm_wait_config.get(&dyncfg.config_set);
378        if set.len() > 0 {
379            warn!(
380                "{} exiting, {} outstanding connections, waiting for {:?}",
381                S::NAME,
382                set.len(),
383                wait
384            );
385        }
386        let timedout = tokio::time::timeout(wait, async {
387            while let Some(res) = set.join_next().await {
388                if let Err(e) = res {
389                    warn!(
390                        "error joining connection in {}: {}",
391                        S::NAME,
392                        e.display_with_causes()
393                    );
394                }
395            }
396        })
397        .await;
398        if timedout.is_err() {
399            warn!(
400                "{}: wait timeout of {:?} exceeded, {} outstanding connections",
401                S::NAME,
402                wait,
403                set.len()
404            );
405        }
406    }
407    set
408}
409
410#[derive(Clone, Debug)]
412pub struct TlsConfig {
413    pub context: SslContext,
415    pub mode: TlsMode,
417}
418
419#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, JsonSchema)]
421pub enum TlsMode {
422    Allow,
424    Require,
426}
427
428#[derive(Debug, Clone)]
430pub struct TlsCertConfig {
431    pub cert: PathBuf,
433    pub key: PathBuf,
435}
436
437impl TlsCertConfig {
438    pub fn load_context(&self) -> Result<SslContext, anyhow::Error> {
440        let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
447        builder.set_certificate_chain_file(&self.cert)?;
448        builder.set_private_key_file(&self.key, SslFiletype::PEM)?;
449        Ok(builder.build().into_context())
450    }
451
452    pub fn reloading_context(
458        &self,
459        mut ticker: ReloadTrigger,
460    ) -> Result<ReloadingSslContext, anyhow::Error> {
461        let context = Arc::new(RwLock::new(self.load_context()?));
462        let updater_context = Arc::clone(&context);
463        let config = self.clone();
464        mz_ore::task::spawn(|| "TlsCertConfig reloading_context", async move {
465            while let Some(chan) = ticker.next().await {
466                let result = match config.load_context() {
467                    Ok(ctx) => {
468                        *updater_context.write().expect("poisoned") = ctx;
469                        Ok(())
470                    }
471                    Err(err) => {
472                        tracing::error!("failed to reload SSL certificate: {err}");
473                        Err(err)
474                    }
475                };
476                if let Some(chan) = chan {
477                    let _ = chan.send(result);
478                }
479            }
480            tracing::warn!("TlsCertConfig reloading_context updater closed");
481        });
482        Ok(ReloadingSslContext { context })
483    }
484}
485
486#[derive(Clone, Debug)]
488pub struct ReloadingSslContext {
489    context: Arc<RwLock<SslContext>>,
491}
492
493impl ReloadingSslContext {
494    pub fn get(&self) -> RwLockReadGuard<'_, SslContext> {
495        self.context.read().expect("poisoned")
496    }
497}
498
499#[derive(Clone, Debug)]
501pub struct ReloadingTlsConfig {
502    pub context: ReloadingSslContext,
504    pub mode: TlsMode,
506}
507
508pub type ReloadTrigger = BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>;
509
510pub fn default_cert_reload_ticker() -> ReloadTrigger {
512    let ticker = IntervalStream::new(tokio::time::interval(Duration::from_secs(60 * 60)));
513    let ticker = ticker.map(|_| None);
514    let ticker = Box::pin(ticker);
515    ticker
516}
517
518pub fn cert_reload_never_reload() -> ReloadTrigger {
520    let ticker = futures::stream::empty();
521    let ticker = Box::pin(ticker);
522    ticker
523}
524
525#[derive(Debug, Clone, clap::Parser)]
527pub struct TlsCliArgs {
528    #[clap(
537        long, env = "TLS_MODE",
538        value_parser = ["disable", "require"],
539        default_value = "disable",
540        default_value_ifs = [
541            ("frontegg_tenant", ArgPredicate::IsPresent, Some("require")),
542            ("frontegg_resolver_template", ArgPredicate::IsPresent, Some("require")),
543        ],
544        value_name = "MODE",
545    )]
546    tls_mode: String,
547    #[clap(
549        long,
550        env = "TLS_CERT",
551        requires = "tls_key",
552        required_if_eq_any([("tls_mode", "require")]),
553        value_name = "PATH"
554    )]
555    tls_cert: Option<PathBuf>,
556    #[clap(
558        long,
559        env = "TLS_KEY",
560        requires = "tls_cert",
561        required_if_eq_any([("tls_mode", "require")]),
562        value_name = "PATH"
563    )]
564    tls_key: Option<PathBuf>,
565}
566
567impl TlsCliArgs {
568    pub fn into_config(self) -> Result<Option<TlsCertConfig>, anyhow::Error> {
570        if self.tls_mode == "disable" {
571            if self.tls_cert.is_some() {
572                bail!("cannot specify --tls-mode=disable and --tls-cert simultaneously");
573            }
574            if self.tls_key.is_some() {
575                bail!("cannot specify --tls-mode=disable and --tls-key simultaneously");
576            }
577            Ok(None)
578        } else {
579            let cert = self.tls_cert.unwrap();
580            let key = self.tls_key.unwrap();
581            Ok(Some(TlsCertConfig { cert, key }))
582        }
583    }
584}