1use std::fmt;
13use std::future::Future;
14use std::io;
15use std::net::SocketAddr;
16use std::path::PathBuf;
17use std::pin::Pin;
18use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22use anyhow::bail;
23use async_trait::async_trait;
24use clap::builder::ArgPredicate;
25use futures::stream::{BoxStream, Stream, StreamExt};
26use mz_dyncfg::{Config, ConfigSet};
27use mz_ore::channel::trigger;
28use mz_ore::error::ErrorExt;
29use mz_ore::netio::AsyncReady;
30use mz_ore::option::OptionExt;
31use mz_ore::task::JoinSetExt;
32use openssl::ssl::{SslAcceptor, SslContext, SslFiletype, SslMethod};
33use proxy_header::{ParseConfig, ProxiedAddress, ProxyHeader};
34use scopeguard::ScopeGuard;
35use socket2::{SockRef, TcpKeepalive};
36use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, Interest, ReadBuf, Ready};
37use tokio::net::{TcpListener, TcpStream};
38use tokio::sync::oneshot;
39use tokio::task::JoinSet;
40use tokio_stream::wrappers::{IntervalStream, TcpListenerStream};
41use tracing::{debug, error, warn};
42use uuid::Uuid;
43
44const KEEPALIVE: TcpKeepalive = TcpKeepalive::new()
49 .with_time(Duration::from_secs(60))
50 .with_interval(Duration::from_secs(60))
51 .with_retries(9);
52
53pub type ConnectionHandler = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
55
56pub struct Connection {
59 conn_uuid: Arc<Mutex<Option<Uuid>>>,
60 tcp_stream: TcpStream,
61}
62
63impl Connection {
64 fn new(tcp_stream: TcpStream) -> Connection {
65 Connection {
66 conn_uuid: Arc::new(Mutex::new(None)),
67 tcp_stream,
68 }
69 }
70
71 pub fn uuid_handle(&self) -> ConnectionUuidHandle {
73 ConnectionUuidHandle(Arc::clone(&self.conn_uuid))
74 }
75
76 pub async fn take_proxy_header_address(&mut self) -> Option<ProxiedAddress> {
85 let mut buf = [0u8; 1024];
90 let len = match self.tcp_stream.peek(&mut buf).await {
91 Ok(n) if n > 0 => n,
92 _ => {
93 debug!("Failed to read from client socket or no data received");
94 return None;
95 }
96 };
97
98 let (header, hlen) = match ProxyHeader::parse(
100 &buf[..len],
101 ParseConfig {
102 include_tlvs: false,
103 allow_v1: false,
104 allow_v2: true,
105 },
106 ) {
107 Ok((header, hlen)) => (header, hlen),
108 Err(proxy_header::Error::Invalid) => {
109 debug!(
110 "Proxy header is invalid. This is likely due to no no header being provided",
111 );
112 return None;
113 }
114 Err(e) => {
115 debug!("Proxy header parse error '{:?}', ignoring header.", e);
116 return None;
117 }
118 };
119 debug!("Proxied connection with header {:?}", header);
120 let address = header.proxied_address().map(|a| a.to_owned());
121 let _ = self.read_exact(&mut buf[..hlen]).await;
123 address
124 }
125
126 pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
128 self.tcp_stream.peer_addr()
129 }
130}
131
132impl AsyncRead for Connection {
133 fn poll_read(
134 mut self: Pin<&mut Self>,
135 cx: &mut Context,
136 buf: &mut ReadBuf,
137 ) -> Poll<io::Result<()>> {
138 Pin::new(&mut self.tcp_stream).poll_read(cx, buf)
139 }
140}
141
142impl AsyncWrite for Connection {
143 fn poll_write(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context,
146 buf: &[u8],
147 ) -> Poll<io::Result<usize>> {
148 Pin::new(&mut self.tcp_stream).poll_write(cx, buf)
149 }
150
151 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
152 Pin::new(&mut self.tcp_stream).poll_flush(cx)
153 }
154
155 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
156 Pin::new(&mut self.tcp_stream).poll_shutdown(cx)
157 }
158}
159
160#[async_trait]
161impl AsyncReady for Connection {
162 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
163 self.tcp_stream.ready(interest).await
164 }
165}
166
167pub struct ConnectionUuidHandle(Arc<Mutex<Option<Uuid>>>);
177
178impl ConnectionUuidHandle {
179 pub fn get(&self) -> Option<Uuid> {
181 *self.0.lock().expect("lock poisoned")
182 }
183
184 pub fn set(&self, conn_uuid: Uuid) {
186 *self.0.lock().expect("lock poisoned") = Some(conn_uuid);
187 }
188
189 pub fn display(&self) -> impl fmt::Display {
191 self.get().display_or("<unknown>")
192 }
193}
194
195pub trait Server {
197 const NAME: &'static str;
199
200 fn handle_connection(&self, conn: Connection) -> ConnectionHandler;
202}
203
204pub trait ConnectionStream: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
206
207impl<T> ConnectionStream for T where T: Stream<Item = io::Result<TcpStream>> + Unpin + Send {}
208
209#[derive(Debug)]
211pub struct ListenerHandle {
212 local_addr: SocketAddr,
213 _trigger: trigger::Trigger,
214}
215
216impl ListenerHandle {
217 pub fn local_addr(&self) -> SocketAddr {
219 self.local_addr
220 }
221}
222
223pub async fn listen(
229 addr: &SocketAddr,
230) -> Result<(ListenerHandle, Pin<Box<dyn ConnectionStream>>), io::Error> {
231 let listener = TcpListener::bind(addr).await?;
232 let local_addr = listener.local_addr()?;
233 let (trigger, trigger_rx) = trigger::channel();
234 let handle = ListenerHandle {
235 local_addr,
236 _trigger: trigger,
237 };
238 let stream = TcpListenerStream::new(listener).take_until(trigger_rx);
241 Ok((handle, Box::pin(stream)))
242}
243
244pub struct ServeConfig<S, C>
246where
247 S: Server,
248 C: ConnectionStream,
249{
250 pub server: S,
252 pub conns: C,
254 pub dyncfg: Option<ServeDyncfg>,
256}
257
258pub struct ServeDyncfg {
260 pub config_set: ConfigSet,
262 pub sigterm_wait_config: &'static Config<Duration>,
268}
269
270pub async fn serve<S, C>(
275 ServeConfig {
276 server,
277 mut conns,
278 dyncfg,
279 }: ServeConfig<S, C>,
280) -> JoinSet<()>
281where
282 S: Server,
283 C: ConnectionStream,
284{
285 let task_name = format!("handle_{}_connection", S::NAME);
286 let mut set = JoinSet::new();
287 loop {
288 tokio::select! {
289 conn = conns.next() => {
291 let conn = match conn {
292 None => break,
293 Some(Ok(conn)) => conn,
294 Some(Err(err)) => {
295 error!("error accepting connection: {}", err);
296 continue;
297 }
298 };
299 conn.set_nodelay(true).expect("set_nodelay failed");
311 if let Err(e) = SockRef::from(&conn).set_tcp_keepalive(&KEEPALIVE) {
316 error!("failed enabling keepalive: {e}");
317 continue;
318 }
319 let conn = Connection::new(conn);
320 let conn_uuid = conn.uuid_handle();
321 let fut = server.handle_connection(conn);
322 set.spawn_named(|| &task_name, async move {
323 let guard = scopeguard::guard((), |_| {
324 debug!(
325 server = S::NAME,
326 conn_uuid = %conn_uuid.display(),
327 "dropping connection without explicit termination",
328 );
329 });
330
331 match fut.await {
332 Ok(()) => {
333 debug!(
334 server = S::NAME,
335 conn_uuid = %conn_uuid.display(),
336 "successfully handled connection",
337 );
338 }
339 Err(e) => {
340 warn!(
341 server = S::NAME,
342 conn_uuid = %conn_uuid.display(),
343 "error handling connection: {}",
344 e.display_with_causes(),
345 );
346 }
347 }
348
349 let () = ScopeGuard::into_inner(guard);
350 });
351 }
352 res = set.join_next(), if set.len() > 0 => {
355 if let Some(Err(e)) = res {
356 warn!(
357 "error joining connection in {}: {}",
358 S::NAME,
359 e.display_with_causes()
360 );
361 }
362 }
363 }
364 }
365 if let Some(dyncfg) = dyncfg {
366 let wait = dyncfg.sigterm_wait_config.get(&dyncfg.config_set);
367 if set.len() > 0 {
368 warn!(
369 "{} exiting, {} outstanding connections, waiting for {:?}",
370 S::NAME,
371 set.len(),
372 wait
373 );
374 }
375 let timedout = tokio::time::timeout(wait, async {
376 while let Some(res) = set.join_next().await {
377 if let Err(e) = res {
378 warn!(
379 "error joining connection in {}: {}",
380 S::NAME,
381 e.display_with_causes()
382 );
383 }
384 }
385 })
386 .await;
387 if timedout.is_err() {
388 warn!(
389 "{}: wait timeout of {:?} exceeded, {} outstanding connections",
390 S::NAME,
391 wait,
392 set.len()
393 );
394 }
395 }
396 set
397}
398
399#[derive(Clone, Debug)]
401pub struct TlsConfig {
402 pub context: SslContext,
404 pub mode: TlsMode,
406}
407
408#[derive(Debug, Clone, Copy)]
410pub enum TlsMode {
411 Allow,
413 Require,
415}
416
417#[derive(Debug, Clone)]
419pub struct TlsCertConfig {
420 pub cert: PathBuf,
422 pub key: PathBuf,
424}
425
426impl TlsCertConfig {
427 pub fn load_context(&self) -> Result<SslContext, anyhow::Error> {
429 let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
436 builder.set_certificate_chain_file(&self.cert)?;
437 builder.set_private_key_file(&self.key, SslFiletype::PEM)?;
438 Ok(builder.build().into_context())
439 }
440
441 pub fn reloading_context(
447 &self,
448 mut ticker: ReloadTrigger,
449 ) -> Result<ReloadingSslContext, anyhow::Error> {
450 let context = Arc::new(RwLock::new(self.load_context()?));
451 let updater_context = Arc::clone(&context);
452 let config = self.clone();
453 mz_ore::task::spawn(|| "TlsCertConfig reloading_context", async move {
454 while let Some(chan) = ticker.next().await {
455 let result = match config.load_context() {
456 Ok(ctx) => {
457 *updater_context.write().expect("poisoned") = ctx;
458 Ok(())
459 }
460 Err(err) => {
461 tracing::error!("failed to reload SSL certificate: {err}");
462 Err(err)
463 }
464 };
465 if let Some(chan) = chan {
466 let _ = chan.send(result);
467 }
468 }
469 tracing::warn!("TlsCertConfig reloading_context updater closed");
470 });
471 Ok(ReloadingSslContext { context })
472 }
473}
474
475#[derive(Clone, Debug)]
477pub struct ReloadingSslContext {
478 context: Arc<RwLock<SslContext>>,
480}
481
482impl ReloadingSslContext {
483 pub fn get(&self) -> RwLockReadGuard<SslContext> {
484 self.context.read().expect("poisoned")
485 }
486}
487
488#[derive(Clone, Debug)]
490pub struct ReloadingTlsConfig {
491 pub context: ReloadingSslContext,
493 pub mode: TlsMode,
495}
496
497pub type ReloadTrigger = BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>;
498
499pub fn default_cert_reload_ticker() -> ReloadTrigger {
501 let ticker = IntervalStream::new(tokio::time::interval(Duration::from_secs(60 * 60)));
502 let ticker = ticker.map(|_| None);
503 let ticker = Box::pin(ticker);
504 ticker
505}
506
507pub fn cert_reload_never_reload() -> ReloadTrigger {
509 let ticker = futures::stream::empty();
510 let ticker = Box::pin(ticker);
511 ticker
512}
513
514#[derive(Debug, Clone, clap::Parser)]
516pub struct TlsCliArgs {
517 #[clap(
526 long, env = "TLS_MODE",
527 value_parser = ["disable", "require"],
528 default_value = "disable",
529 default_value_ifs = [
530 ("frontegg_tenant", ArgPredicate::IsPresent, Some("require")),
531 ("frontegg_resolver_template", ArgPredicate::IsPresent, Some("require")),
532 ],
533 value_name = "MODE",
534 )]
535 tls_mode: String,
536 #[clap(
538 long,
539 env = "TLS_CERT",
540 requires = "tls_key",
541 required_if_eq_any([("tls_mode", "require")]),
542 value_name = "PATH"
543 )]
544 tls_cert: Option<PathBuf>,
545 #[clap(
547 long,
548 env = "TLS_KEY",
549 requires = "tls_cert",
550 required_if_eq_any([("tls_mode", "require")]),
551 value_name = "PATH"
552 )]
553 tls_key: Option<PathBuf>,
554}
555
556impl TlsCliArgs {
557 pub fn into_config(self) -> Result<Option<TlsCertConfig>, anyhow::Error> {
559 if self.tls_mode == "disable" {
560 if self.tls_cert.is_some() {
561 bail!("cannot specify --tls-mode=disable and --tls-cert simultaneously");
562 }
563 if self.tls_key.is_some() {
564 bail!("cannot specify --tls-mode=disable and --tls-key simultaneously");
565 }
566 Ok(None)
567 } else {
568 let cert = self.tls_cert.unwrap();
569 let key = self.tls_key.unwrap();
570 Ok(Some(TlsCertConfig { cert, key }))
571 }
572 }
573}