mz_server_core/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Methods common to servers listening for TCP connections.
11
12use std::fmt;
13use std::future::Future;
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use anyhow::bail;
23use async_trait::async_trait;
24use clap::builder::ArgPredicate;
25use futures::stream::{BoxStream, Stream, StreamExt};
26use mz_dyncfg::{Config, ConfigSet};
27use mz_ore::channel::trigger;
28use mz_ore::error::ErrorExt;
29use mz_ore::netio::AsyncReady;
30use mz_ore::option::OptionExt;
31use mz_ore::task::JoinSetExt;
32use openssl::ssl::{SslAcceptor, SslContext, SslFiletype, SslMethod};
33use proxy_header::{ParseConfig, ProxiedAddress, ProxyHeader};
34use schemars::JsonSchema;
35use scopeguard::ScopeGuard;
36use serde::{Deserialize, Serialize};
37use socket2::{SockRef, TcpKeepalive};
38use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, Interest, ReadBuf, Ready};
39use tokio::net::{TcpListener, TcpStream};
40use tokio::sync::oneshot;
41use tokio::task::JoinSet;
42use tokio_metrics::TaskMetrics;
43use tokio_stream::wrappers::{IntervalStream, TcpListenerStream};
44use tracing::{debug, error, warn};
45use uuid::Uuid;
46
47pub mod listeners;
48
49/// TCP keepalive settings. The idle time and interval match CockroachDB [0].
50/// The number of retries matches the Linux default.
51///
52/// [0]: https://github.com/cockroachdb/cockroach/pull/14063
53const KEEPALIVE: TcpKeepalive = TcpKeepalive::new()
54    .with_time(Duration::from_secs(60))
55    .with_interval(Duration::from_secs(60))
56    .with_retries(9);
57
58/// A future that handles a connection.
59pub type ConnectionHandler = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
60
61/// A wrapper around a [`TcpStream`] that can identify a connection across
62/// processes.
63pub struct Connection {
64    conn_uuid: Arc<Mutex<Option<Uuid>>>,
65    tcp_stream: TcpStream,
66}
67
68impl Connection {
69    fn new(tcp_stream: TcpStream) -> Connection {
70        Connection {
71            conn_uuid: Arc::new(Mutex::new(None)),
72            tcp_stream,
73        }
74    }
75
76    /// Returns a handle to the connection UUID.
77    pub fn uuid_handle(&self) -> ConnectionUuidHandle {
78        ConnectionUuidHandle(Arc::clone(&self.conn_uuid))
79    }
80
81    /// Attempts to parse a proxy header from the tcp_stream.
82    /// If none is found or it is unable to be parsed None will
83    /// be returned. If a header is found it will be returned and its
84    /// bytes will be removed from the stream.
85    ///
86    /// It is possible an invalid header was sent, if that is the case
87    /// any downstream service will be responsible for returning errors
88    /// to the client.
89    pub async fn take_proxy_header_address(&mut self) -> Option<ProxiedAddress> {
90        // 1024 bytes is a rather large header for tcp proxy header, unless
91        // if the header contains TLV fields or uses a unix socket address
92        // this could easily be hit. We'll use a 1024 byte max buf to allow
93        // limited support for this.
94        let mut buf = [0u8; 1024];
95        let len = match self.tcp_stream.peek(&mut buf).await {
96            Ok(n) if n > 0 => n,
97            _ => {
98                debug!("Failed to read from client socket or no data received");
99                return None;
100            }
101        };
102
103        // Attempt to parse the header, and log failures.
104        let (header, hlen) = match ProxyHeader::parse(
105            &buf[..len],
106            ParseConfig {
107                include_tlvs: false,
108                allow_v1: false,
109                allow_v2: true,
110            },
111        ) {
112            Ok((header, hlen)) => (header, hlen),
113            Err(proxy_header::Error::Invalid) => {
114                debug!(
115                    "Proxy header is invalid. This is likely due to no no header being provided",
116                );
117                return None;
118            }
119            Err(e) => {
120                debug!("Proxy header parse error '{:?}', ignoring header.", e);
121                return None;
122            }
123        };
124        debug!("Proxied connection with header {:?}", header);
125        let address = header.proxied_address().map(|a| a.to_owned());
126        // Proxy header found, clear the bytes.
127        let _ = self.read_exact(&mut buf[..hlen]).await;
128        address
129    }
130
131    /// Peer address of the inner tcp_stream.
132    pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
133        self.tcp_stream.peer_addr()
134    }
135}
136
137impl AsyncRead for Connection {
138    fn poll_read(
139        mut self: Pin<&mut Self>,
140        cx: &mut Context,
141        buf: &mut ReadBuf,
142    ) -> Poll<io::Result<()>> {
143        Pin::new(&mut self.tcp_stream).poll_read(cx, buf)
144    }
145}
146
147impl AsyncWrite for Connection {
148    fn poll_write(
149        mut self: Pin<&mut Self>,
150        cx: &mut Context,
151        buf: &[u8],
152    ) -> Poll<io::Result<usize>> {
153        Pin::new(&mut self.tcp_stream).poll_write(cx, buf)
154    }
155
156    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
157        Pin::new(&mut self.tcp_stream).poll_flush(cx)
158    }
159
160    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
161        Pin::new(&mut self.tcp_stream).poll_shutdown(cx)
162    }
163}
164
165#[async_trait]
166impl AsyncReady for Connection {
167    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
168        self.tcp_stream.ready(interest).await
169    }
170}
171
172/// A handle that permits getting and setting the UUID for a [`Connection`].
173///
174/// A connection's UUID is a globally unique value that can identify a given
175/// connection across environments and process boundaries. Connection UUIDs are
176/// never reused.
177///
178/// This is distinct from environmentd's concept of a "connection ID", which is
179/// a `u32` that only identifies a connection within a given environment and
180/// only during its lifetime. These connection IDs are frequently reused.
181pub struct ConnectionUuidHandle(Arc<Mutex<Option<Uuid>>>);
182
183impl ConnectionUuidHandle {
184    /// Gets the UUID for the connection, if it exists.
185    pub fn get(&self) -> Option<Uuid> {
186        *self.0.lock().expect("lock poisoned")
187    }
188
189    /// Sets the UUID for this connection.
190    pub fn set(&self, conn_uuid: Uuid) {
191        *self.0.lock().expect("lock poisoned") = Some(conn_uuid);
192    }
193
194    /// Returns a displayable that renders a possibly missing connection UUID.
195    pub fn display(&self) -> impl fmt::Display {
196        self.get().display_or("<unknown>")
197    }
198}
199
200/// A server handles incoming network connections.
201pub trait Server {
202    /// Returns the name of the connection handler for use in e.g. log messages.
203    const NAME: &'static str;
204
205    /// Handles a single connection.
206    fn handle_connection(
207        &self,
208        conn: Connection,
209        tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
210    ) -> ConnectionHandler;
211}
212
213/// A stream of incoming connections.
214pub trait ConnectionStream: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
215
216impl<T> ConnectionStream for T where T: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
217
218/// A handle to a listener created by [`listen`].
219#[derive(Debug)]
220pub struct ListenerHandle {
221    pub local_addr: SocketAddr,
222    _trigger: trigger::Trigger,
223}
224
225impl ListenerHandle {
226    /// Returns the local address to which the listener is bound.
227    pub fn local_addr(&self) -> SocketAddr {
228        self.local_addr
229    }
230}
231
232/// Listens for incoming TCP connections on the specified address.
233///
234/// Returns a handle to the listener and the stream of incoming connections
235/// produced by the listener. When the handle is dropped, the listener is
236/// closed, and the stream of incoming connections terminates.
237pub async fn listen(
238    addr: &SocketAddr,
239) -> Result<(ListenerHandle, Pin<Box<dyn ConnectionStream>>), io::Error> {
240    let listener = TcpListener::bind(addr).await?;
241    let local_addr = listener.local_addr()?;
242    let (trigger, trigger_rx) = trigger::channel();
243    let handle = ListenerHandle {
244        local_addr,
245        _trigger: trigger,
246    };
247    // TODO(benesch): replace `TCPListenerStream`s with `listener.incoming()` if
248    // that is restored when the `Stream` trait stabilizes.
249    let stream = TcpListenerStream::new(listener).take_until(trigger_rx);
250    Ok((handle, Box::pin(stream)))
251}
252
253/// Configuration for [`serve`].
254pub struct ServeConfig<S, C>
255where
256    S: Server,
257    C: ConnectionStream,
258{
259    /// The server for the connections.
260    pub server: S,
261    /// The stream of incoming TCP connections.
262    pub conns: C,
263    /// Optional dynamic configuration for the server.
264    pub dyncfg: Option<ServeDyncfg>,
265}
266
267/// Dynamic configuration for [`ServeConfig`].
268pub struct ServeDyncfg {
269    /// The current bundle of dynamic configuration values.
270    pub config_set: ConfigSet,
271    /// A configuration in `config_set` that specifies how long to wait for
272    /// connections to terminate after receiving a SIGTERM before forcibly
273    /// terminated.
274    ///
275    /// If `None`, then forcible shutdown occurs immediately.
276    pub sigterm_wait_config: &'static Config<Duration>,
277}
278
279/// Serves incoming TCP connections.
280///
281/// Returns handles to the outstanding connections after the configured timeout
282/// has expired or all connections have completed.
283pub async fn serve<S, C>(
284    ServeConfig {
285        server,
286        mut conns,
287        dyncfg,
288    }: ServeConfig<S, C>,
289) -> JoinSet<()>
290where
291    S: Server,
292    C: ConnectionStream,
293{
294    let task_name = format!("handle_{}_connection", S::NAME);
295    let mut set = JoinSet::new();
296    loop {
297        tokio::select! {
298            // next() is cancel safe.
299            conn = conns.next() => {
300                let conn = match conn {
301                    None => break,
302                    Some(Ok(conn)) => conn,
303                    Some(Err(err)) => {
304                        error!("error accepting connection: {}", err);
305                        continue;
306                    }
307                };
308                // Set TCP_NODELAY to disable tinygram prevention (Nagle's
309                // algorithm), which forces a 40ms delay between each query
310                // on linux. According to John Nagle [0], the true problem
311                // is delayed acks, but disabling those is a receive-side
312                // operation (TCP_QUICKACK), and we can't always control the
313                // client. PostgreSQL sets TCP_NODELAY on both sides of its
314                // sockets, so it seems sane to just do the same.
315                //
316                // If set_nodelay fails, it's a programming error, so panic.
317                //
318                // [0]: https://news.ycombinator.com/item?id=10608356
319                conn.set_nodelay(true).expect("set_nodelay failed");
320                // Enable TCP keepalives to avoid any idle connection timeouts that may
321                // be enforced by networking devices between us and the client. Idle SQL
322                // connections are expected--e.g., a `SUBSCRIBE` to a view containing
323                // critical alerts will ideally be producing no data most of the time.
324                if let Err(e) = SockRef::from(&conn).set_tcp_keepalive(&KEEPALIVE) {
325                    error!("failed enabling keepalive: {e}");
326                    continue;
327                }
328                let conn = Connection::new(conn);
329                let conn_uuid = conn.uuid_handle();
330                let metrics_monitor = tokio_metrics::TaskMonitor::new();
331                let tokio_metrics_intervals = metrics_monitor.intervals();
332                let fut = server.handle_connection(conn, tokio_metrics_intervals);
333                set.spawn_named(|| &task_name, metrics_monitor.instrument(async move {
334                    let guard = scopeguard::guard((), |_| {
335                        debug!(
336                            server = S::NAME,
337                            conn_uuid = %conn_uuid.display(),
338                            "dropping connection without explicit termination",
339                        );
340                    });
341
342                    match fut.await {
343                        Ok(()) => {
344                            debug!(
345                                server = S::NAME,
346                                conn_uuid = %conn_uuid.display(),
347                                "successfully handled connection",
348                            );
349                        }
350                        Err(e) => {
351                            warn!(
352                                server = S::NAME,
353                                conn_uuid = %conn_uuid.display(),
354                                "error handling connection: {}",
355                                e.display_with_causes(),
356                            );
357                        }
358                    }
359
360                    let () = ScopeGuard::into_inner(guard);
361                }));
362            }
363            // Actively cull completed tasks from the JoinSet so it does not grow unbounded. This
364            // method is cancel safe.
365            res = set.join_next(), if set.len() > 0 => {
366                if let Some(Err(e)) = res {
367                    warn!(
368                        "error joining connection in {}: {}",
369                        S::NAME,
370                        e.display_with_causes()
371                    );
372                }
373            }
374        }
375    }
376    if let Some(dyncfg) = dyncfg {
377        let wait = dyncfg.sigterm_wait_config.get(&dyncfg.config_set);
378        if set.len() > 0 {
379            warn!(
380                "{} exiting, {} outstanding connections, waiting for {:?}",
381                S::NAME,
382                set.len(),
383                wait
384            );
385        }
386        let timedout = tokio::time::timeout(wait, async {
387            while let Some(res) = set.join_next().await {
388                if let Err(e) = res {
389                    warn!(
390                        "error joining connection in {}: {}",
391                        S::NAME,
392                        e.display_with_causes()
393                    );
394                }
395            }
396        })
397        .await;
398        if timedout.is_err() {
399            warn!(
400                "{}: wait timeout of {:?} exceeded, {} outstanding connections",
401                S::NAME,
402                wait,
403                set.len()
404            );
405        }
406    }
407    set
408}
409
410/// Configures a server's TLS encryption and authentication.
411#[derive(Clone, Debug)]
412pub struct TlsConfig {
413    /// The SSL context used to manage incoming TLS negotiations.
414    pub context: SslContext,
415    /// The TLS mode.
416    pub mode: TlsMode,
417}
418
419/// Specifies how strictly to enforce TLS encryption.
420#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, JsonSchema)]
421pub enum TlsMode {
422    /// Allow TLS encryption.
423    Allow,
424    /// Require that clients negotiate TLS encryption.
425    Require,
426}
427
428/// Configures TLS encryption for connections.
429#[derive(Debug, Clone)]
430pub struct TlsCertConfig {
431    /// The path to the TLS certificate.
432    pub cert: PathBuf,
433    /// The path to the TLS key.
434    pub key: PathBuf,
435}
436
437impl TlsCertConfig {
438    /// Returns the SSL context to use in TlsConfigs.
439    pub fn load_context(&self) -> Result<SslContext, anyhow::Error> {
440        // Mozilla publishes three presets: old, intermediate, and modern. They
441        // recommend the intermediate preset for general purpose servers, which
442        // is what we use, as it is compatible with nearly every client released
443        // in the last five years but does not include any known-problematic
444        // ciphers. We once tried to use the modern preset, but it was
445        // incompatible with Fivetran, and presumably other JDBC-based tools.
446        let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
447        builder.set_certificate_chain_file(&self.cert)?;
448        builder.set_private_key_file(&self.key, SslFiletype::PEM)?;
449        Ok(builder.build().into_context())
450    }
451
452    /// Like [Self::load_context] but attempts to reload the files each time `ticker` yields an item.
453    /// Returns an error based on the files currently on disk. When `ticker` receives, the
454    /// certificates are reloaded from the context. The result of the reloading is returned on the
455    /// oneshot if present, and an Ok result means new connections will use the new certificates. An
456    /// Err result will not change the current certificates.
457    pub fn reloading_context(
458        &self,
459        mut ticker: ReloadTrigger,
460    ) -> Result<ReloadingSslContext, anyhow::Error> {
461        let context = Arc::new(RwLock::new(self.load_context()?));
462        let updater_context = Arc::clone(&context);
463        let config = self.clone();
464        mz_ore::task::spawn(|| "TlsCertConfig reloading_context", async move {
465            while let Some(chan) = ticker.next().await {
466                let result = match config.load_context() {
467                    Ok(ctx) => {
468                        *updater_context.write().expect("poisoned") = ctx;
469                        Ok(())
470                    }
471                    Err(err) => {
472                        tracing::error!("failed to reload SSL certificate: {err}");
473                        Err(err)
474                    }
475                };
476                if let Some(chan) = chan {
477                    let _ = chan.send(result);
478                }
479            }
480            tracing::warn!("TlsCertConfig reloading_context updater closed");
481        });
482        Ok(ReloadingSslContext { context })
483    }
484}
485
486/// An SslContext whose inner value can be updated.
487#[derive(Clone, Debug)]
488pub struct ReloadingSslContext {
489    /// The current SSL context.
490    context: Arc<RwLock<SslContext>>,
491}
492
493impl ReloadingSslContext {
494    pub fn get(&self) -> RwLockReadGuard<'_, SslContext> {
495        self.context.read().expect("poisoned")
496    }
497}
498
499/// Configures a server's TLS encryption and authentication with reloading.
500#[derive(Clone, Debug)]
501pub struct ReloadingTlsConfig {
502    /// The SSL context used to manage incoming TLS negotiations.
503    pub context: ReloadingSslContext,
504    /// The TLS mode.
505    pub mode: TlsMode,
506}
507
508pub type ReloadTrigger = BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>;
509
510/// Returns a ReloadTrigger that triggers once per hour.
511pub fn default_cert_reload_ticker() -> ReloadTrigger {
512    let ticker = IntervalStream::new(tokio::time::interval(Duration::from_secs(60 * 60)));
513    let ticker = ticker.map(|_| None);
514    let ticker = Box::pin(ticker);
515    ticker
516}
517
518/// Returns a ReloadTrigger that never triggers.
519pub fn cert_reload_never_reload() -> ReloadTrigger {
520    let ticker = futures::stream::empty();
521    let ticker = Box::pin(ticker);
522    ticker
523}
524
525/// Command line arguments for TLS.
526#[derive(Debug, Clone, clap::Parser)]
527pub struct TlsCliArgs {
528    /// How stringently to demand TLS authentication and encryption.
529    ///
530    /// If set to "disable", then environmentd rejects HTTP and PostgreSQL
531    /// connections that negotiate TLS.
532    ///
533    /// If set to "require", then environmentd requires that all HTTP and
534    /// PostgreSQL connections negotiate TLS. Unencrypted connections will be
535    /// rejected.
536    #[clap(
537        long, env = "TLS_MODE",
538        value_parser = ["disable", "require"],
539        default_value = "disable",
540        default_value_ifs = [
541            ("frontegg_tenant", ArgPredicate::IsPresent, Some("require")),
542            ("frontegg_resolver_template", ArgPredicate::IsPresent, Some("require")),
543        ],
544        value_name = "MODE",
545    )]
546    tls_mode: String,
547    /// Certificate file for TLS connections.
548    #[clap(
549        long,
550        env = "TLS_CERT",
551        requires = "tls_key",
552        required_if_eq_any([("tls_mode", "require")]),
553        value_name = "PATH"
554    )]
555    tls_cert: Option<PathBuf>,
556    /// Private key file for TLS connections.
557    #[clap(
558        long,
559        env = "TLS_KEY",
560        requires = "tls_cert",
561        required_if_eq_any([("tls_mode", "require")]),
562        value_name = "PATH"
563    )]
564    tls_key: Option<PathBuf>,
565}
566
567impl TlsCliArgs {
568    /// Convert args into configuration.
569    pub fn into_config(self) -> Result<Option<TlsCertConfig>, anyhow::Error> {
570        if self.tls_mode == "disable" {
571            if self.tls_cert.is_some() {
572                bail!("cannot specify --tls-mode=disable and --tls-cert simultaneously");
573            }
574            if self.tls_key.is_some() {
575                bail!("cannot specify --tls-mode=disable and --tls-key simultaneously");
576            }
577            Ok(None)
578        } else {
579            let cert = self.tls_cert.unwrap();
580            let key = self.tls_key.unwrap();
581            Ok(Some(TlsCertConfig { cert, key }))
582        }
583    }
584}