mz_server_core/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Methods common to servers listening for TCP connections.
11
12use std::fmt;
13use std::future::Future;
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use anyhow::bail;
23use async_trait::async_trait;
24use clap::builder::ArgPredicate;
25use futures::stream::{BoxStream, Stream, StreamExt};
26use mz_dyncfg::{Config, ConfigSet};
27use mz_ore::channel::trigger;
28use mz_ore::error::ErrorExt;
29use mz_ore::netio::AsyncReady;
30use mz_ore::option::OptionExt;
31use mz_ore::task::JoinSetExt;
32use openssl::ssl::{SslAcceptor, SslContext, SslFiletype, SslMethod};
33use proxy_header::{ParseConfig, ProxiedAddress, ProxyHeader};
34use scopeguard::ScopeGuard;
35use socket2::{SockRef, TcpKeepalive};
36use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, Interest, ReadBuf, Ready};
37use tokio::net::{TcpListener, TcpStream};
38use tokio::sync::oneshot;
39use tokio::task::JoinSet;
40use tokio_stream::wrappers::{IntervalStream, TcpListenerStream};
41use tracing::{debug, error, warn};
42use uuid::Uuid;
43
44/// TCP keepalive settings. The idle time and interval match CockroachDB [0].
45/// The number of retries matches the Linux default.
46///
47/// [0]: https://github.com/cockroachdb/cockroach/pull/14063
48const KEEPALIVE: TcpKeepalive = TcpKeepalive::new()
49    .with_time(Duration::from_secs(60))
50    .with_interval(Duration::from_secs(60))
51    .with_retries(9);
52
53/// A future that handles a connection.
54pub type ConnectionHandler = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
55
56/// A wrapper around a [`TcpStream`] that can identify a connection across
57/// processes.
58pub struct Connection {
59    conn_uuid: Arc<Mutex<Option<Uuid>>>,
60    tcp_stream: TcpStream,
61}
62
63impl Connection {
64    fn new(tcp_stream: TcpStream) -> Connection {
65        Connection {
66            conn_uuid: Arc::new(Mutex::new(None)),
67            tcp_stream,
68        }
69    }
70
71    /// Returns a handle to the connection UUID.
72    pub fn uuid_handle(&self) -> ConnectionUuidHandle {
73        ConnectionUuidHandle(Arc::clone(&self.conn_uuid))
74    }
75
76    /// Attempts to parse a proxy header from the tcp_stream.
77    /// If none is found or it is unable to be parsed None will
78    /// be returned. If a header is found it will be returned and its
79    /// bytes will be removed from the stream.
80    ///
81    /// It is possible an invalid header was sent, if that is the case
82    /// any downstream service will be responsible for returning errors
83    /// to the client.
84    pub async fn take_proxy_header_address(&mut self) -> Option<ProxiedAddress> {
85        // 1024 bytes is a rather large header for tcp proxy header, unless
86        // if the header contains TLV fields or uses a unix socket address
87        // this could easily be hit. We'll use a 1024 byte max buf to allow
88        // limited support for this.
89        let mut buf = [0u8; 1024];
90        let len = match self.tcp_stream.peek(&mut buf).await {
91            Ok(n) if n > 0 => n,
92            _ => {
93                debug!("Failed to read from client socket or no data received");
94                return None;
95            }
96        };
97
98        // Attempt to parse the header, and log failures.
99        let (header, hlen) = match ProxyHeader::parse(
100            &buf[..len],
101            ParseConfig {
102                include_tlvs: false,
103                allow_v1: false,
104                allow_v2: true,
105            },
106        ) {
107            Ok((header, hlen)) => (header, hlen),
108            Err(proxy_header::Error::Invalid) => {
109                debug!(
110                    "Proxy header is invalid. This is likely due to no no header being provided",
111                );
112                return None;
113            }
114            Err(e) => {
115                debug!("Proxy header parse error '{:?}', ignoring header.", e);
116                return None;
117            }
118        };
119        debug!("Proxied connection with header {:?}", header);
120        let address = header.proxied_address().map(|a| a.to_owned());
121        // Proxy header found, clear the bytes.
122        let _ = self.read_exact(&mut buf[..hlen]).await;
123        address
124    }
125
126    /// Peer address of the inner tcp_stream.
127    pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
128        self.tcp_stream.peer_addr()
129    }
130}
131
132impl AsyncRead for Connection {
133    fn poll_read(
134        mut self: Pin<&mut Self>,
135        cx: &mut Context,
136        buf: &mut ReadBuf,
137    ) -> Poll<io::Result<()>> {
138        Pin::new(&mut self.tcp_stream).poll_read(cx, buf)
139    }
140}
141
142impl AsyncWrite for Connection {
143    fn poll_write(
144        mut self: Pin<&mut Self>,
145        cx: &mut Context,
146        buf: &[u8],
147    ) -> Poll<io::Result<usize>> {
148        Pin::new(&mut self.tcp_stream).poll_write(cx, buf)
149    }
150
151    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
152        Pin::new(&mut self.tcp_stream).poll_flush(cx)
153    }
154
155    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
156        Pin::new(&mut self.tcp_stream).poll_shutdown(cx)
157    }
158}
159
160#[async_trait]
161impl AsyncReady for Connection {
162    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
163        self.tcp_stream.ready(interest).await
164    }
165}
166
167/// A handle that permits getting and setting the UUID for a [`Connection`].
168///
169/// A connection's UUID is a globally unique value that can identify a given
170/// connection across environments and process boundaries. Connection UUIDs are
171/// never reused.
172///
173/// This is distinct from environmentd's concept of a "connection ID", which is
174/// a `u32` that only identifies a connection within a given environment and
175/// only during its lifetime. These connection IDs are frequently reused.
176pub struct ConnectionUuidHandle(Arc<Mutex<Option<Uuid>>>);
177
178impl ConnectionUuidHandle {
179    /// Gets the UUID for the connection, if it exists.
180    pub fn get(&self) -> Option<Uuid> {
181        *self.0.lock().expect("lock poisoned")
182    }
183
184    /// Sets the UUID for this connection.
185    pub fn set(&self, conn_uuid: Uuid) {
186        *self.0.lock().expect("lock poisoned") = Some(conn_uuid);
187    }
188
189    /// Returns a displayble that renders a possibly missing connection UUID.
190    pub fn display(&self) -> impl fmt::Display {
191        self.get().display_or("<unknown>")
192    }
193}
194
195/// A server handles incoming network connections.
196pub trait Server {
197    /// Returns the name of the connection handler for use in e.g. log messages.
198    const NAME: &'static str;
199
200    /// Handles a single connection.
201    fn handle_connection(&self, conn: Connection) -> ConnectionHandler;
202}
203
204/// A stream of incoming connections.
205pub trait ConnectionStream: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
206
207impl<T> ConnectionStream for T where T: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
208
209/// A handle to a listener created by [`listen`].
210#[derive(Debug)]
211pub struct ListenerHandle {
212    local_addr: SocketAddr,
213    _trigger: trigger::Trigger,
214}
215
216impl ListenerHandle {
217    /// Returns the local address to which the listener is bound.
218    pub fn local_addr(&self) -> SocketAddr {
219        self.local_addr
220    }
221}
222
223/// Listens for incoming TCP connections on the specified address.
224///
225/// Returns a handle to the listener and the stream of incoming connections
226/// produced by the listener. When the handle is dropped, the listener is
227/// closed, and the stream of incoming connections terminates.
228pub async fn listen(
229    addr: &SocketAddr,
230) -> Result<(ListenerHandle, Pin<Box<dyn ConnectionStream>>), io::Error> {
231    let listener = TcpListener::bind(addr).await?;
232    let local_addr = listener.local_addr()?;
233    let (trigger, trigger_rx) = trigger::channel();
234    let handle = ListenerHandle {
235        local_addr,
236        _trigger: trigger,
237    };
238    // TODO(benesch): replace `TCPListenerStream`s with `listener.incoming()` if
239    // that is restored when the `Stream` trait stabilizes.
240    let stream = TcpListenerStream::new(listener).take_until(trigger_rx);
241    Ok((handle, Box::pin(stream)))
242}
243
244/// Configuration for [`serve`].
245pub struct ServeConfig<S, C>
246where
247    S: Server,
248    C: ConnectionStream,
249{
250    /// The server for the connections.
251    pub server: S,
252    /// The stream of incoming TCP connections.
253    pub conns: C,
254    /// Optional dynamic configuration for the server.
255    pub dyncfg: Option<ServeDyncfg>,
256}
257
258/// Dynamic configuration for [`ServeConfig`].
259pub struct ServeDyncfg {
260    /// The current bundle of dynamic configuration values.
261    pub config_set: ConfigSet,
262    /// A configuration in `config_set` that specifies how long to wait for
263    /// connections to terminate after receiving a SIGTERM before forcibly
264    /// terminated.
265    ///
266    /// If `None`, then forcible shutdown occurs immediately.
267    pub sigterm_wait_config: &'static Config<Duration>,
268}
269
270/// Serves incoming TCP connections.
271///
272/// Returns handles to the outstanding connections after the configured timeout
273/// has expired or all connections have completed.
274pub async fn serve<S, C>(
275    ServeConfig {
276        server,
277        mut conns,
278        dyncfg,
279    }: ServeConfig<S, C>,
280) -> JoinSet<()>
281where
282    S: Server,
283    C: ConnectionStream,
284{
285    let task_name = format!("handle_{}_connection", S::NAME);
286    let mut set = JoinSet::new();
287    loop {
288        tokio::select! {
289            // next() is cancel safe.
290            conn = conns.next() => {
291                let conn = match conn {
292                    None => break,
293                    Some(Ok(conn)) => conn,
294                    Some(Err(err)) => {
295                        error!("error accepting connection: {}", err);
296                        continue;
297                    }
298                };
299                // Set TCP_NODELAY to disable tinygram prevention (Nagle's
300                // algorithm), which forces a 40ms delay between each query
301                // on linux. According to John Nagle [0], the true problem
302                // is delayed acks, but disabling those is a receive-side
303                // operation (TCP_QUICKACK), and we can't always control the
304                // client. PostgreSQL sets TCP_NODELAY on both sides of its
305                // sockets, so it seems sane to just do the same.
306                //
307                // If set_nodelay fails, it's a programming error, so panic.
308                //
309                // [0]: https://news.ycombinator.com/item?id=10608356
310                conn.set_nodelay(true).expect("set_nodelay failed");
311                // Enable TCP keepalives to avoid any idle connection timeouts that may
312                // be enforced by networking devices between us and the client. Idle SQL
313                // connections are expected--e.g., a `SUBSCRIBE` to a view containing
314                // critical alerts will ideally be producing no data most of the time.
315                if let Err(e) = SockRef::from(&conn).set_tcp_keepalive(&KEEPALIVE) {
316                    error!("failed enabling keepalive: {e}");
317                    continue;
318                }
319                let conn = Connection::new(conn);
320                let conn_uuid = conn.uuid_handle();
321                let fut = server.handle_connection(conn);
322                set.spawn_named(|| &task_name, async move {
323                    let guard = scopeguard::guard((), |_| {
324                        debug!(
325                            server = S::NAME,
326                            conn_uuid = %conn_uuid.display(),
327                            "dropping connection without explicit termination",
328                        );
329                    });
330
331                    match fut.await {
332                        Ok(()) => {
333                            debug!(
334                                server = S::NAME,
335                                conn_uuid = %conn_uuid.display(),
336                                "successfully handled connection",
337                            );
338                        }
339                        Err(e) => {
340                            warn!(
341                                server = S::NAME,
342                                conn_uuid = %conn_uuid.display(),
343                                "error handling connection: {}",
344                                e.display_with_causes(),
345                            );
346                        }
347                    }
348
349                    let () = ScopeGuard::into_inner(guard);
350                });
351            }
352            // Actively cull completed tasks from the JoinSet so it does not grow unbounded. This
353            // method is cancel safe.
354            res = set.join_next(), if set.len() > 0 => {
355                if let Some(Err(e)) = res {
356                    warn!(
357                        "error joining connection in {}: {}",
358                        S::NAME,
359                        e.display_with_causes()
360                    );
361                }
362            }
363        }
364    }
365    if let Some(dyncfg) = dyncfg {
366        let wait = dyncfg.sigterm_wait_config.get(&dyncfg.config_set);
367        if set.len() > 0 {
368            warn!(
369                "{} exiting, {} outstanding connections, waiting for {:?}",
370                S::NAME,
371                set.len(),
372                wait
373            );
374        }
375        let timedout = tokio::time::timeout(wait, async {
376            while let Some(res) = set.join_next().await {
377                if let Err(e) = res {
378                    warn!(
379                        "error joining connection in {}: {}",
380                        S::NAME,
381                        e.display_with_causes()
382                    );
383                }
384            }
385        })
386        .await;
387        if timedout.is_err() {
388            warn!(
389                "{}: wait timeout of {:?} exceeded, {} outstanding connections",
390                S::NAME,
391                wait,
392                set.len()
393            );
394        }
395    }
396    set
397}
398
399/// Configures a server's TLS encryption and authentication.
400#[derive(Clone, Debug)]
401pub struct TlsConfig {
402    /// The SSL context used to manage incoming TLS negotiations.
403    pub context: SslContext,
404    /// The TLS mode.
405    pub mode: TlsMode,
406}
407
408/// Specifies how strictly to enforce TLS encryption.
409#[derive(Debug, Clone, Copy)]
410pub enum TlsMode {
411    /// Allow TLS encryption.
412    Allow,
413    /// Require that clients negotiate TLS encryption.
414    Require,
415}
416
417/// Configures TLS encryption for connections.
418#[derive(Debug, Clone)]
419pub struct TlsCertConfig {
420    /// The path to the TLS certificate.
421    pub cert: PathBuf,
422    /// The path to the TLS key.
423    pub key: PathBuf,
424}
425
426impl TlsCertConfig {
427    /// Returns the SSL context to use in TlsConfigs.
428    pub fn load_context(&self) -> Result<SslContext, anyhow::Error> {
429        // Mozilla publishes three presets: old, intermediate, and modern. They
430        // recommend the intermediate preset for general purpose servers, which
431        // is what we use, as it is compatible with nearly every client released
432        // in the last five years but does not include any known-problematic
433        // ciphers. We once tried to use the modern preset, but it was
434        // incompatible with Fivetran, and presumably other JDBC-based tools.
435        let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
436        builder.set_certificate_chain_file(&self.cert)?;
437        builder.set_private_key_file(&self.key, SslFiletype::PEM)?;
438        Ok(builder.build().into_context())
439    }
440
441    /// Like [Self::load_context] but attempts to reload the files each time `ticker` yields an item.
442    /// Returns an error based on the files currently on disk. When `ticker` receives, the
443    /// certificates are reloaded from the context. The result of the reloading is returned on the
444    /// oneshot if present, and an Ok result means new connections will use the new certificates. An
445    /// Err result will not change the current certificates.
446    pub fn reloading_context(
447        &self,
448        mut ticker: ReloadTrigger,
449    ) -> Result<ReloadingSslContext, anyhow::Error> {
450        let context = Arc::new(RwLock::new(self.load_context()?));
451        let updater_context = Arc::clone(&context);
452        let config = self.clone();
453        mz_ore::task::spawn(|| "TlsCertConfig reloading_context", async move {
454            while let Some(chan) = ticker.next().await {
455                let result = match config.load_context() {
456                    Ok(ctx) => {
457                        *updater_context.write().expect("poisoned") = ctx;
458                        Ok(())
459                    }
460                    Err(err) => {
461                        tracing::error!("failed to reload SSL certificate: {err}");
462                        Err(err)
463                    }
464                };
465                if let Some(chan) = chan {
466                    let _ = chan.send(result);
467                }
468            }
469            tracing::warn!("TlsCertConfig reloading_context updater closed");
470        });
471        Ok(ReloadingSslContext { context })
472    }
473}
474
475/// An SslContext whose inner value can be updated.
476#[derive(Clone, Debug)]
477pub struct ReloadingSslContext {
478    /// The current SSL context.
479    context: Arc<RwLock<SslContext>>,
480}
481
482impl ReloadingSslContext {
483    pub fn get(&self) -> RwLockReadGuard<SslContext> {
484        self.context.read().expect("poisoned")
485    }
486}
487
488/// Configures a server's TLS encryption and authentication with reloading.
489#[derive(Clone, Debug)]
490pub struct ReloadingTlsConfig {
491    /// The SSL context used to manage incoming TLS negotiations.
492    pub context: ReloadingSslContext,
493    /// The TLS mode.
494    pub mode: TlsMode,
495}
496
497pub type ReloadTrigger = BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>;
498
499/// Returns a ReloadTrigger that triggers once per hour.
500pub fn default_cert_reload_ticker() -> ReloadTrigger {
501    let ticker = IntervalStream::new(tokio::time::interval(Duration::from_secs(60 * 60)));
502    let ticker = ticker.map(|_| None);
503    let ticker = Box::pin(ticker);
504    ticker
505}
506
507/// Returns a ReloadTrigger that never triggers.
508pub fn cert_reload_never_reload() -> ReloadTrigger {
509    let ticker = futures::stream::empty();
510    let ticker = Box::pin(ticker);
511    ticker
512}
513
514/// Command line arguments for TLS.
515#[derive(Debug, Clone, clap::Parser)]
516pub struct TlsCliArgs {
517    /// How stringently to demand TLS authentication and encryption.
518    ///
519    /// If set to "disable", then environmentd rejects HTTP and PostgreSQL
520    /// connections that negotiate TLS.
521    ///
522    /// If set to "require", then environmentd requires that all HTTP and
523    /// PostgreSQL connections negotiate TLS. Unencrypted connections will be
524    /// rejected.
525    #[clap(
526        long, env = "TLS_MODE",
527        value_parser = ["disable", "require"],
528        default_value = "disable",
529        default_value_ifs = [
530            ("frontegg_tenant", ArgPredicate::IsPresent, Some("require")),
531            ("frontegg_resolver_template", ArgPredicate::IsPresent, Some("require")),
532        ],
533        value_name = "MODE",
534    )]
535    tls_mode: String,
536    /// Certificate file for TLS connections.
537    #[clap(
538        long,
539        env = "TLS_CERT",
540        requires = "tls_key",
541        required_if_eq_any([("tls_mode", "require")]),
542        value_name = "PATH"
543    )]
544    tls_cert: Option<PathBuf>,
545    /// Private key file for TLS connections.
546    #[clap(
547        long,
548        env = "TLS_KEY",
549        requires = "tls_cert",
550        required_if_eq_any([("tls_mode", "require")]),
551        value_name = "PATH"
552    )]
553    tls_key: Option<PathBuf>,
554}
555
556impl TlsCliArgs {
557    /// Convert args into configuration.
558    pub fn into_config(self) -> Result<Option<TlsCertConfig>, anyhow::Error> {
559        if self.tls_mode == "disable" {
560            if self.tls_cert.is_some() {
561                bail!("cannot specify --tls-mode=disable and --tls-cert simultaneously");
562            }
563            if self.tls_key.is_some() {
564                bail!("cannot specify --tls-mode=disable and --tls-key simultaneously");
565            }
566            Ok(None)
567        } else {
568            let cert = self.tls_cert.unwrap();
569            let key = self.tls_key.unwrap();
570            Ok(Some(TlsCertConfig { cert, key }))
571        }
572    }
573}