mz_service/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! gRPC stats collecting codec for the [client](crate::client) module.
11//!
12//! The implementation of StatCodec is based on tonic's ProstCodec.
13
14use std::fmt::Debug;
15use std::marker::PhantomData;
16
17use prost::bytes::{Buf, BufMut};
18use prost::{DecodeError, Message};
19use tonic::codec::{Codec, Decoder, Encoder};
20use tonic::{Code, Status};
21
22pub trait StatsCollector<C, R>: Clone + Debug + Send + Sync {
23    fn send_event(&self, item: &C, size: usize);
24    fn receive_event(&self, item: &R, size: usize);
25}
26
27#[derive(Debug, Clone, Default)]
28pub struct StatEncoder<C, R, S> {
29    _pd: PhantomData<(C, R)>,
30    stats_collector: S,
31}
32
33impl<C, R, S> StatEncoder<C, R, S>
34where
35    S: StatsCollector<C, R>,
36{
37    pub fn new(stats_collector: S) -> StatEncoder<C, R, S> {
38        StatEncoder {
39            _pd: Default::default(),
40            stats_collector,
41        }
42    }
43}
44
45impl<C, R, S> Encoder for StatEncoder<C, R, S>
46where
47    C: Message,
48    S: StatsCollector<C, R>,
49{
50    type Item = C;
51    type Error = Status;
52
53    fn encode(
54        &mut self,
55        item: Self::Item,
56        buf: &mut tonic::codec::EncodeBuf<'_>,
57    ) -> Result<(), Self::Error> {
58        let initial_remaining = buf.remaining_mut();
59        item.encode(buf)
60            .expect("Message only errors if not enough space");
61        let encoded_len = initial_remaining - buf.remaining_mut();
62        self.stats_collector.send_event(&item, encoded_len);
63
64        Ok(())
65    }
66}
67
68#[derive(Debug, Clone, Default)]
69pub struct StatDecoder<C, R, S> {
70    _pd: PhantomData<(C, R)>,
71    stats_collector: S,
72}
73
74impl<C, R, S> StatDecoder<C, R, S> {
75    pub fn new(stats_collector: S) -> StatDecoder<C, R, S> {
76        StatDecoder {
77            _pd: PhantomData,
78            stats_collector,
79        }
80    }
81}
82
83impl<C, R, S> Decoder for StatDecoder<C, R, S>
84where
85    R: Default + Message,
86    S: StatsCollector<C, R>,
87{
88    type Item = R;
89    type Error = Status;
90
91    fn decode(
92        &mut self,
93        buf: &mut tonic::codec::DecodeBuf<'_>,
94    ) -> Result<Option<Self::Item>, Self::Error> {
95        let remaining_before = buf.remaining();
96        let item = Message::decode(buf).map_err(from_decode_error)?;
97        self.stats_collector.receive_event(&item, remaining_before);
98        Ok(Some(item))
99    }
100}
101
102fn from_decode_error(error: DecodeError) -> Status {
103    // Map Protobuf parse errors to an INTERNAL status code, as per
104    // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
105    Status::new(Code::Internal, error.to_string())
106}
107
108#[derive(Debug, Clone)]
109pub struct StatCodec<C, R, S> {
110    _pd: PhantomData<(C, R)>,
111    stats_collector: S,
112}
113
114impl<C, R, S> StatCodec<C, R, S> {
115    pub fn new(stats_collector: S) -> StatCodec<C, R, S> {
116        StatCodec {
117            _pd: PhantomData,
118            stats_collector,
119        }
120    }
121}
122
123impl<C, R, S> Codec for StatCodec<C, R, S>
124where
125    C: Message + 'static,
126    R: Default + Message + 'static,
127    S: StatsCollector<C, R> + 'static,
128{
129    type Encode = C;
130    type Decode = R;
131    type Encoder = StatEncoder<C, R, S>;
132    type Decoder = StatDecoder<C, R, S>;
133
134    fn encoder(&mut self) -> Self::Encoder {
135        StatEncoder::new(self.stats_collector.clone())
136    }
137
138    fn decoder(&mut self) -> Self::Decoder {
139        StatDecoder::new(self.stats_collector.clone())
140    }
141}