1use async_stream::stream;
13use async_trait::async_trait;
14use futures::future::{self, BoxFuture};
15use futures::stream::{Stream, StreamExt, TryStreamExt};
16use http::uri::PathAndQuery;
17use hyper_util::rt::TokioIo;
18use mz_ore::metric;
19use mz_ore::metrics::{DeleteOnDropGauge, MetricsRegistry, UIntGaugeVec};
20use mz_ore::netio::{Listener, SocketAddr, SocketAddrType};
21use mz_proto::{ProtoType, RustType};
22use prometheus::core::AtomicU64;
23use semver::Version;
24use std::error::Error;
25use std::fmt::{self, Debug};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::time::UNIX_EPOCH;
30use tokio::net::UnixStream;
31use tokio::select;
32use tokio::sync::mpsc::{self, UnboundedSender};
33use tokio::sync::{Mutex, oneshot};
34use tokio_stream::wrappers::UnboundedReceiverStream;
35use tonic::body::BoxBody;
36use tonic::codegen::InterceptedService;
37use tonic::metadata::AsciiMetadataValue;
38use tonic::server::NamedService;
39use tonic::service::Interceptor;
40use tonic::transport::{Channel, Endpoint, Server};
41use tonic::{IntoStreamingRequest, Request, Response, Status, Streaming};
42use tower::Service;
43use tracing::{debug, error, info, warn};
44
45use crate::client::{GenericClient, Partitionable, Partitioned};
46use crate::codec::{StatCodec, StatsCollector};
47use crate::params::GrpcClientParameters;
48use crate::transport;
49
50include!(concat!(env!("OUT_DIR"), "/mz_service.params.rs"));
51
52pub const MAX_GRPC_MESSAGE_SIZE: usize = usize::MAX;
54
55pub type ClientTransport = InterceptedService<Channel, VersionAttachInterceptor>;
56
57pub trait ProtoServiceTypes: Debug + Clone + Send {
59 type PC: prost::Message + Clone + 'static;
60 type PR: prost::Message + Clone + Default + 'static;
61 type STATS: StatsCollector<Self::PC, Self::PR> + 'static;
62 const URL: &'static str;
63}
64
65#[derive(Debug)]
80pub struct GrpcClient<G>
81where
82 G: ProtoServiceTypes,
83{
84 tx: UnboundedSender<G::PC>,
86 rx: Streaming<G::PR>,
88}
89
90impl<G> GrpcClient<G>
91where
92 G: ProtoServiceTypes,
93{
94 pub async fn connect(
97 addr: String,
98 version: Version,
99 metrics: G::STATS,
100 params: &GrpcClientParameters,
101 ) -> Result<Self, anyhow::Error> {
102 debug!("GrpcClient {}: Attempt to connect", addr);
103
104 let channel = match SocketAddrType::guess(&addr) {
105 SocketAddrType::Inet => {
106 let mut endpoint = Endpoint::new(format!("http://{}", addr))?;
107 if let Some(connect_timeout) = params.connect_timeout {
108 endpoint = endpoint.connect_timeout(connect_timeout);
109 }
110 if let Some(keep_alive_timeout) = params.http2_keep_alive_timeout {
111 endpoint = endpoint.keep_alive_timeout(keep_alive_timeout);
112 }
113 if let Some(keep_alive_interval) = params.http2_keep_alive_interval {
114 endpoint = endpoint.http2_keep_alive_interval(keep_alive_interval);
115 }
116 endpoint.connect().await?
117 }
118 SocketAddrType::Unix => {
119 let addr = addr.clone();
120 Endpoint::from_static("http://localhost") .connect_with_connector(tower::service_fn(move |_| {
122 let addr = addr.clone();
123 async { UnixStream::connect(addr).await.map(TokioIo::new) }
124 }))
125 .await?
126 }
127 SocketAddrType::Turmoil => unimplemented!(),
128 };
129 let service = InterceptedService::new(channel, VersionAttachInterceptor::new(version));
130 let mut client = BidiProtoClient::new(service, G::URL, metrics);
131 let (tx, rx) = mpsc::unbounded_channel();
132 let rx = client
133 .establish_bidi_stream(UnboundedReceiverStream::new(rx))
134 .await?
135 .into_inner();
136 info!("GrpcClient {}: connected", &addr);
137 Ok(GrpcClient { tx, rx })
138 }
139
140 pub async fn connect_partitioned<C, R>(
142 dests: Vec<(String, G::STATS)>,
143 version: Version,
144 params: &GrpcClientParameters,
145 ) -> Result<Partitioned<Self, C, R>, anyhow::Error>
146 where
147 (C, R): Partitionable<C, R>,
148 {
149 let clients = future::try_join_all(
150 dests
151 .into_iter()
152 .map(|(addr, metrics)| Self::connect(addr, version.clone(), metrics, params)),
153 )
154 .await?;
155 Ok(Partitioned::new(clients))
156 }
157}
158
159#[async_trait]
160impl<G, C, R> GenericClient<C, R> for GrpcClient<G>
161where
162 C: RustType<G::PC> + Send + Sync + 'static,
163 R: RustType<G::PR> + Send + Sync + 'static,
164 G: ProtoServiceTypes,
165{
166 async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
167 self.tx.send(cmd.into_proto())?;
168 Ok(())
169 }
170
171 async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
177 match self.rx.try_next().await? {
180 None => Ok(None),
181 Some(response) => Ok(Some(response.into_rust()?)),
182 }
183 }
184}
185
186pub struct BidiProtoClient<PC, PR, S>
194where
195 PC: prost::Message + 'static,
196 PR: Default + prost::Message + 'static,
197 S: StatsCollector<PC, PR>,
198{
199 inner: tonic::client::Grpc<ClientTransport>,
200 path: &'static str,
201 codec: StatCodec<PC, PR, S>,
202}
203
204impl<PC, PR, S> BidiProtoClient<PC, PR, S>
205where
206 PC: Clone + prost::Message + 'static,
207 PR: Clone + Default + prost::Message + 'static,
208 S: StatsCollector<PC, PR> + 'static,
209{
210 fn new(inner: ClientTransport, path: &'static str, stats_collector: S) -> Self
211 where
212 Self: Sized,
213 {
214 let inner = tonic::client::Grpc::new(inner)
215 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE)
216 .max_encoding_message_size(MAX_GRPC_MESSAGE_SIZE);
217 let codec = StatCodec::new(stats_collector);
218 BidiProtoClient { inner, path, codec }
219 }
220
221 async fn establish_bidi_stream(
222 &mut self,
223 rx: UnboundedReceiverStream<PC>,
224 ) -> Result<Response<Streaming<PR>>, Status> {
225 self.inner.ready().await.map_err(|e| {
226 tonic::Status::new(
227 tonic::Code::Unknown,
228 format!("Service was not ready: {}", e),
229 )
230 })?;
231 let path = PathAndQuery::from_static(self.path);
232 self.inner
233 .streaming(rx.into_streaming_request(), path, self.codec.clone())
234 .await
235 }
236}
237
238pub struct GrpcServer<F> {
248 state: Arc<GrpcServerState<F>>,
249}
250
251struct GrpcServerState<F> {
252 cancel_tx: Mutex<oneshot::Sender<()>>,
253 client_builder: F,
254 metrics: PerGrpcServerMetrics,
255}
256
257impl<F, G> GrpcServer<F>
258where
259 F: Fn() -> G + Send + Sync + 'static,
260{
261 pub fn serve<S, Fs>(
270 metrics: &GrpcServerMetrics,
271 listen_addr: SocketAddr,
272 version: Version,
273 host: Option<String>,
274 client_builder: F,
275 service_builder: Fs,
276 ) -> impl Future<Output = Result<(), anyhow::Error>> + use<S, Fs, F, G>
277 where
278 S: Service<
279 http::Request<BoxBody>,
280 Response = http::Response<BoxBody>,
281 Error = std::convert::Infallible,
282 > + NamedService
283 + Clone
284 + Send
285 + 'static,
286 S::Future: Send + 'static,
287 Fs: FnOnce(Self) -> S + Send + 'static,
288 {
289 let (cancel_tx, _cancel_rx) = oneshot::channel();
290 let state = GrpcServerState {
291 cancel_tx: Mutex::new(cancel_tx),
292 client_builder,
293 metrics: metrics.for_server(S::NAME),
294 };
295 let server = Self {
296 state: Arc::new(state),
297 };
298 let service = service_builder(server);
299
300 if host.is_none() {
301 warn!("no host provided; request destination host checking is disabled");
302 }
303 let validation = RequestValidationLayer { version, host };
304
305 info!("Starting to listen on {}", listen_addr);
306
307 async {
308 let listener = Listener::bind(listen_addr).await?;
309
310 Server::builder()
311 .layer(validation)
312 .add_service(service)
313 .serve_with_incoming(listener)
314 .await?;
315 Ok(())
316 }
317 }
318
319 pub async fn forward_bidi_stream<C, R, PC, PR>(
325 &self,
326 request: Request<Streaming<PC>>,
327 ) -> Result<Response<ResponseStream<PR>>, Status>
328 where
329 G: GenericClient<C, R> + 'static,
330 C: RustType<PC> + Send + Sync + 'static + fmt::Debug,
331 R: RustType<PR> + Send + Sync + 'static + fmt::Debug,
332 PC: fmt::Debug + Send + Sync + 'static,
333 PR: fmt::Debug + Send + Sync + 'static,
334 {
335 info!("GrpcServer: remote client connected");
336
337 let (cancel_tx, mut cancel_rx) = oneshot::channel();
346 *self.state.cancel_tx.lock().await = cancel_tx;
347
348 let mut request = request.into_inner();
351 let state = Arc::clone(&self.state);
352 let stream = stream! {
353 let mut client = (state.client_builder)();
354 loop {
355 select! {
356 command = request.next() => {
357 let command = match command {
358 None => break,
359 Some(Ok(command)) => command,
360 Some(Err(e)) => {
361 warn!("error handling client: {e}");
362 break;
363 }
364 };
365
366 match UNIX_EPOCH.elapsed() {
367 Ok(ts) => state.metrics.last_command_received.set(ts.as_secs()),
368 Err(e) => error!("failed to get system time: {e}"),
369 }
370
371 let command = match command.into_rust() {
372 Ok(command) => command,
373 Err(e) => {
374 error!("error converting command from protobuf: {}", e);
375 break;
376 }
377 };
378
379 if let Err(e) = client.send(command).await {
380 yield Err(Status::unknown(e.to_string()));
381 }
382 }
383 response = client.recv() => {
384 match response {
385 Ok(Some(response)) => yield Ok(response.into_proto()),
386 Ok(None) => break,
387 Err(e) => yield Err(Status::unknown(e.to_string())),
388 }
389 }
390 _ = &mut cancel_rx => break,
391 }
392 }
393 };
394 Ok(Response::new(ResponseStream::new(stream)))
395 }
396}
397
398pub struct ResponseStream<PR>(Pin<Box<dyn Stream<Item = Result<PR, Status>> + Send>>);
403
404impl<PR> ResponseStream<PR> {
405 fn new<S>(stream: S) -> Self
406 where
407 S: Stream<Item = Result<PR, Status>> + Send + 'static,
408 {
409 Self(Box::pin(stream))
410 }
411}
412
413impl<PR> Stream for ResponseStream<PR> {
414 type Item = Result<PR, Status>;
415
416 fn poll_next(
417 mut self: Pin<&mut Self>,
418 cx: &mut std::task::Context<'_>,
419 ) -> std::task::Poll<Option<Self::Item>> {
420 self.0.poll_next_unpin(cx)
421 }
422}
423
424impl<PR> Drop for ResponseStream<PR> {
425 fn drop(&mut self) {
426 info!("GrpcServer: response stream disconnected");
427 }
428}
429
430#[derive(Debug)]
432pub struct GrpcServerMetrics {
433 last_command_received: UIntGaugeVec,
434}
435
436impl GrpcServerMetrics {
437 pub fn register_with(registry: &MetricsRegistry) -> Self {
439 Self {
440 last_command_received: registry.register(metric!(
441 name: "mz_grpc_server_last_command_received",
442 help: "The time at which the server received its last command.",
443 var_labels: ["server_name"],
444 )),
445 }
446 }
447
448 pub fn for_server(&self, name: &'static str) -> PerGrpcServerMetrics {
449 PerGrpcServerMetrics {
450 last_command_received: self
451 .last_command_received
452 .get_delete_on_drop_metric(vec![name]),
453 }
454 }
455}
456
457#[derive(Clone, Debug)]
458pub struct PerGrpcServerMetrics {
459 last_command_received: DeleteOnDropGauge<AtomicU64, Vec<&'static str>>,
460}
461
462impl<C, R> transport::Metrics<C, R> for PerGrpcServerMetrics {
463 fn bytes_sent(&mut self, _len: usize) {}
464 fn bytes_received(&mut self, _len: usize) {}
465 fn message_sent(&mut self, _msg: &C) {}
466
467 fn message_received(&mut self, _msg: &R) {
468 match UNIX_EPOCH.elapsed() {
469 Ok(ts) => self.last_command_received.set(ts.as_secs()),
470 Err(e) => error!("failed to get system time: {e}"),
471 }
472 }
473}
474
475const VERSION_HEADER_KEY: &str = "x-mz-version";
476
477#[derive(Debug, Clone)]
479pub struct VersionAttachInterceptor {
480 version: AsciiMetadataValue,
481}
482
483impl VersionAttachInterceptor {
484 fn new(version: Version) -> VersionAttachInterceptor {
485 VersionAttachInterceptor {
486 version: version
487 .to_string()
488 .try_into()
489 .expect("semver versions are valid metadata values"),
490 }
491 }
492}
493
494impl Interceptor for VersionAttachInterceptor {
495 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
496 request
497 .metadata_mut()
498 .insert(VERSION_HEADER_KEY, self.version.clone());
499 Ok(request)
500 }
501}
502
503#[derive(Clone)]
505struct RequestValidationLayer {
506 version: Version,
507 host: Option<String>,
508}
509
510impl<S> tower::Layer<S> for RequestValidationLayer {
511 type Service = RequestValidation<S>;
512
513 fn layer(&self, inner: S) -> Self::Service {
514 let version = self
515 .version
516 .to_string()
517 .try_into()
518 .expect("version is a valid header value");
519 RequestValidation {
520 inner,
521 version,
522 host: self.host.clone(),
523 }
524 }
525}
526
527#[derive(Clone)]
529struct RequestValidation<S> {
530 inner: S,
531 version: http::HeaderValue,
532 host: Option<String>,
533}
534
535impl<S, B> Service<http::Request<B>> for RequestValidation<S>
536where
537 S: Service<http::Request<B>, Error = Box<dyn Error + Send + Sync + 'static>>,
538 S::Response: Send + 'static,
539 S::Future: Send + 'static,
540{
541 type Response = S::Response;
542 type Error = S::Error;
543 type Future = BoxFuture<'static, Result<S::Response, S::Error>>;
544
545 fn poll_ready(
546 &mut self,
547 cx: &mut std::task::Context<'_>,
548 ) -> std::task::Poll<Result<(), Self::Error>> {
549 self.inner.poll_ready(cx)
550 }
551
552 fn call(&mut self, req: http::Request<B>) -> Self::Future {
553 let error = |msg| {
554 let error: S::Error = Box::new(Status::permission_denied(msg));
555 Box::pin(future::ready(Err(error)))
556 };
557
558 let Some(req_version) = req.headers().get(VERSION_HEADER_KEY) else {
559 return error("request missing version header".into());
560 };
561 if req_version != self.version {
562 return error(format!(
563 "request has version {req_version:?} but {:?} required",
564 self.version
565 ));
566 }
567
568 let req_host = req.uri().host();
569 if let (Some(req_host), Some(host)) = (req_host, &self.host) {
570 if req_host != host {
571 return error(format!(
572 "request has host {req_host:?} but {host:?} required"
573 ));
574 }
575 }
576
577 Box::pin(self.inner.call(req))
578 }
579}