mz_service/
grpc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! gRPC transport for the [client](crate::client) module.
11
12use async_stream::stream;
13use async_trait::async_trait;
14use futures::future::{self, BoxFuture};
15use futures::stream::{Stream, StreamExt, TryStreamExt};
16use http::uri::PathAndQuery;
17use hyper_util::rt::TokioIo;
18use mz_ore::metric;
19use mz_ore::metrics::{DeleteOnDropGauge, MetricsRegistry, UIntGaugeVec};
20use mz_ore::netio::{Listener, SocketAddr, SocketAddrType};
21use mz_proto::{ProtoType, RustType};
22use prometheus::core::AtomicU64;
23use semver::Version;
24use std::error::Error;
25use std::fmt::{self, Debug};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::time::UNIX_EPOCH;
30use tokio::net::UnixStream;
31use tokio::select;
32use tokio::sync::mpsc::{self, UnboundedSender};
33use tokio::sync::{Mutex, oneshot};
34use tokio_stream::wrappers::UnboundedReceiverStream;
35use tonic::body::BoxBody;
36use tonic::codegen::InterceptedService;
37use tonic::metadata::AsciiMetadataValue;
38use tonic::server::NamedService;
39use tonic::service::Interceptor;
40use tonic::transport::{Channel, Endpoint, Server};
41use tonic::{IntoStreamingRequest, Request, Response, Status, Streaming};
42use tower::Service;
43use tracing::{debug, error, info, warn};
44
45use crate::client::{GenericClient, Partitionable, Partitioned};
46use crate::codec::{StatCodec, StatsCollector};
47use crate::params::GrpcClientParameters;
48use crate::transport;
49
50include!(concat!(env!("OUT_DIR"), "/mz_service.params.rs"));
51
52// Use with generated servers Server::new(Svc).max_decoding_message_size
53pub const MAX_GRPC_MESSAGE_SIZE: usize = usize::MAX;
54
55pub type ClientTransport = InterceptedService<Channel, VersionAttachInterceptor>;
56
57/// Types that we send and receive over a service endpoint.
58pub trait ProtoServiceTypes: Debug + Clone + Send {
59    type PC: prost::Message + Clone + 'static;
60    type PR: prost::Message + Clone + Default + 'static;
61    type STATS: StatsCollector<Self::PC, Self::PR> + 'static;
62    const URL: &'static str;
63}
64
65/// A client to a remote dataflow server using gRPC and protobuf based
66/// communication.
67///
68/// The client opens a connection using the proto client stubs that are
69/// generated by tonic from a service definition. When the client is connected,
70/// it will call automatically the only RPC defined in the service description,
71/// encapsulated by the `BidiProtoClient` trait. This trait bound is not on the
72/// `Client` type parameter here, but it IS on the impl blocks. Bidirectional
73/// protobuf RPC sets up two streams that persist after the RPC has returned: A
74/// Request (Command) stream (for us, backed by a unbounded mpsc queue) going
75/// from this instance to the server and a response stream coming back
76/// (represented directly as a `Streaming<Response>` instance). The recv and send
77/// functions interact with the two mpsc channels or the streaming instance
78/// respectively.
79#[derive(Debug)]
80pub struct GrpcClient<G>
81where
82    G: ProtoServiceTypes,
83{
84    /// The sender for commands.
85    tx: UnboundedSender<G::PC>,
86    /// The receiver for responses.
87    rx: Streaming<G::PR>,
88}
89
90impl<G> GrpcClient<G>
91where
92    G: ProtoServiceTypes,
93{
94    /// Connects to the server at the given address, announcing the specified
95    /// client version.
96    pub async fn connect(
97        addr: String,
98        version: Version,
99        metrics: G::STATS,
100        params: &GrpcClientParameters,
101    ) -> Result<Self, anyhow::Error> {
102        debug!("GrpcClient {}: Attempt to connect", addr);
103
104        let channel = match SocketAddrType::guess(&addr) {
105            SocketAddrType::Inet => {
106                let mut endpoint = Endpoint::new(format!("http://{}", addr))?;
107                if let Some(connect_timeout) = params.connect_timeout {
108                    endpoint = endpoint.connect_timeout(connect_timeout);
109                }
110                if let Some(keep_alive_timeout) = params.http2_keep_alive_timeout {
111                    endpoint = endpoint.keep_alive_timeout(keep_alive_timeout);
112                }
113                if let Some(keep_alive_interval) = params.http2_keep_alive_interval {
114                    endpoint = endpoint.http2_keep_alive_interval(keep_alive_interval);
115                }
116                endpoint.connect().await?
117            }
118            SocketAddrType::Unix => {
119                let addr = addr.clone();
120                Endpoint::from_static("http://localhost") // URI is ignored
121                    .connect_with_connector(tower::service_fn(move |_| {
122                        let addr = addr.clone();
123                        async { UnixStream::connect(addr).await.map(TokioIo::new) }
124                    }))
125                    .await?
126            }
127            SocketAddrType::Turmoil => unimplemented!(),
128        };
129        let service = InterceptedService::new(channel, VersionAttachInterceptor::new(version));
130        let mut client = BidiProtoClient::new(service, G::URL, metrics);
131        let (tx, rx) = mpsc::unbounded_channel();
132        let rx = client
133            .establish_bidi_stream(UnboundedReceiverStream::new(rx))
134            .await?
135            .into_inner();
136        info!("GrpcClient {}: connected", &addr);
137        Ok(GrpcClient { tx, rx })
138    }
139
140    /// Like [`GrpcClient::connect`], but for multiple partitioned servers.
141    pub async fn connect_partitioned<C, R>(
142        dests: Vec<(String, G::STATS)>,
143        version: Version,
144        params: &GrpcClientParameters,
145    ) -> Result<Partitioned<Self, C, R>, anyhow::Error>
146    where
147        (C, R): Partitionable<C, R>,
148    {
149        let clients = future::try_join_all(
150            dests
151                .into_iter()
152                .map(|(addr, metrics)| Self::connect(addr, version.clone(), metrics, params)),
153        )
154        .await?;
155        Ok(Partitioned::new(clients))
156    }
157}
158
159#[async_trait]
160impl<G, C, R> GenericClient<C, R> for GrpcClient<G>
161where
162    C: RustType<G::PC> + Send + Sync + 'static,
163    R: RustType<G::PR> + Send + Sync + 'static,
164    G: ProtoServiceTypes,
165{
166    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
167        self.tx.send(cmd.into_proto())?;
168        Ok(())
169    }
170
171    /// # Cancel safety
172    ///
173    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
174    /// statement and some other branch completes first, it is guaranteed that no messages were
175    /// received by this client.
176    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
177        // `TryStreamExt::try_next` is cancel safe. The returned future only holds onto a
178        // reference to the underlying stream, so dropping it will never lose a value.
179        match self.rx.try_next().await? {
180            None => Ok(None),
181            Some(response) => Ok(Some(response.into_rust()?)),
182        }
183    }
184}
185
186/// Encapsulates the core functionality of a tonic gRPC client for a service
187/// that exposes a single bidirectional RPC stream.
188///
189/// The client calls back into the StatsCollector on each command send and
190/// response receive.
191///
192/// See the documentation on [`GrpcClient`] for details.
193pub struct BidiProtoClient<PC, PR, S>
194where
195    PC: prost::Message + 'static,
196    PR: Default + prost::Message + 'static,
197    S: StatsCollector<PC, PR>,
198{
199    inner: tonic::client::Grpc<ClientTransport>,
200    path: &'static str,
201    codec: StatCodec<PC, PR, S>,
202}
203
204impl<PC, PR, S> BidiProtoClient<PC, PR, S>
205where
206    PC: Clone + prost::Message + 'static,
207    PR: Clone + Default + prost::Message + 'static,
208    S: StatsCollector<PC, PR> + 'static,
209{
210    fn new(inner: ClientTransport, path: &'static str, stats_collector: S) -> Self
211    where
212        Self: Sized,
213    {
214        let inner = tonic::client::Grpc::new(inner)
215            .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE)
216            .max_encoding_message_size(MAX_GRPC_MESSAGE_SIZE);
217        let codec = StatCodec::new(stats_collector);
218        BidiProtoClient { inner, path, codec }
219    }
220
221    async fn establish_bidi_stream(
222        &mut self,
223        rx: UnboundedReceiverStream<PC>,
224    ) -> Result<Response<Streaming<PR>>, Status> {
225        self.inner.ready().await.map_err(|e| {
226            tonic::Status::new(
227                tonic::Code::Unknown,
228                format!("Service was not ready: {}", e),
229            )
230        })?;
231        let path = PathAndQuery::from_static(self.path);
232        self.inner
233            .streaming(rx.into_streaming_request(), path, self.codec.clone())
234            .await
235    }
236}
237
238/// A gRPC server that stitches a gRPC service with a single bidirectional
239/// stream to a [`GenericClient`].
240///
241/// It is the counterpart of [`GrpcClient`].
242///
243/// To use, implement the tonic-generated `ProtoService` trait for this type.
244/// The implementation of the bidirectional stream method should call
245/// [`GrpcServer::forward_bidi_stream`] to stitch the bidirectional stream to
246/// the client underlying this server.
247pub struct GrpcServer<F> {
248    state: Arc<GrpcServerState<F>>,
249}
250
251struct GrpcServerState<F> {
252    cancel_tx: Mutex<oneshot::Sender<()>>,
253    client_builder: F,
254    metrics: PerGrpcServerMetrics,
255}
256
257impl<F, G> GrpcServer<F>
258where
259    F: Fn() -> G + Send + Sync + 'static,
260{
261    /// Starts the server, listening for gRPC connections on `listen_addr`.
262    ///
263    /// The trait bounds on `S` are intimidating, but it is the return type of
264    /// `service_builder`, which is a function that
265    /// turns a `GrpcServer<ProtoCommandType, ProtoResponseType>` into a
266    /// [`Service`] that represents a gRPC server. This is always encapsulated
267    /// by the tonic-generated `ProtoServer::new` method for a specific Protobuf
268    /// service.
269    pub fn serve<S, Fs>(
270        metrics: &GrpcServerMetrics,
271        listen_addr: SocketAddr,
272        version: Version,
273        host: Option<String>,
274        client_builder: F,
275        service_builder: Fs,
276    ) -> impl Future<Output = Result<(), anyhow::Error>> + use<S, Fs, F, G>
277    where
278        S: Service<
279                http::Request<BoxBody>,
280                Response = http::Response<BoxBody>,
281                Error = std::convert::Infallible,
282            > + NamedService
283            + Clone
284            + Send
285            + 'static,
286        S::Future: Send + 'static,
287        Fs: FnOnce(Self) -> S + Send + 'static,
288    {
289        let (cancel_tx, _cancel_rx) = oneshot::channel();
290        let state = GrpcServerState {
291            cancel_tx: Mutex::new(cancel_tx),
292            client_builder,
293            metrics: metrics.for_server(S::NAME),
294        };
295        let server = Self {
296            state: Arc::new(state),
297        };
298        let service = service_builder(server);
299
300        if host.is_none() {
301            warn!("no host provided; request destination host checking is disabled");
302        }
303        let validation = RequestValidationLayer { version, host };
304
305        info!("Starting to listen on {}", listen_addr);
306
307        async {
308            let listener = Listener::bind(listen_addr).await?;
309
310            Server::builder()
311                .layer(validation)
312                .add_service(service)
313                .serve_with_incoming(listener)
314                .await?;
315            Ok(())
316        }
317    }
318
319    /// Handles a bidirectional stream request by forwarding commands to and
320    /// responses from the server's underlying client.
321    ///
322    /// Call this method from the implementation of the tonic-generated
323    /// `ProtoService`.
324    pub async fn forward_bidi_stream<C, R, PC, PR>(
325        &self,
326        request: Request<Streaming<PC>>,
327    ) -> Result<Response<ResponseStream<PR>>, Status>
328    where
329        G: GenericClient<C, R> + 'static,
330        C: RustType<PC> + Send + Sync + 'static + fmt::Debug,
331        R: RustType<PR> + Send + Sync + 'static + fmt::Debug,
332        PC: fmt::Debug + Send + Sync + 'static,
333        PR: fmt::Debug + Send + Sync + 'static,
334    {
335        info!("GrpcServer: remote client connected");
336
337        // Install our cancellation token. This may drop an existing
338        // cancellation token. We're allowed to run until someone else drops our
339        // cancellation token.
340        //
341        // TODO(benesch): rather than blindly dropping the existing cancellation
342        // token, we should check epochs, and only drop the existing connection
343        // if it is at a lower epoch.
344        // See: https://github.com/MaterializeInc/database-issues/issues/3840
345        let (cancel_tx, mut cancel_rx) = oneshot::channel();
346        *self.state.cancel_tx.lock().await = cancel_tx;
347
348        // Construct a new client and forward commands and responses until
349        // canceled.
350        let mut request = request.into_inner();
351        let state = Arc::clone(&self.state);
352        let stream = stream! {
353            let mut client = (state.client_builder)();
354            loop {
355                select! {
356                    command = request.next() => {
357                        let command = match command {
358                            None => break,
359                            Some(Ok(command)) => command,
360                            Some(Err(e)) => {
361                                warn!("error handling client: {e}");
362                                break;
363                            }
364                        };
365
366                        match UNIX_EPOCH.elapsed() {
367                            Ok(ts) => state.metrics.last_command_received.set(ts.as_secs()),
368                            Err(e) => error!("failed to get system time: {e}"),
369                        }
370
371                        let command = match command.into_rust() {
372                            Ok(command) => command,
373                            Err(e) => {
374                                error!("error converting command from protobuf: {}", e);
375                                break;
376                            }
377                        };
378
379                        if let Err(e) = client.send(command).await {
380                            yield Err(Status::unknown(e.to_string()));
381                        }
382                    }
383                    response = client.recv() => {
384                        match response {
385                            Ok(Some(response)) => yield Ok(response.into_proto()),
386                            Ok(None) => break,
387                            Err(e) => yield Err(Status::unknown(e.to_string())),
388                        }
389                    }
390                    _ = &mut cancel_rx => break,
391                }
392            }
393        };
394        Ok(Response::new(ResponseStream::new(stream)))
395    }
396}
397
398/// A stream returning responses to GRPC clients.
399///
400/// This is defined as a struct, rather than a type alias, so that we can define a `Drop` impl that
401/// logs stream termination.
402pub struct ResponseStream<PR>(Pin<Box<dyn Stream<Item = Result<PR, Status>> + Send>>);
403
404impl<PR> ResponseStream<PR> {
405    fn new<S>(stream: S) -> Self
406    where
407        S: Stream<Item = Result<PR, Status>> + Send + 'static,
408    {
409        Self(Box::pin(stream))
410    }
411}
412
413impl<PR> Stream for ResponseStream<PR> {
414    type Item = Result<PR, Status>;
415
416    fn poll_next(
417        mut self: Pin<&mut Self>,
418        cx: &mut std::task::Context<'_>,
419    ) -> std::task::Poll<Option<Self::Item>> {
420        self.0.poll_next_unpin(cx)
421    }
422}
423
424impl<PR> Drop for ResponseStream<PR> {
425    fn drop(&mut self) {
426        info!("GrpcServer: response stream disconnected");
427    }
428}
429
430/// Metrics for a [`GrpcServer`].
431#[derive(Debug)]
432pub struct GrpcServerMetrics {
433    last_command_received: UIntGaugeVec,
434}
435
436impl GrpcServerMetrics {
437    /// Registers the GRPC server metrics into a `registry`.
438    pub fn register_with(registry: &MetricsRegistry) -> Self {
439        Self {
440            last_command_received: registry.register(metric!(
441                name: "mz_grpc_server_last_command_received",
442                help: "The time at which the server received its last command.",
443                var_labels: ["server_name"],
444            )),
445        }
446    }
447
448    pub fn for_server(&self, name: &'static str) -> PerGrpcServerMetrics {
449        PerGrpcServerMetrics {
450            last_command_received: self
451                .last_command_received
452                .get_delete_on_drop_metric(vec![name]),
453        }
454    }
455}
456
457#[derive(Clone, Debug)]
458pub struct PerGrpcServerMetrics {
459    last_command_received: DeleteOnDropGauge<AtomicU64, Vec<&'static str>>,
460}
461
462impl<C, R> transport::Metrics<C, R> for PerGrpcServerMetrics {
463    fn bytes_sent(&mut self, _len: usize) {}
464    fn bytes_received(&mut self, _len: usize) {}
465    fn message_sent(&mut self, _msg: &C) {}
466
467    fn message_received(&mut self, _msg: &R) {
468        match UNIX_EPOCH.elapsed() {
469            Ok(ts) => self.last_command_received.set(ts.as_secs()),
470            Err(e) => error!("failed to get system time: {e}"),
471        }
472    }
473}
474
475const VERSION_HEADER_KEY: &str = "x-mz-version";
476
477/// A gRPC interceptor that attaches a version as metadata to each request.
478#[derive(Debug, Clone)]
479pub struct VersionAttachInterceptor {
480    version: AsciiMetadataValue,
481}
482
483impl VersionAttachInterceptor {
484    fn new(version: Version) -> VersionAttachInterceptor {
485        VersionAttachInterceptor {
486            version: version
487                .to_string()
488                .try_into()
489                .expect("semver versions are valid metadata values"),
490        }
491    }
492}
493
494impl Interceptor for VersionAttachInterceptor {
495    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
496        request
497            .metadata_mut()
498            .insert(VERSION_HEADER_KEY, self.version.clone());
499        Ok(request)
500    }
501}
502
503/// A `tower` layer that validates requests for compatibility with the server.
504#[derive(Clone)]
505struct RequestValidationLayer {
506    version: Version,
507    host: Option<String>,
508}
509
510impl<S> tower::Layer<S> for RequestValidationLayer {
511    type Service = RequestValidation<S>;
512
513    fn layer(&self, inner: S) -> Self::Service {
514        let version = self
515            .version
516            .to_string()
517            .try_into()
518            .expect("version is a valid header value");
519        RequestValidation {
520            inner,
521            version,
522            host: self.host.clone(),
523        }
524    }
525}
526
527/// A `tower` middleware that validates requests for compatibility with the server.
528#[derive(Clone)]
529struct RequestValidation<S> {
530    inner: S,
531    version: http::HeaderValue,
532    host: Option<String>,
533}
534
535impl<S, B> Service<http::Request<B>> for RequestValidation<S>
536where
537    S: Service<http::Request<B>, Error = Box<dyn Error + Send + Sync + 'static>>,
538    S::Response: Send + 'static,
539    S::Future: Send + 'static,
540{
541    type Response = S::Response;
542    type Error = S::Error;
543    type Future = BoxFuture<'static, Result<S::Response, S::Error>>;
544
545    fn poll_ready(
546        &mut self,
547        cx: &mut std::task::Context<'_>,
548    ) -> std::task::Poll<Result<(), Self::Error>> {
549        self.inner.poll_ready(cx)
550    }
551
552    fn call(&mut self, req: http::Request<B>) -> Self::Future {
553        let error = |msg| {
554            let error: S::Error = Box::new(Status::permission_denied(msg));
555            Box::pin(future::ready(Err(error)))
556        };
557
558        let Some(req_version) = req.headers().get(VERSION_HEADER_KEY) else {
559            return error("request missing version header".into());
560        };
561        if req_version != self.version {
562            return error(format!(
563                "request has version {req_version:?} but {:?} required",
564                self.version
565            ));
566        }
567
568        let req_host = req.uri().host();
569        if let (Some(req_host), Some(host)) = (req_host, &self.host) {
570            if req_host != host {
571                return error(format!(
572                    "request has host {req_host:?} but {host:?} required"
573                ));
574            }
575        }
576
577        Box::pin(self.inner.call(req))
578    }
579}