mz_service/
grpc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! gRPC transport for the [client](crate::client) module.
11
12use async_stream::stream;
13use async_trait::async_trait;
14use futures::future::{self, BoxFuture};
15use futures::stream::{Stream, StreamExt, TryStreamExt};
16use http::uri::PathAndQuery;
17use hyper_util::rt::TokioIo;
18use mz_ore::metric;
19use mz_ore::metrics::{DeleteOnDropGauge, MetricsRegistry, UIntGaugeVec};
20use mz_ore::netio::{Listener, SocketAddr, SocketAddrType};
21use mz_proto::{ProtoType, RustType};
22use prometheus::core::AtomicU64;
23use semver::Version;
24use std::error::Error;
25use std::fmt::{self, Debug};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::time::UNIX_EPOCH;
30use tokio::net::UnixStream;
31use tokio::select;
32use tokio::sync::mpsc::{self, UnboundedSender};
33use tokio::sync::{Mutex, oneshot};
34use tokio_stream::wrappers::UnboundedReceiverStream;
35use tonic::body::BoxBody;
36use tonic::codegen::InterceptedService;
37use tonic::metadata::AsciiMetadataValue;
38use tonic::server::NamedService;
39use tonic::service::Interceptor;
40use tonic::transport::{Channel, Endpoint, Server};
41use tonic::{IntoStreamingRequest, Request, Response, Status, Streaming};
42use tower::Service;
43use tracing::{debug, error, info, warn};
44
45use crate::client::{GenericClient, Partitionable, Partitioned};
46use crate::codec::{StatCodec, StatsCollector};
47use crate::params::GrpcClientParameters;
48
49include!(concat!(env!("OUT_DIR"), "/mz_service.params.rs"));
50
51// Use with generated servers Server::new(Svc).max_decoding_message_size
52pub const MAX_GRPC_MESSAGE_SIZE: usize = usize::MAX;
53
54pub type ClientTransport = InterceptedService<Channel, VersionAttachInterceptor>;
55
56/// Types that we send and receive over a service endpoint.
57pub trait ProtoServiceTypes: Debug + Clone + Send {
58    type PC: prost::Message + Clone + 'static;
59    type PR: prost::Message + Clone + Default + 'static;
60    type STATS: StatsCollector<Self::PC, Self::PR> + 'static;
61    const URL: &'static str;
62}
63
64/// A client to a remote dataflow server using gRPC and protobuf based
65/// communication.
66///
67/// The client opens a connection using the proto client stubs that are
68/// generated by tonic from a service definition. When the client is connected,
69/// it will call automatically the only RPC defined in the service description,
70/// encapsulated by the `BidiProtoClient` trait. This trait bound is not on the
71/// `Client` type parameter here, but it IS on the impl blocks. Bidirectional
72/// protobuf RPC sets up two streams that persist after the RPC has returned: A
73/// Request (Command) stream (for us, backed by a unbounded mpsc queue) going
74/// from this instance to the server and a response stream coming back
75/// (represented directly as a `Streaming<Response>` instance). The recv and send
76/// functions interact with the two mpsc channels or the streaming instance
77/// respectively.
78#[derive(Debug)]
79pub struct GrpcClient<G>
80where
81    G: ProtoServiceTypes,
82{
83    /// The sender for commands.
84    tx: UnboundedSender<G::PC>,
85    /// The receiver for responses.
86    rx: Streaming<G::PR>,
87}
88
89impl<G> GrpcClient<G>
90where
91    G: ProtoServiceTypes,
92{
93    /// Connects to the server at the given address, announcing the specified
94    /// client version.
95    pub async fn connect(
96        addr: String,
97        version: Version,
98        metrics: G::STATS,
99        params: &GrpcClientParameters,
100    ) -> Result<Self, anyhow::Error> {
101        debug!("GrpcClient {}: Attempt to connect", addr);
102
103        let channel = match SocketAddrType::guess(&addr) {
104            SocketAddrType::Inet => {
105                let mut endpoint = Endpoint::new(format!("http://{}", addr))?;
106                if let Some(connect_timeout) = params.connect_timeout {
107                    endpoint = endpoint.connect_timeout(connect_timeout);
108                }
109                if let Some(keep_alive_timeout) = params.http2_keep_alive_timeout {
110                    endpoint = endpoint.keep_alive_timeout(keep_alive_timeout);
111                }
112                if let Some(keep_alive_interval) = params.http2_keep_alive_interval {
113                    endpoint = endpoint.http2_keep_alive_interval(keep_alive_interval);
114                }
115                endpoint.connect().await?
116            }
117            SocketAddrType::Unix => {
118                let addr = addr.clone();
119                Endpoint::from_static("http://localhost") // URI is ignored
120                    .connect_with_connector(tower::service_fn(move |_| {
121                        let addr = addr.clone();
122                        async { UnixStream::connect(addr).await.map(TokioIo::new) }
123                    }))
124                    .await?
125            }
126            SocketAddrType::Turmoil => unimplemented!(),
127        };
128        let service = InterceptedService::new(channel, VersionAttachInterceptor::new(version));
129        let mut client = BidiProtoClient::new(service, G::URL, metrics);
130        let (tx, rx) = mpsc::unbounded_channel();
131        let rx = client
132            .establish_bidi_stream(UnboundedReceiverStream::new(rx))
133            .await?
134            .into_inner();
135        info!("GrpcClient {}: connected", &addr);
136        Ok(GrpcClient { tx, rx })
137    }
138
139    /// Like [`GrpcClient::connect`], but for multiple partitioned servers.
140    pub async fn connect_partitioned<C, R>(
141        dests: Vec<(String, G::STATS)>,
142        version: Version,
143        params: &GrpcClientParameters,
144    ) -> Result<Partitioned<Self, C, R>, anyhow::Error>
145    where
146        (C, R): Partitionable<C, R>,
147    {
148        let clients = future::try_join_all(
149            dests
150                .into_iter()
151                .map(|(addr, metrics)| Self::connect(addr, version.clone(), metrics, params)),
152        )
153        .await?;
154        Ok(Partitioned::new(clients))
155    }
156}
157
158#[async_trait]
159impl<G, C, R> GenericClient<C, R> for GrpcClient<G>
160where
161    C: RustType<G::PC> + Send + Sync + 'static,
162    R: RustType<G::PR> + Send + Sync + 'static,
163    G: ProtoServiceTypes,
164{
165    async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
166        self.tx.send(cmd.into_proto())?;
167        Ok(())
168    }
169
170    /// # Cancel safety
171    ///
172    /// This method is cancel safe. If `recv` is used as the event in a [`tokio::select!`]
173    /// statement and some other branch completes first, it is guaranteed that no messages were
174    /// received by this client.
175    async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
176        // `TryStreamExt::try_next` is cancel safe. The returned future only holds onto a
177        // reference to the underlying stream, so dropping it will never lose a value.
178        match self.rx.try_next().await? {
179            None => Ok(None),
180            Some(response) => Ok(Some(response.into_rust()?)),
181        }
182    }
183}
184
185/// Encapsulates the core functionality of a tonic gRPC client for a service
186/// that exposes a single bidirectional RPC stream.
187///
188/// The client calls back into the StatsCollector on each command send and
189/// response receive.
190///
191/// See the documentation on [`GrpcClient`] for details.
192pub struct BidiProtoClient<PC, PR, S>
193where
194    PC: prost::Message + 'static,
195    PR: Default + prost::Message + 'static,
196    S: StatsCollector<PC, PR>,
197{
198    inner: tonic::client::Grpc<ClientTransport>,
199    path: &'static str,
200    codec: StatCodec<PC, PR, S>,
201}
202
203impl<PC, PR, S> BidiProtoClient<PC, PR, S>
204where
205    PC: Clone + prost::Message + 'static,
206    PR: Clone + Default + prost::Message + 'static,
207    S: StatsCollector<PC, PR> + 'static,
208{
209    fn new(inner: ClientTransport, path: &'static str, stats_collector: S) -> Self
210    where
211        Self: Sized,
212    {
213        let inner = tonic::client::Grpc::new(inner)
214            .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE)
215            .max_encoding_message_size(MAX_GRPC_MESSAGE_SIZE);
216        let codec = StatCodec::new(stats_collector);
217        BidiProtoClient { inner, path, codec }
218    }
219
220    async fn establish_bidi_stream(
221        &mut self,
222        rx: UnboundedReceiverStream<PC>,
223    ) -> Result<Response<Streaming<PR>>, Status> {
224        self.inner.ready().await.map_err(|e| {
225            tonic::Status::new(
226                tonic::Code::Unknown,
227                format!("Service was not ready: {}", e),
228            )
229        })?;
230        let path = PathAndQuery::from_static(self.path);
231        self.inner
232            .streaming(rx.into_streaming_request(), path, self.codec.clone())
233            .await
234    }
235}
236
237/// A gRPC server that stitches a gRPC service with a single bidirectional
238/// stream to a [`GenericClient`].
239///
240/// It is the counterpart of [`GrpcClient`].
241///
242/// To use, implement the tonic-generated `ProtoService` trait for this type.
243/// The implementation of the bidirectional stream method should call
244/// [`GrpcServer::forward_bidi_stream`] to stitch the bidirectional stream to
245/// the client underlying this server.
246pub struct GrpcServer<F> {
247    state: Arc<GrpcServerState<F>>,
248}
249
250struct GrpcServerState<F> {
251    cancel_tx: Mutex<oneshot::Sender<()>>,
252    client_builder: F,
253    metrics: PerGrpcServerMetrics,
254}
255
256impl<F, G> GrpcServer<F>
257where
258    F: Fn() -> G + Send + Sync + 'static,
259{
260    /// Starts the server, listening for gRPC connections on `listen_addr`.
261    ///
262    /// The trait bounds on `S` are intimidating, but it is the return type of
263    /// `service_builder`, which is a function that
264    /// turns a `GrpcServer<ProtoCommandType, ProtoResponseType>` into a
265    /// [`Service`] that represents a gRPC server. This is always encapsulated
266    /// by the tonic-generated `ProtoServer::new` method for a specific Protobuf
267    /// service.
268    pub fn serve<S, Fs>(
269        metrics: &GrpcServerMetrics,
270        listen_addr: SocketAddr,
271        version: Version,
272        host: Option<String>,
273        client_builder: F,
274        service_builder: Fs,
275    ) -> impl Future<Output = Result<(), anyhow::Error>> + use<S, Fs, F, G>
276    where
277        S: Service<
278                http::Request<BoxBody>,
279                Response = http::Response<BoxBody>,
280                Error = std::convert::Infallible,
281            > + NamedService
282            + Clone
283            + Send
284            + 'static,
285        S::Future: Send + 'static,
286        Fs: FnOnce(Self) -> S + Send + 'static,
287    {
288        let (cancel_tx, _cancel_rx) = oneshot::channel();
289        let state = GrpcServerState {
290            cancel_tx: Mutex::new(cancel_tx),
291            client_builder,
292            metrics: metrics.for_server(S::NAME),
293        };
294        let server = Self {
295            state: Arc::new(state),
296        };
297        let service = service_builder(server);
298
299        if host.is_none() {
300            warn!("no host provided; request destination host checking is disabled");
301        }
302        let validation = RequestValidationLayer { version, host };
303
304        info!("Starting to listen on {}", listen_addr);
305
306        async {
307            let listener = Listener::bind(listen_addr).await?;
308
309            Server::builder()
310                .layer(validation)
311                .add_service(service)
312                .serve_with_incoming(listener)
313                .await?;
314            Ok(())
315        }
316    }
317
318    /// Handles a bidirectional stream request by forwarding commands to and
319    /// responses from the server's underlying client.
320    ///
321    /// Call this method from the implementation of the tonic-generated
322    /// `ProtoService`.
323    pub async fn forward_bidi_stream<C, R, PC, PR>(
324        &self,
325        request: Request<Streaming<PC>>,
326    ) -> Result<Response<ResponseStream<PR>>, Status>
327    where
328        G: GenericClient<C, R> + 'static,
329        C: RustType<PC> + Send + Sync + 'static + fmt::Debug,
330        R: RustType<PR> + Send + Sync + 'static + fmt::Debug,
331        PC: fmt::Debug + Send + Sync + 'static,
332        PR: fmt::Debug + Send + Sync + 'static,
333    {
334        info!("GrpcServer: remote client connected");
335
336        // Install our cancellation token. This may drop an existing
337        // cancellation token. We're allowed to run until someone else drops our
338        // cancellation token.
339        //
340        // TODO(benesch): rather than blindly dropping the existing cancellation
341        // token, we should check epochs, and only drop the existing connection
342        // if it is at a lower epoch.
343        // See: https://github.com/MaterializeInc/database-issues/issues/3840
344        let (cancel_tx, mut cancel_rx) = oneshot::channel();
345        *self.state.cancel_tx.lock().await = cancel_tx;
346
347        // Construct a new client and forward commands and responses until
348        // canceled.
349        let mut request = request.into_inner();
350        let state = Arc::clone(&self.state);
351        let stream = stream! {
352            let mut client = (state.client_builder)();
353            loop {
354                select! {
355                    command = request.next() => {
356                        let command = match command {
357                            None => break,
358                            Some(Ok(command)) => command,
359                            Some(Err(e)) => {
360                                error!("error handling client: {e}");
361                                break;
362                            }
363                        };
364
365                        match UNIX_EPOCH.elapsed() {
366                            Ok(ts) => state.metrics.last_command_received.set(ts.as_secs()),
367                            Err(e) => error!("failed to get system time: {e}"),
368                        }
369
370                        let command = match command.into_rust() {
371                            Ok(command) => command,
372                            Err(e) => {
373                                error!("error converting command from protobuf: {}", e);
374                                break;
375                            }
376                        };
377
378                        if let Err(e) = client.send(command).await {
379                            yield Err(Status::unknown(e.to_string()));
380                        }
381                    }
382                    response = client.recv() => {
383                        match response {
384                            Ok(Some(response)) => yield Ok(response.into_proto()),
385                            Ok(None) => break,
386                            Err(e) => yield Err(Status::unknown(e.to_string())),
387                        }
388                    }
389                    _ = &mut cancel_rx => break,
390                }
391            }
392        };
393        Ok(Response::new(ResponseStream::new(stream)))
394    }
395}
396
397/// A stream returning responses to GRPC clients.
398///
399/// This is defined as a struct, rather than a type alias, so that we can define a `Drop` impl that
400/// logs stream termination.
401pub struct ResponseStream<PR>(Pin<Box<dyn Stream<Item = Result<PR, Status>> + Send>>);
402
403impl<PR> ResponseStream<PR> {
404    fn new<S>(stream: S) -> Self
405    where
406        S: Stream<Item = Result<PR, Status>> + Send + 'static,
407    {
408        Self(Box::pin(stream))
409    }
410}
411
412impl<PR> Stream for ResponseStream<PR> {
413    type Item = Result<PR, Status>;
414
415    fn poll_next(
416        mut self: Pin<&mut Self>,
417        cx: &mut std::task::Context<'_>,
418    ) -> std::task::Poll<Option<Self::Item>> {
419        self.0.poll_next_unpin(cx)
420    }
421}
422
423impl<PR> Drop for ResponseStream<PR> {
424    fn drop(&mut self) {
425        info!("GrpcServer: response stream disconnected");
426    }
427}
428
429/// Metrics for a [`GrpcServer`].
430#[derive(Debug)]
431pub struct GrpcServerMetrics {
432    last_command_received: UIntGaugeVec,
433}
434
435impl GrpcServerMetrics {
436    /// Registers the GRPC server metrics into a `registry`.
437    pub fn register_with(registry: &MetricsRegistry) -> Self {
438        Self {
439            last_command_received: registry.register(metric!(
440                name: "mz_grpc_server_last_command_received",
441                help: "The time at which the server received its last command.",
442                var_labels: ["server_name"],
443            )),
444        }
445    }
446
447    fn for_server(&self, name: &'static str) -> PerGrpcServerMetrics {
448        PerGrpcServerMetrics {
449            last_command_received: self
450                .last_command_received
451                .get_delete_on_drop_metric(vec![name]),
452        }
453    }
454}
455
456#[derive(Debug)]
457struct PerGrpcServerMetrics {
458    last_command_received: DeleteOnDropGauge<AtomicU64, Vec<&'static str>>,
459}
460
461const VERSION_HEADER_KEY: &str = "x-mz-version";
462
463/// A gRPC interceptor that attaches a version as metadata to each request.
464#[derive(Debug, Clone)]
465pub struct VersionAttachInterceptor {
466    version: AsciiMetadataValue,
467}
468
469impl VersionAttachInterceptor {
470    fn new(version: Version) -> VersionAttachInterceptor {
471        VersionAttachInterceptor {
472            version: version
473                .to_string()
474                .try_into()
475                .expect("semver versions are valid metadata values"),
476        }
477    }
478}
479
480impl Interceptor for VersionAttachInterceptor {
481    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
482        request
483            .metadata_mut()
484            .insert(VERSION_HEADER_KEY, self.version.clone());
485        Ok(request)
486    }
487}
488
489/// A `tower` layer that validates requests for compatibility with the server.
490#[derive(Clone)]
491struct RequestValidationLayer {
492    version: Version,
493    host: Option<String>,
494}
495
496impl<S> tower::Layer<S> for RequestValidationLayer {
497    type Service = RequestValidation<S>;
498
499    fn layer(&self, inner: S) -> Self::Service {
500        let version = self
501            .version
502            .to_string()
503            .try_into()
504            .expect("version is a valid header value");
505        RequestValidation {
506            inner,
507            version,
508            host: self.host.clone(),
509        }
510    }
511}
512
513/// A `tower` middleware that validates requests for compatibility with the server.
514#[derive(Clone)]
515struct RequestValidation<S> {
516    inner: S,
517    version: http::HeaderValue,
518    host: Option<String>,
519}
520
521impl<S, B> Service<http::Request<B>> for RequestValidation<S>
522where
523    S: Service<http::Request<B>, Error = Box<dyn Error + Send + Sync + 'static>>,
524    S::Response: Send + 'static,
525    S::Future: Send + 'static,
526{
527    type Response = S::Response;
528    type Error = S::Error;
529    type Future = BoxFuture<'static, Result<S::Response, S::Error>>;
530
531    fn poll_ready(
532        &mut self,
533        cx: &mut std::task::Context<'_>,
534    ) -> std::task::Poll<Result<(), Self::Error>> {
535        self.inner.poll_ready(cx)
536    }
537
538    fn call(&mut self, req: http::Request<B>) -> Self::Future {
539        let error = |msg| {
540            let error: S::Error = Box::new(Status::permission_denied(msg));
541            Box::pin(future::ready(Err(error)))
542        };
543
544        let Some(req_version) = req.headers().get(VERSION_HEADER_KEY) else {
545            return error("request missing version header".into());
546        };
547        if req_version != self.version {
548            return error(format!(
549                "request has version {req_version:?} but {:?} required",
550                self.version
551            ));
552        }
553
554        let req_host = req.uri().host();
555        if let (Some(req_host), Some(host)) = (req_host, &self.host) {
556            if req_host != host {
557                return error(format!(
558                    "request has host {req_host:?} but {host:?} required"
559                ));
560            }
561        }
562
563        Box::pin(self.inner.call(req))
564    }
565}