1use std::fmt::Debug;
15use std::marker::PhantomData;
16
17use prost::bytes::{Buf, BufMut};
18use prost::{DecodeError, Message};
19use tonic::codec::{Codec, Decoder, Encoder};
20use tonic::{Code, Status};
21
22pub trait StatsCollector<C, R>: Clone + Debug + Send + Sync {
23 fn send_event(&self, item: &C, size: usize);
24 fn receive_event(&self, item: &R, size: usize);
25}
26
27#[derive(Debug, Clone, Default)]
28pub struct StatEncoder<C, R, S> {
29 _pd: PhantomData<(C, R)>,
30 stats_collector: S,
31}
32
33impl<C, R, S> StatEncoder<C, R, S>
34where
35 S: StatsCollector<C, R>,
36{
37 pub fn new(stats_collector: S) -> StatEncoder<C, R, S> {
38 StatEncoder {
39 _pd: Default::default(),
40 stats_collector,
41 }
42 }
43}
44
45impl<C, R, S> Encoder for StatEncoder<C, R, S>
46where
47 C: Message,
48 S: StatsCollector<C, R>,
49{
50 type Item = C;
51 type Error = Status;
52
53 fn encode(
54 &mut self,
55 item: Self::Item,
56 buf: &mut tonic::codec::EncodeBuf<'_>,
57 ) -> Result<(), Self::Error> {
58 let initial_remaining = buf.remaining_mut();
59 item.encode(buf)
60 .expect("Message only errors if not enough space");
61 let encoded_len = initial_remaining - buf.remaining_mut();
62 self.stats_collector.send_event(&item, encoded_len);
63
64 Ok(())
65 }
66}
67
68#[derive(Debug, Clone, Default)]
69pub struct StatDecoder<C, R, S> {
70 _pd: PhantomData<(C, R)>,
71 stats_collector: S,
72}
73
74impl<C, R, S> StatDecoder<C, R, S> {
75 pub fn new(stats_collector: S) -> StatDecoder<C, R, S> {
76 StatDecoder {
77 _pd: PhantomData,
78 stats_collector,
79 }
80 }
81}
82
83impl<C, R, S> Decoder for StatDecoder<C, R, S>
84where
85 R: Default + Message,
86 S: StatsCollector<C, R>,
87{
88 type Item = R;
89 type Error = Status;
90
91 fn decode(
92 &mut self,
93 buf: &mut tonic::codec::DecodeBuf<'_>,
94 ) -> Result<Option<Self::Item>, Self::Error> {
95 let remaining_before = buf.remaining();
96 let item = Message::decode(buf).map_err(from_decode_error)?;
97 self.stats_collector.receive_event(&item, remaining_before);
98 Ok(Some(item))
99 }
100}
101
102fn from_decode_error(error: DecodeError) -> Status {
103 Status::new(Code::Internal, error.to_string())
106}
107
108#[derive(Debug, Clone)]
109pub struct StatCodec<C, R, S> {
110 _pd: PhantomData<(C, R)>,
111 stats_collector: S,
112}
113
114impl<C, R, S> StatCodec<C, R, S> {
115 pub fn new(stats_collector: S) -> StatCodec<C, R, S> {
116 StatCodec {
117 _pd: PhantomData,
118 stats_collector,
119 }
120 }
121}
122
123impl<C, R, S> Codec for StatCodec<C, R, S>
124where
125 C: Message + 'static,
126 R: Default + Message + 'static,
127 S: StatsCollector<C, R> + 'static,
128{
129 type Encode = C;
130 type Decode = R;
131 type Encoder = StatEncoder<C, R, S>;
132 type Decoder = StatDecoder<C, R, S>;
133
134 fn encoder(&mut self) -> Self::Encoder {
135 StatEncoder::new(self.stats_collector.clone())
136 }
137
138 fn decoder(&mut self) -> Self::Decoder {
139 StatDecoder::new(self.stats_collector.clone())
140 }
141}