1use async_stream::stream;
13use async_trait::async_trait;
14use futures::future::{self, BoxFuture};
15use futures::stream::{Stream, StreamExt, TryStreamExt};
16use http::uri::PathAndQuery;
17use hyper_util::rt::TokioIo;
18use mz_ore::metric;
19use mz_ore::metrics::{DeleteOnDropGauge, MetricsRegistry, UIntGaugeVec};
20use mz_ore::netio::{Listener, SocketAddr, SocketAddrType};
21use mz_proto::{ProtoType, RustType};
22use prometheus::core::AtomicU64;
23use semver::Version;
24use std::error::Error;
25use std::fmt::{self, Debug};
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::time::UNIX_EPOCH;
30use tokio::net::UnixStream;
31use tokio::select;
32use tokio::sync::mpsc::{self, UnboundedSender};
33use tokio::sync::{Mutex, oneshot};
34use tokio_stream::wrappers::UnboundedReceiverStream;
35use tonic::body::BoxBody;
36use tonic::codegen::InterceptedService;
37use tonic::metadata::AsciiMetadataValue;
38use tonic::server::NamedService;
39use tonic::service::Interceptor;
40use tonic::transport::{Channel, Endpoint, Server};
41use tonic::{IntoStreamingRequest, Request, Response, Status, Streaming};
42use tower::Service;
43use tracing::{debug, error, info, warn};
44
45use crate::client::{GenericClient, Partitionable, Partitioned};
46use crate::codec::{StatCodec, StatsCollector};
47use crate::params::GrpcClientParameters;
48
49include!(concat!(env!("OUT_DIR"), "/mz_service.params.rs"));
50
51pub const MAX_GRPC_MESSAGE_SIZE: usize = usize::MAX;
53
54pub type ClientTransport = InterceptedService<Channel, VersionAttachInterceptor>;
55
56pub trait ProtoServiceTypes: Debug + Clone + Send {
58 type PC: prost::Message + Clone + 'static;
59 type PR: prost::Message + Clone + Default + 'static;
60 type STATS: StatsCollector<Self::PC, Self::PR> + 'static;
61 const URL: &'static str;
62}
63
64#[derive(Debug)]
79pub struct GrpcClient<G>
80where
81 G: ProtoServiceTypes,
82{
83 tx: UnboundedSender<G::PC>,
85 rx: Streaming<G::PR>,
87}
88
89impl<G> GrpcClient<G>
90where
91 G: ProtoServiceTypes,
92{
93 pub async fn connect(
96 addr: String,
97 version: Version,
98 metrics: G::STATS,
99 params: &GrpcClientParameters,
100 ) -> Result<Self, anyhow::Error> {
101 debug!("GrpcClient {}: Attempt to connect", addr);
102
103 let channel = match SocketAddrType::guess(&addr) {
104 SocketAddrType::Inet => {
105 let mut endpoint = Endpoint::new(format!("http://{}", addr))?;
106 if let Some(connect_timeout) = params.connect_timeout {
107 endpoint = endpoint.connect_timeout(connect_timeout);
108 }
109 if let Some(keep_alive_timeout) = params.http2_keep_alive_timeout {
110 endpoint = endpoint.keep_alive_timeout(keep_alive_timeout);
111 }
112 if let Some(keep_alive_interval) = params.http2_keep_alive_interval {
113 endpoint = endpoint.http2_keep_alive_interval(keep_alive_interval);
114 }
115 endpoint.connect().await?
116 }
117 SocketAddrType::Unix => {
118 let addr = addr.clone();
119 Endpoint::from_static("http://localhost") .connect_with_connector(tower::service_fn(move |_| {
121 let addr = addr.clone();
122 async { UnixStream::connect(addr).await.map(TokioIo::new) }
123 }))
124 .await?
125 }
126 SocketAddrType::Turmoil => unimplemented!(),
127 };
128 let service = InterceptedService::new(channel, VersionAttachInterceptor::new(version));
129 let mut client = BidiProtoClient::new(service, G::URL, metrics);
130 let (tx, rx) = mpsc::unbounded_channel();
131 let rx = client
132 .establish_bidi_stream(UnboundedReceiverStream::new(rx))
133 .await?
134 .into_inner();
135 info!("GrpcClient {}: connected", &addr);
136 Ok(GrpcClient { tx, rx })
137 }
138
139 pub async fn connect_partitioned<C, R>(
141 dests: Vec<(String, G::STATS)>,
142 version: Version,
143 params: &GrpcClientParameters,
144 ) -> Result<Partitioned<Self, C, R>, anyhow::Error>
145 where
146 (C, R): Partitionable<C, R>,
147 {
148 let clients = future::try_join_all(
149 dests
150 .into_iter()
151 .map(|(addr, metrics)| Self::connect(addr, version.clone(), metrics, params)),
152 )
153 .await?;
154 Ok(Partitioned::new(clients))
155 }
156}
157
158#[async_trait]
159impl<G, C, R> GenericClient<C, R> for GrpcClient<G>
160where
161 C: RustType<G::PC> + Send + Sync + 'static,
162 R: RustType<G::PR> + Send + Sync + 'static,
163 G: ProtoServiceTypes,
164{
165 async fn send(&mut self, cmd: C) -> Result<(), anyhow::Error> {
166 self.tx.send(cmd.into_proto())?;
167 Ok(())
168 }
169
170 async fn recv(&mut self) -> Result<Option<R>, anyhow::Error> {
176 match self.rx.try_next().await? {
179 None => Ok(None),
180 Some(response) => Ok(Some(response.into_rust()?)),
181 }
182 }
183}
184
185pub struct BidiProtoClient<PC, PR, S>
193where
194 PC: prost::Message + 'static,
195 PR: Default + prost::Message + 'static,
196 S: StatsCollector<PC, PR>,
197{
198 inner: tonic::client::Grpc<ClientTransport>,
199 path: &'static str,
200 codec: StatCodec<PC, PR, S>,
201}
202
203impl<PC, PR, S> BidiProtoClient<PC, PR, S>
204where
205 PC: Clone + prost::Message + 'static,
206 PR: Clone + Default + prost::Message + 'static,
207 S: StatsCollector<PC, PR> + 'static,
208{
209 fn new(inner: ClientTransport, path: &'static str, stats_collector: S) -> Self
210 where
211 Self: Sized,
212 {
213 let inner = tonic::client::Grpc::new(inner)
214 .max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE)
215 .max_encoding_message_size(MAX_GRPC_MESSAGE_SIZE);
216 let codec = StatCodec::new(stats_collector);
217 BidiProtoClient { inner, path, codec }
218 }
219
220 async fn establish_bidi_stream(
221 &mut self,
222 rx: UnboundedReceiverStream<PC>,
223 ) -> Result<Response<Streaming<PR>>, Status> {
224 self.inner.ready().await.map_err(|e| {
225 tonic::Status::new(
226 tonic::Code::Unknown,
227 format!("Service was not ready: {}", e),
228 )
229 })?;
230 let path = PathAndQuery::from_static(self.path);
231 self.inner
232 .streaming(rx.into_streaming_request(), path, self.codec.clone())
233 .await
234 }
235}
236
237pub struct GrpcServer<F> {
247 state: Arc<GrpcServerState<F>>,
248}
249
250struct GrpcServerState<F> {
251 cancel_tx: Mutex<oneshot::Sender<()>>,
252 client_builder: F,
253 metrics: PerGrpcServerMetrics,
254}
255
256impl<F, G> GrpcServer<F>
257where
258 F: Fn() -> G + Send + Sync + 'static,
259{
260 pub fn serve<S, Fs>(
269 metrics: &GrpcServerMetrics,
270 listen_addr: SocketAddr,
271 version: Version,
272 host: Option<String>,
273 client_builder: F,
274 service_builder: Fs,
275 ) -> impl Future<Output = Result<(), anyhow::Error>> + use<S, Fs, F, G>
276 where
277 S: Service<
278 http::Request<BoxBody>,
279 Response = http::Response<BoxBody>,
280 Error = std::convert::Infallible,
281 > + NamedService
282 + Clone
283 + Send
284 + 'static,
285 S::Future: Send + 'static,
286 Fs: FnOnce(Self) -> S + Send + 'static,
287 {
288 let (cancel_tx, _cancel_rx) = oneshot::channel();
289 let state = GrpcServerState {
290 cancel_tx: Mutex::new(cancel_tx),
291 client_builder,
292 metrics: metrics.for_server(S::NAME),
293 };
294 let server = Self {
295 state: Arc::new(state),
296 };
297 let service = service_builder(server);
298
299 if host.is_none() {
300 warn!("no host provided; request destination host checking is disabled");
301 }
302 let validation = RequestValidationLayer { version, host };
303
304 info!("Starting to listen on {}", listen_addr);
305
306 async {
307 let listener = Listener::bind(listen_addr).await?;
308
309 Server::builder()
310 .layer(validation)
311 .add_service(service)
312 .serve_with_incoming(listener)
313 .await?;
314 Ok(())
315 }
316 }
317
318 pub async fn forward_bidi_stream<C, R, PC, PR>(
324 &self,
325 request: Request<Streaming<PC>>,
326 ) -> Result<Response<ResponseStream<PR>>, Status>
327 where
328 G: GenericClient<C, R> + 'static,
329 C: RustType<PC> + Send + Sync + 'static + fmt::Debug,
330 R: RustType<PR> + Send + Sync + 'static + fmt::Debug,
331 PC: fmt::Debug + Send + Sync + 'static,
332 PR: fmt::Debug + Send + Sync + 'static,
333 {
334 info!("GrpcServer: remote client connected");
335
336 let (cancel_tx, mut cancel_rx) = oneshot::channel();
345 *self.state.cancel_tx.lock().await = cancel_tx;
346
347 let mut request = request.into_inner();
350 let state = Arc::clone(&self.state);
351 let stream = stream! {
352 let mut client = (state.client_builder)();
353 loop {
354 select! {
355 command = request.next() => {
356 let command = match command {
357 None => break,
358 Some(Ok(command)) => command,
359 Some(Err(e)) => {
360 error!("error handling client: {e}");
361 break;
362 }
363 };
364
365 match UNIX_EPOCH.elapsed() {
366 Ok(ts) => state.metrics.last_command_received.set(ts.as_secs()),
367 Err(e) => error!("failed to get system time: {e}"),
368 }
369
370 let command = match command.into_rust() {
371 Ok(command) => command,
372 Err(e) => {
373 error!("error converting command from protobuf: {}", e);
374 break;
375 }
376 };
377
378 if let Err(e) = client.send(command).await {
379 yield Err(Status::unknown(e.to_string()));
380 }
381 }
382 response = client.recv() => {
383 match response {
384 Ok(Some(response)) => yield Ok(response.into_proto()),
385 Ok(None) => break,
386 Err(e) => yield Err(Status::unknown(e.to_string())),
387 }
388 }
389 _ = &mut cancel_rx => break,
390 }
391 }
392 };
393 Ok(Response::new(ResponseStream::new(stream)))
394 }
395}
396
397pub struct ResponseStream<PR>(Pin<Box<dyn Stream<Item = Result<PR, Status>> + Send>>);
402
403impl<PR> ResponseStream<PR> {
404 fn new<S>(stream: S) -> Self
405 where
406 S: Stream<Item = Result<PR, Status>> + Send + 'static,
407 {
408 Self(Box::pin(stream))
409 }
410}
411
412impl<PR> Stream for ResponseStream<PR> {
413 type Item = Result<PR, Status>;
414
415 fn poll_next(
416 mut self: Pin<&mut Self>,
417 cx: &mut std::task::Context<'_>,
418 ) -> std::task::Poll<Option<Self::Item>> {
419 self.0.poll_next_unpin(cx)
420 }
421}
422
423impl<PR> Drop for ResponseStream<PR> {
424 fn drop(&mut self) {
425 info!("GrpcServer: response stream disconnected");
426 }
427}
428
429#[derive(Debug)]
431pub struct GrpcServerMetrics {
432 last_command_received: UIntGaugeVec,
433}
434
435impl GrpcServerMetrics {
436 pub fn register_with(registry: &MetricsRegistry) -> Self {
438 Self {
439 last_command_received: registry.register(metric!(
440 name: "mz_grpc_server_last_command_received",
441 help: "The time at which the server received its last command.",
442 var_labels: ["server_name"],
443 )),
444 }
445 }
446
447 fn for_server(&self, name: &'static str) -> PerGrpcServerMetrics {
448 PerGrpcServerMetrics {
449 last_command_received: self
450 .last_command_received
451 .get_delete_on_drop_metric(vec![name]),
452 }
453 }
454}
455
456#[derive(Debug)]
457struct PerGrpcServerMetrics {
458 last_command_received: DeleteOnDropGauge<AtomicU64, Vec<&'static str>>,
459}
460
461const VERSION_HEADER_KEY: &str = "x-mz-version";
462
463#[derive(Debug, Clone)]
465pub struct VersionAttachInterceptor {
466 version: AsciiMetadataValue,
467}
468
469impl VersionAttachInterceptor {
470 fn new(version: Version) -> VersionAttachInterceptor {
471 VersionAttachInterceptor {
472 version: version
473 .to_string()
474 .try_into()
475 .expect("semver versions are valid metadata values"),
476 }
477 }
478}
479
480impl Interceptor for VersionAttachInterceptor {
481 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
482 request
483 .metadata_mut()
484 .insert(VERSION_HEADER_KEY, self.version.clone());
485 Ok(request)
486 }
487}
488
489#[derive(Clone)]
491struct RequestValidationLayer {
492 version: Version,
493 host: Option<String>,
494}
495
496impl<S> tower::Layer<S> for RequestValidationLayer {
497 type Service = RequestValidation<S>;
498
499 fn layer(&self, inner: S) -> Self::Service {
500 let version = self
501 .version
502 .to_string()
503 .try_into()
504 .expect("version is a valid header value");
505 RequestValidation {
506 inner,
507 version,
508 host: self.host.clone(),
509 }
510 }
511}
512
513#[derive(Clone)]
515struct RequestValidation<S> {
516 inner: S,
517 version: http::HeaderValue,
518 host: Option<String>,
519}
520
521impl<S, B> Service<http::Request<B>> for RequestValidation<S>
522where
523 S: Service<http::Request<B>, Error = Box<dyn Error + Send + Sync + 'static>>,
524 S::Response: Send + 'static,
525 S::Future: Send + 'static,
526{
527 type Response = S::Response;
528 type Error = S::Error;
529 type Future = BoxFuture<'static, Result<S::Response, S::Error>>;
530
531 fn poll_ready(
532 &mut self,
533 cx: &mut std::task::Context<'_>,
534 ) -> std::task::Poll<Result<(), Self::Error>> {
535 self.inner.poll_ready(cx)
536 }
537
538 fn call(&mut self, req: http::Request<B>) -> Self::Future {
539 let error = |msg| {
540 let error: S::Error = Box::new(Status::permission_denied(msg));
541 Box::pin(future::ready(Err(error)))
542 };
543
544 let Some(req_version) = req.headers().get(VERSION_HEADER_KEY) else {
545 return error("request missing version header".into());
546 };
547 if req_version != self.version {
548 return error(format!(
549 "request has version {req_version:?} but {:?} required",
550 self.version
551 ));
552 }
553
554 let req_host = req.uri().host();
555 if let (Some(req_host), Some(host)) = (req_host, &self.host) {
556 if req_host != host {
557 return error(format!(
558 "request has host {req_host:?} but {host:?} required"
559 ));
560 }
561 }
562
563 Box::pin(self.inner.call(req))
564 }
565}