Skip to main content

mz_server_core/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Methods common to servers listening for TCP connections.
11
12use std::fmt;
13use std::future::Future;
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use anyhow::bail;
23use async_trait::async_trait;
24use clap::builder::ArgPredicate;
25use futures::stream::{BoxStream, Stream, StreamExt};
26use mz_dyncfg::{Config, ConfigSet};
27use mz_ore::channel::trigger;
28use mz_ore::error::ErrorExt;
29use mz_ore::netio::AsyncReady;
30use mz_ore::option::OptionExt;
31use mz_ore::task::JoinSetExt;
32use openssl::ssl::{SslAcceptor, SslContext, SslFiletype, SslMethod};
33use proxy_header::{ParseConfig, ProxiedAddress, ProxyHeader};
34use schemars::JsonSchema;
35use scopeguard::ScopeGuard;
36use serde::{Deserialize, Serialize};
37use socket2::{SockRef, TcpKeepalive};
38use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, Interest, ReadBuf, Ready};
39use tokio::net::{TcpListener, TcpStream};
40use tokio::sync::oneshot;
41use tokio::task::JoinSet;
42use tokio_metrics::TaskMetrics;
43use tokio_stream::wrappers::{IntervalStream, TcpListenerStream};
44use tracing::{debug, error, warn};
45use uuid::Uuid;
46
47pub mod listeners;
48
49/// TCP keepalive settings. The idle time and interval match CockroachDB [0].
50/// The number of retries matches the Linux default.
51///
52/// [0]: https://github.com/cockroachdb/cockroach/pull/14063
53const KEEPALIVE: TcpKeepalive = TcpKeepalive::new()
54    .with_time(Duration::from_secs(60))
55    .with_interval(Duration::from_secs(60))
56    .with_retries(9);
57
58/// A future that handles a connection.
59pub type ConnectionHandler = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
60
61/// A wrapper around a [`TcpStream`] that can identify a connection across
62/// processes.
63pub struct Connection {
64    conn_uuid: Arc<Mutex<Option<Uuid>>>,
65    tcp_stream: TcpStream,
66}
67
68impl Connection {
69    fn new(tcp_stream: TcpStream) -> Connection {
70        Connection {
71            conn_uuid: Arc::new(Mutex::new(None)),
72            tcp_stream,
73        }
74    }
75
76    /// Returns a handle to the connection UUID.
77    pub fn uuid_handle(&self) -> ConnectionUuidHandle {
78        ConnectionUuidHandle(Arc::clone(&self.conn_uuid))
79    }
80
81    /// Attempts to parse a proxy header from the tcp_stream.
82    /// If none is found or it is unable to be parsed None will
83    /// be returned. If a header is found it will be returned and its
84    /// bytes will be removed from the stream.
85    ///
86    /// It is possible an invalid header was sent, if that is the case
87    /// any downstream service will be responsible for returning errors
88    /// to the client.
89    pub async fn take_proxy_header_address(&mut self) -> Option<ProxiedAddress> {
90        // 1024 bytes is a rather large header for tcp proxy header, unless
91        // if the header contains TLV fields or uses a unix socket address
92        // this could easily be hit. We'll use a 1024 byte max buf to allow
93        // limited support for this.
94        let mut buf = [0u8; 1024];
95        let len = match self.tcp_stream.peek(&mut buf).await {
96            Ok(n) if n > 0 => n,
97            _ => {
98                debug!("Failed to read from client socket or no data received");
99                return None;
100            }
101        };
102
103        // Attempt to parse the header, and log failures.
104        let (header, hlen) = match ProxyHeader::parse(
105            &buf[..len],
106            ParseConfig {
107                include_tlvs: false,
108                allow_v1: false,
109                allow_v2: true,
110            },
111        ) {
112            Ok((header, hlen)) => (header, hlen),
113            Err(proxy_header::Error::Invalid) => {
114                debug!("Proxy header is invalid. This is likely due to no header being provided",);
115                return None;
116            }
117            // Data matches the PROXY v2 signature prefix but the header
118            // is incomplete — likely split across TCP segments. Read the
119            // 16-byte fixed v2 header to learn the total size, then read
120            // the remaining address bytes.
121            Err(proxy_header::Error::BufferTooShort) => {
122                return self.read_proxy_v2_header(&mut buf).await;
123            }
124            Err(e) => {
125                debug!("Proxy header parse error '{:?}', ignoring header.", e);
126                return None;
127            }
128        };
129        debug!("Proxied connection with header {:?}", header);
130        let address = header.proxied_address().map(|a| a.to_owned());
131        // Proxy header found, clear the bytes.
132        let _ = self.read_exact(&mut buf[..hlen]).await;
133        address
134    }
135
136    /// Fallback path for [`Self::take_proxy_header_address`] when the initial
137    /// peek returned an incomplete PROXY v2 header. Reads the fixed 16-byte
138    /// v2 prefix to learn the total header size, then reads the rest.
139    async fn read_proxy_v2_header(&mut self, buf: &mut [u8; 1024]) -> Option<ProxiedAddress> {
140        // PROXY v2 fixed prefix: 12-byte signature + ver/cmd + fam/proto + 2-byte length.
141        const V2_PREFIX_LEN: usize = 16;
142        if self.read_exact(&mut buf[..V2_PREFIX_LEN]).await.is_err() {
143            debug!("Failed to read PROXY v2 fixed header");
144            return None;
145        }
146        let addr_len = usize::from(u16::from_be_bytes([buf[14], buf[15]]));
147        let total = V2_PREFIX_LEN + addr_len;
148        if total > buf.len() {
149            debug!("PROXY v2 header too large: {total} bytes");
150            return None;
151        }
152        if self
153            .read_exact(&mut buf[V2_PREFIX_LEN..total])
154            .await
155            .is_err()
156        {
157            debug!("Failed to read PROXY v2 address data");
158            return None;
159        }
160        match ProxyHeader::parse(
161            &buf[..total],
162            ParseConfig {
163                include_tlvs: false,
164                allow_v1: false,
165                allow_v2: true,
166            },
167        ) {
168            Ok((header, _)) => {
169                debug!("Proxied connection with header {:?}", header);
170                header.proxied_address().map(|a| a.to_owned())
171            }
172            Err(e) => {
173                debug!("Proxy header parse error '{:?}', ignoring header.", e);
174                None
175            }
176        }
177    }
178
179    /// Peer address of the inner tcp_stream.
180    pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
181        self.tcp_stream.peer_addr()
182    }
183}
184
185impl AsyncRead for Connection {
186    fn poll_read(
187        mut self: Pin<&mut Self>,
188        cx: &mut Context,
189        buf: &mut ReadBuf,
190    ) -> Poll<io::Result<()>> {
191        Pin::new(&mut self.tcp_stream).poll_read(cx, buf)
192    }
193}
194
195impl AsyncWrite for Connection {
196    fn poll_write(
197        mut self: Pin<&mut Self>,
198        cx: &mut Context,
199        buf: &[u8],
200    ) -> Poll<io::Result<usize>> {
201        Pin::new(&mut self.tcp_stream).poll_write(cx, buf)
202    }
203
204    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
205        Pin::new(&mut self.tcp_stream).poll_flush(cx)
206    }
207
208    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
209        Pin::new(&mut self.tcp_stream).poll_shutdown(cx)
210    }
211}
212
213#[async_trait]
214impl AsyncReady for Connection {
215    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
216        self.tcp_stream.ready(interest).await
217    }
218}
219
220/// A handle that permits getting and setting the UUID for a [`Connection`].
221///
222/// A connection's UUID is a globally unique value that can identify a given
223/// connection across environments and process boundaries. Connection UUIDs are
224/// never reused.
225///
226/// This is distinct from environmentd's concept of a "connection ID", which is
227/// a `u32` that only identifies a connection within a given environment and
228/// only during its lifetime. These connection IDs are frequently reused.
229pub struct ConnectionUuidHandle(Arc<Mutex<Option<Uuid>>>);
230
231impl ConnectionUuidHandle {
232    /// Gets the UUID for the connection, if it exists.
233    pub fn get(&self) -> Option<Uuid> {
234        *self.0.lock().expect("lock poisoned")
235    }
236
237    /// Sets the UUID for this connection.
238    pub fn set(&self, conn_uuid: Uuid) {
239        *self.0.lock().expect("lock poisoned") = Some(conn_uuid);
240    }
241
242    /// Returns a displayable that renders a possibly missing connection UUID.
243    pub fn display(&self) -> impl fmt::Display {
244        self.get().display_or("<unknown>")
245    }
246}
247
248/// A server handles incoming network connections.
249pub trait Server {
250    /// Returns the name of the connection handler for use in e.g. log messages.
251    const NAME: &'static str;
252
253    /// Handles a single connection.
254    fn handle_connection(
255        &self,
256        conn: Connection,
257        tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
258    ) -> ConnectionHandler;
259}
260
261/// A stream of incoming connections.
262pub trait ConnectionStream: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
263
264impl<T> ConnectionStream for T where T: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
265
266/// A handle to a listener created by [`listen`].
267#[derive(Debug)]
268pub struct ListenerHandle {
269    pub local_addr: SocketAddr,
270    _trigger: trigger::Trigger,
271}
272
273impl ListenerHandle {
274    /// Returns the local address to which the listener is bound.
275    pub fn local_addr(&self) -> SocketAddr {
276        self.local_addr
277    }
278}
279
280/// Listens for incoming TCP connections on the specified address.
281///
282/// Returns a handle to the listener and the stream of incoming connections
283/// produced by the listener. When the handle is dropped, the listener is
284/// closed, and the stream of incoming connections terminates.
285pub async fn listen(
286    addr: &SocketAddr,
287) -> Result<(ListenerHandle, Pin<Box<dyn ConnectionStream>>), io::Error> {
288    let listener = TcpListener::bind(addr).await?;
289    let local_addr = listener.local_addr()?;
290    let (trigger, trigger_rx) = trigger::channel();
291    let handle = ListenerHandle {
292        local_addr,
293        _trigger: trigger,
294    };
295    // TODO(benesch): replace `TCPListenerStream`s with `listener.incoming()` if
296    // that is restored when the `Stream` trait stabilizes.
297    let stream = TcpListenerStream::new(listener).take_until(trigger_rx);
298    Ok((handle, Box::pin(stream)))
299}
300
301/// Configuration for [`serve`].
302pub struct ServeConfig<S, C>
303where
304    S: Server,
305    C: ConnectionStream,
306{
307    /// The server for the connections.
308    pub server: S,
309    /// The stream of incoming TCP connections.
310    pub conns: C,
311    /// Optional dynamic configuration for the server.
312    pub dyncfg: Option<ServeDyncfg>,
313}
314
315/// Dynamic configuration for [`ServeConfig`].
316pub struct ServeDyncfg {
317    /// The current bundle of dynamic configuration values.
318    pub config_set: ConfigSet,
319    /// A configuration in `config_set` that specifies how long to wait for
320    /// connections to terminate after receiving a SIGTERM before forcibly
321    /// terminated.
322    ///
323    /// If `None`, then forcible shutdown occurs immediately.
324    pub sigterm_wait_config: &'static Config<Duration>,
325}
326
327/// Serves incoming TCP connections.
328///
329/// Returns handles to the outstanding connections after the configured timeout
330/// has expired or all connections have completed.
331pub async fn serve<S, C>(
332    ServeConfig {
333        server,
334        mut conns,
335        dyncfg,
336    }: ServeConfig<S, C>,
337) -> JoinSet<()>
338where
339    S: Server,
340    C: ConnectionStream,
341{
342    let task_name = format!("handle_{}_connection", S::NAME);
343    let mut set = JoinSet::new();
344    loop {
345        tokio::select! {
346            // next() is cancel safe.
347            conn = conns.next() => {
348                let conn = match conn {
349                    None => break,
350                    Some(Ok(conn)) => conn,
351                    Some(Err(err)) => {
352                        error!("error accepting connection: {}", err);
353                        continue;
354                    }
355                };
356                // Set TCP_NODELAY to disable tinygram prevention (Nagle's
357                // algorithm), which forces a 40ms delay between each query
358                // on linux. According to John Nagle [0], the true problem
359                // is delayed acks, but disabling those is a receive-side
360                // operation (TCP_QUICKACK), and we can't always control the
361                // client. PostgreSQL sets TCP_NODELAY on both sides of its
362                // sockets, so it seems sane to just do the same.
363                //
364                // If set_nodelay fails, it's a programming error, so panic.
365                //
366                // [0]: https://news.ycombinator.com/item?id=10608356
367                conn.set_nodelay(true).expect("set_nodelay failed");
368                // Enable TCP keepalives to avoid any idle connection timeouts that may
369                // be enforced by networking devices between us and the client. Idle SQL
370                // connections are expected--e.g., a `SUBSCRIBE` to a view containing
371                // critical alerts will ideally be producing no data most of the time.
372                if let Err(e) = SockRef::from(&conn).set_tcp_keepalive(&KEEPALIVE) {
373                    error!("failed enabling keepalive: {e}");
374                    continue;
375                }
376                let conn = Connection::new(conn);
377                let conn_uuid = conn.uuid_handle();
378                let metrics_monitor = tokio_metrics::TaskMonitor::new();
379                let tokio_metrics_intervals = metrics_monitor.intervals();
380                let fut = server.handle_connection(conn, tokio_metrics_intervals);
381                set.spawn_named(|| &task_name, metrics_monitor.instrument(async move {
382                    let guard = scopeguard::guard((), |_| {
383                        debug!(
384                            server = S::NAME,
385                            conn_uuid = %conn_uuid.display(),
386                            "dropping connection without explicit termination",
387                        );
388                    });
389
390                    match fut.await {
391                        Ok(()) => {
392                            debug!(
393                                server = S::NAME,
394                                conn_uuid = %conn_uuid.display(),
395                                "successfully handled connection",
396                            );
397                        }
398                        Err(e) => {
399                            warn!(
400                                server = S::NAME,
401                                conn_uuid = %conn_uuid.display(),
402                                "error handling connection: {}",
403                                e.display_with_causes(),
404                            );
405                        }
406                    }
407
408                    let () = ScopeGuard::into_inner(guard);
409                }));
410            }
411            // Actively cull completed tasks from the JoinSet so it does not grow unbounded. This
412            // method is cancel safe.
413            res = set.join_next(), if set.len() > 0 => {
414                if let Some(Err(e)) = res {
415                    warn!(
416                        "error joining connection in {}: {}",
417                        S::NAME,
418                        e.display_with_causes()
419                    );
420                }
421            }
422        }
423    }
424    if let Some(dyncfg) = dyncfg {
425        let wait = dyncfg.sigterm_wait_config.get(&dyncfg.config_set);
426        if set.len() > 0 {
427            warn!(
428                "{} exiting, {} outstanding connections, waiting for {:?}",
429                S::NAME,
430                set.len(),
431                wait
432            );
433        }
434        let timedout = tokio::time::timeout(wait, async {
435            while let Some(res) = set.join_next().await {
436                if let Err(e) = res {
437                    warn!(
438                        "error joining connection in {}: {}",
439                        S::NAME,
440                        e.display_with_causes()
441                    );
442                }
443            }
444        })
445        .await;
446        if timedout.is_err() {
447            warn!(
448                "{}: wait timeout of {:?} exceeded, {} outstanding connections",
449                S::NAME,
450                wait,
451                set.len()
452            );
453        }
454    }
455    set
456}
457
458/// Configures a server's TLS encryption and authentication.
459#[derive(Clone, Debug)]
460pub struct TlsConfig {
461    /// The SSL context used to manage incoming TLS negotiations.
462    pub context: SslContext,
463    /// The TLS mode.
464    pub mode: TlsMode,
465}
466
467/// Specifies how strictly to enforce TLS encryption.
468#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, JsonSchema)]
469pub enum TlsMode {
470    /// Allow TLS encryption.
471    Allow,
472    /// Require that clients negotiate TLS encryption.
473    Require,
474}
475
476/// Configures TLS encryption for connections.
477#[derive(Debug, Clone)]
478pub struct TlsCertConfig {
479    /// The path to the TLS certificate.
480    pub cert: PathBuf,
481    /// The path to the TLS key.
482    pub key: PathBuf,
483}
484
485impl TlsCertConfig {
486    /// Returns the SSL context to use in TlsConfigs.
487    pub fn load_context(&self) -> Result<SslContext, anyhow::Error> {
488        // Mozilla publishes three presets: old, intermediate, and modern. They
489        // recommend the intermediate preset for general purpose servers, which
490        // is what we use, as it is compatible with nearly every client released
491        // in the last five years but does not include any known-problematic
492        // ciphers. We once tried to use the modern preset, but it was
493        // incompatible with Fivetran, and presumably other JDBC-based tools.
494        let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
495        builder.set_certificate_chain_file(&self.cert)?;
496        builder.set_private_key_file(&self.key, SslFiletype::PEM)?;
497        Ok(builder.build().into_context())
498    }
499
500    /// Like [Self::load_context] but attempts to reload the files each time `ticker` yields an item.
501    /// Returns an error based on the files currently on disk. When `ticker` receives, the
502    /// certificates are reloaded from the context. The result of the reloading is returned on the
503    /// oneshot if present, and an Ok result means new connections will use the new certificates. An
504    /// Err result will not change the current certificates.
505    pub fn reloading_context(
506        &self,
507        mut ticker: ReloadTrigger,
508    ) -> Result<ReloadingSslContext, anyhow::Error> {
509        let context = Arc::new(RwLock::new(self.load_context()?));
510        let updater_context = Arc::clone(&context);
511        let config = self.clone();
512        mz_ore::task::spawn(|| "TlsCertConfig reloading_context", async move {
513            while let Some(chan) = ticker.next().await {
514                let result = match config.load_context() {
515                    Ok(ctx) => {
516                        *updater_context.write().expect("poisoned") = ctx;
517                        Ok(())
518                    }
519                    Err(err) => {
520                        tracing::error!("failed to reload SSL certificate: {err}");
521                        Err(err)
522                    }
523                };
524                if let Some(chan) = chan {
525                    let _ = chan.send(result);
526                }
527            }
528            tracing::warn!("TlsCertConfig reloading_context updater closed");
529        });
530        Ok(ReloadingSslContext { context })
531    }
532}
533
534/// An SslContext whose inner value can be updated.
535#[derive(Clone, Debug)]
536pub struct ReloadingSslContext {
537    /// The current SSL context.
538    context: Arc<RwLock<SslContext>>,
539}
540
541impl ReloadingSslContext {
542    pub fn get(&self) -> RwLockReadGuard<'_, SslContext> {
543        self.context.read().expect("poisoned")
544    }
545}
546
547/// Configures a server's TLS encryption and authentication with reloading.
548#[derive(Clone, Debug)]
549pub struct ReloadingTlsConfig {
550    /// The SSL context used to manage incoming TLS negotiations.
551    pub context: ReloadingSslContext,
552    /// The TLS mode.
553    pub mode: TlsMode,
554}
555
556pub type ReloadTrigger = BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>;
557
558/// Returns a ReloadTrigger that triggers once per hour.
559pub fn default_cert_reload_ticker() -> ReloadTrigger {
560    let ticker = IntervalStream::new(tokio::time::interval(Duration::from_secs(60 * 60)));
561    let ticker = ticker.map(|_| None);
562    let ticker = Box::pin(ticker);
563    ticker
564}
565
566/// Returns a ReloadTrigger that never triggers.
567pub fn cert_reload_never_reload() -> ReloadTrigger {
568    let ticker = futures::stream::empty();
569    let ticker = Box::pin(ticker);
570    ticker
571}
572
573/// Command line arguments for TLS.
574#[derive(Debug, Clone, clap::Parser)]
575pub struct TlsCliArgs {
576    /// How stringently to demand TLS authentication and encryption.
577    ///
578    /// If set to "disable", then environmentd rejects HTTP and PostgreSQL
579    /// connections that negotiate TLS.
580    ///
581    /// If set to "require", then environmentd requires that all HTTP and
582    /// PostgreSQL connections negotiate TLS. Unencrypted connections will be
583    /// rejected.
584    #[clap(
585        long, env = "TLS_MODE",
586        value_parser = ["disable", "require"],
587        default_value = "disable",
588        default_value_ifs = [
589            ("frontegg_tenant", ArgPredicate::IsPresent, Some("require")),
590            ("frontegg_resolver_template", ArgPredicate::IsPresent, Some("require")),
591        ],
592        value_name = "MODE",
593    )]
594    tls_mode: String,
595    /// Certificate file for TLS connections.
596    #[clap(
597        long,
598        env = "TLS_CERT",
599        requires = "tls_key",
600        required_if_eq_any([("tls_mode", "require")]),
601        value_name = "PATH"
602    )]
603    tls_cert: Option<PathBuf>,
604    /// Private key file for TLS connections.
605    #[clap(
606        long,
607        env = "TLS_KEY",
608        requires = "tls_cert",
609        required_if_eq_any([("tls_mode", "require")]),
610        value_name = "PATH"
611    )]
612    tls_key: Option<PathBuf>,
613}
614
615impl TlsCliArgs {
616    /// Convert args into configuration.
617    pub fn into_config(self) -> Result<Option<TlsCertConfig>, anyhow::Error> {
618        if self.tls_mode == "disable" {
619            if self.tls_cert.is_some() {
620                bail!("cannot specify --tls-mode=disable and --tls-cert simultaneously");
621            }
622            if self.tls_key.is_some() {
623                bail!("cannot specify --tls-mode=disable and --tls-key simultaneously");
624            }
625            Ok(None)
626        } else {
627            let cert = self.tls_cert.unwrap();
628            let key = self.tls_key.unwrap();
629            Ok(Some(TlsCertConfig { cert, key }))
630        }
631    }
632}