mz_balancerd/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use async_trait::async_trait;
11use bytes::{Buf, BufMut, BytesMut};
12use bytesize::ByteSize;
13use futures::{SinkExt, TryStreamExt, sink};
14use mz_ore::cast::CastFrom;
15use mz_ore::future::OreSinkExt;
16use mz_ore::netio::AsyncReady;
17use mz_pgwire_common::{
18    Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf,
19    parse_frame_len,
20};
21use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
22use tokio_util::codec::{Decoder, Encoder, Framed};
23
24/// Internal representation of a backend [message].
25///
26/// [message]: https://www.postgresql.org/docs/11/protocol-message-formats.html
27#[derive(Debug)]
28pub enum BackendMessage {
29    AuthenticationCleartextPassword,
30    ErrorResponse(ErrorResponse),
31}
32
33impl From<ErrorResponse> for BackendMessage {
34    fn from(err: ErrorResponse) -> BackendMessage {
35        BackendMessage::ErrorResponse(err)
36    }
37}
38
39/// A connection that manages the encoding and decoding of pgwire frames.
40pub struct FramedConn<A> {
41    inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
42}
43
44impl<A> FramedConn<A>
45where
46    A: AsyncRead + AsyncWrite + Unpin,
47{
48    /// Constructs a new framed connection.
49    ///
50    /// The underlying connection, `inner`, is expected to be something like a
51    /// TCP stream. Anything that implements [`AsyncRead`] and [`AsyncWrite`]
52    /// will do.
53    pub fn new(inner: Conn<A>) -> FramedConn<A> {
54        FramedConn {
55            inner: Framed::new(inner, Codec::new()).buffer(32),
56        }
57    }
58
59    /// Reads and decodes one frontend message from the client.
60    ///
61    /// Blocks until the client sends a complete message. If the client
62    /// terminates the stream, returns `None`. Returns an error if the client
63    /// sends a malformed message or if the connection underlying is broken.
64    ///
65    /// # Cancel safety
66    ///
67    /// This method is cancel safe. The returned future only holds onto a
68    /// reference to thea underlying stream, so dropping it will never lose a
69    /// value.
70    ///
71    /// <https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#cancel-safety-1>
72    pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
73        let message = self.inner.try_next().await?;
74        Ok(message)
75    }
76
77    /// Encodes and sends one backend message to the client.
78    ///
79    /// Note that the connection is not flushed after calling this method. You
80    /// must call [`FramedConn::flush`] explicitly. Returns an error if the
81    /// underlying connection is broken.
82    ///
83    /// Please use `StateMachine::send` instead if calling from `StateMachine`,
84    /// as it applies session-based filters before calling this method.
85    pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
86    where
87        M: Into<BackendMessage>,
88    {
89        let message = message.into();
90        self.inner.enqueue(message).await
91    }
92
93    /// Flushes all outstanding messages.
94    pub async fn flush(&mut self) -> Result<(), io::Error> {
95        self.inner.flush().await
96    }
97}
98
99impl<A> FramedConn<A>
100where
101    A: AsyncRead + AsyncWrite + Unpin,
102{
103    pub fn inner(&self) -> &Conn<A> {
104        self.inner.get_ref().get_ref()
105    }
106    pub fn inner_mut(&mut self) -> &mut Conn<A> {
107        self.inner.get_mut().get_mut()
108    }
109}
110
111#[async_trait]
112impl<A> AsyncReady for FramedConn<A>
113where
114    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
115{
116    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
117        self.inner.get_ref().get_ref().ready(interest).await
118    }
119}
120
121struct Codec {
122    decode_state: DecodeState,
123}
124
125impl Codec {
126    /// Creates a new `Codec`.
127    pub fn new() -> Codec {
128        Codec {
129            decode_state: DecodeState::Head,
130        }
131    }
132}
133
134impl Default for Codec {
135    fn default() -> Codec {
136        Codec::new()
137    }
138}
139
140impl Encoder<BackendMessage> for Codec {
141    type Error = io::Error;
142
143    fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
144        // Write type byte.
145        let byte = match &msg {
146            BackendMessage::AuthenticationCleartextPassword => b'R',
147            BackendMessage::ErrorResponse(r) => {
148                if r.severity.is_error() {
149                    b'E'
150                } else {
151                    b'N'
152                }
153            }
154        };
155        dst.put_u8(byte);
156
157        // Write message length placeholder. The true length is filled in later.
158        let base = dst.len();
159        dst.put_u32(0);
160
161        // Write message contents.
162        match msg {
163            BackendMessage::AuthenticationCleartextPassword => {
164                dst.put_u32(3);
165            }
166            BackendMessage::ErrorResponse(ErrorResponse {
167                severity,
168                code,
169                message,
170                detail,
171                hint,
172                position,
173            }) => {
174                dst.put_u8(b'S');
175                dst.put_string(severity.as_str());
176                dst.put_u8(b'C');
177                dst.put_string(code.code());
178                dst.put_u8(b'M');
179                dst.put_string(&message);
180                if let Some(detail) = &detail {
181                    dst.put_u8(b'D');
182                    dst.put_string(detail);
183                }
184                if let Some(hint) = &hint {
185                    dst.put_u8(b'H');
186                    dst.put_string(hint);
187                }
188                if let Some(position) = &position {
189                    dst.put_u8(b'P');
190                    dst.put_string(&position.to_string());
191                }
192                dst.put_u8(b'\0');
193            }
194        }
195
196        let len = dst.len() - base;
197
198        // Overwrite length placeholder with true length.
199        let len = i32::try_from(len).map_err(|_| {
200            io::Error::new(
201                io::ErrorKind::Other,
202                "length of encoded message does not fit into an i32",
203            )
204        })?;
205        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
206
207        Ok(())
208    }
209}
210
211impl Decoder for Codec {
212    type Item = FrontendMessage;
213    type Error = io::Error;
214
215    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
216        if src.len() > MAX_REQUEST_SIZE {
217            return Err(io::Error::new(
218                io::ErrorKind::InvalidData,
219                format!(
220                    "request larger than {}",
221                    ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
222                ),
223            ));
224        }
225        loop {
226            match self.decode_state {
227                DecodeState::Head => {
228                    if src.len() < 5 {
229                        return Ok(None);
230                    }
231                    let msg_type = src[0];
232                    let frame_len = parse_frame_len(&src[1..])?;
233                    src.advance(5);
234                    src.reserve(frame_len);
235                    self.decode_state = DecodeState::Data(msg_type, frame_len);
236                }
237
238                DecodeState::Data(msg_type, frame_len) => {
239                    if src.len() < frame_len {
240                        return Ok(None);
241                    }
242                    let buf = src.split_to(frame_len).freeze();
243                    let buf = Cursor::new(&buf);
244                    let msg = match msg_type {
245                        // Termination.
246                        b'X' => decode_terminate(buf)?,
247
248                        // Authentication.
249                        b'p' => decode_password(buf)?,
250
251                        // Invalid.
252                        _ => {
253                            return Err(io::Error::new(
254                                io::ErrorKind::InvalidData,
255                                format!("unknown message type {}", msg_type),
256                            ));
257                        }
258                    };
259                    src.reserve(5);
260                    self.decode_state = DecodeState::Head;
261                    return Ok(Some(msg));
262                }
263            }
264        }
265    }
266}
267
268fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
269    // Nothing more to decode.
270    Ok(FrontendMessage::Terminate)
271}
272
273fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
274    Ok(FrontendMessage::Password {
275        password: buf.read_cstr()?.to_owned(),
276    })
277}