Skip to main content

mz_balancerd/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use async_trait::async_trait;
11use bytes::{Buf, BufMut, BytesMut};
12use bytesize::ByteSize;
13use futures::{SinkExt, TryStreamExt, sink};
14use mz_ore::cast::CastFrom;
15use mz_ore::future::OreSinkExt;
16use mz_ore::netio::AsyncReady;
17use mz_pgwire_common::{
18    Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf,
19    parse_frame_len,
20};
21use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
22use tokio_util::codec::{Decoder, Encoder, Framed};
23
24/// Internal representation of a backend [message].
25///
26/// [message]: https://www.postgresql.org/docs/11/protocol-message-formats.html
27#[derive(Debug)]
28pub enum BackendMessage {
29    AuthenticationCleartextPassword,
30    ErrorResponse(ErrorResponse),
31}
32
33impl From<ErrorResponse> for BackendMessage {
34    fn from(err: ErrorResponse) -> BackendMessage {
35        BackendMessage::ErrorResponse(err)
36    }
37}
38
39/// A connection that manages the encoding and decoding of pgwire frames.
40pub struct FramedConn<A> {
41    inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
42}
43
44impl<A> FramedConn<A>
45where
46    A: AsyncRead + AsyncWrite + Unpin,
47{
48    /// Constructs a new framed connection.
49    ///
50    /// The underlying connection, `inner`, is expected to be something like a
51    /// TCP stream. Anything that implements [`AsyncRead`] and [`AsyncWrite`]
52    /// will do.
53    pub fn new(inner: Conn<A>) -> FramedConn<A> {
54        FramedConn {
55            inner: Framed::new(inner, Codec::new()).buffer(32),
56        }
57    }
58
59    /// Reads and decodes one frontend message from the client.
60    ///
61    /// Blocks until the client sends a complete message. If the client
62    /// terminates the stream, returns `None`. Returns an error if the client
63    /// sends a malformed message or if the connection underlying is broken.
64    ///
65    /// # Cancel safety
66    ///
67    /// This method is cancel safe. The returned future only holds onto a
68    /// reference to thea underlying stream, so dropping it will never lose a
69    /// value.
70    ///
71    /// <https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#cancel-safety-1>
72    pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
73        let message = self.inner.try_next().await?;
74        Ok(message)
75    }
76
77    /// Encodes and sends one backend message to the client.
78    ///
79    /// Note that the connection is not flushed after calling this method. You
80    /// must call [`FramedConn::flush`] explicitly. Returns an error if the
81    /// underlying connection is broken.
82    ///
83    /// Please use `StateMachine::send` instead if calling from `StateMachine`,
84    /// as it applies session-based filters before calling this method.
85    pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
86    where
87        M: Into<BackendMessage>,
88    {
89        let message = message.into();
90        self.inner.enqueue(message).await
91    }
92
93    /// Flushes all outstanding messages.
94    pub async fn flush(&mut self) -> Result<(), io::Error> {
95        self.inner.flush().await
96    }
97}
98
99impl<A> FramedConn<A>
100where
101    A: AsyncRead + AsyncWrite + Unpin,
102{
103    pub fn inner(&self) -> &Conn<A> {
104        self.inner.get_ref().get_ref()
105    }
106    pub fn inner_mut(&mut self) -> &mut Conn<A> {
107        self.inner.get_mut().get_mut()
108    }
109}
110
111#[async_trait]
112impl<A> AsyncReady for FramedConn<A>
113where
114    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
115{
116    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
117        self.inner.get_ref().get_ref().ready(interest).await
118    }
119}
120
121struct Codec {
122    decode_state: DecodeState,
123}
124
125impl Codec {
126    /// Creates a new `Codec`.
127    pub fn new() -> Codec {
128        Codec {
129            decode_state: DecodeState::Head,
130        }
131    }
132}
133
134impl Default for Codec {
135    fn default() -> Codec {
136        Codec::new()
137    }
138}
139
140impl Encoder<BackendMessage> for Codec {
141    type Error = io::Error;
142
143    /// Encode a backend message into `dst`.
144    /// If this function returns an error result, `dst` is left unmodified.
145    fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
146        // Record the starting position so we can truncate on error.
147        // This prevents partial messages from being left in the buffer,
148        // which could be sent to the client and cause "lost synchronization" errors.
149        let start = dst.len();
150        match self.encode_inner(msg, dst) {
151            Ok(()) => Ok(()),
152            Err(e) => {
153                dst.truncate(start);
154                Err(e)
155            }
156        }
157    }
158}
159
160impl Codec {
161    /// This is the meat of the encoding logic. It's a separate function so that errors returned by
162    /// `?` can be handled in the outer `encode` function.
163    fn encode_inner(&self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
164        // Write type byte.
165        let byte = match &msg {
166            BackendMessage::AuthenticationCleartextPassword => b'R',
167            BackendMessage::ErrorResponse(r) => {
168                if r.severity.is_error() {
169                    b'E'
170                } else {
171                    b'N'
172                }
173            }
174        };
175        dst.put_u8(byte);
176
177        // Write message length placeholder. The true length is filled in later.
178        let base = dst.len();
179        dst.put_u32(0);
180
181        // Write message contents.
182        match msg {
183            BackendMessage::AuthenticationCleartextPassword => {
184                dst.put_u32(3);
185            }
186            BackendMessage::ErrorResponse(ErrorResponse {
187                severity,
188                code,
189                message,
190                detail,
191                hint,
192                position,
193            }) => {
194                dst.put_u8(b'S');
195                dst.put_string(severity.as_str());
196                dst.put_u8(b'C');
197                dst.put_string(code.code());
198                dst.put_u8(b'M');
199                dst.put_string(&message);
200                if let Some(detail) = &detail {
201                    dst.put_u8(b'D');
202                    dst.put_string(detail);
203                }
204                if let Some(hint) = &hint {
205                    dst.put_u8(b'H');
206                    dst.put_string(hint);
207                }
208                if let Some(position) = &position {
209                    dst.put_u8(b'P');
210                    dst.put_string(&position.to_string());
211                }
212                dst.put_u8(b'\0');
213            }
214        }
215
216        let len = dst.len() - base;
217
218        // Overwrite length placeholder with true length.
219        let len = i32::try_from(len).map_err(|_| {
220            io::Error::new(
221                io::ErrorKind::InvalidData,
222                "length of encoded message does not fit into an i32",
223            )
224        })?;
225        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
226
227        Ok(())
228    }
229}
230
231impl Decoder for Codec {
232    type Item = FrontendMessage;
233    type Error = io::Error;
234
235    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
236        if src.len() > MAX_REQUEST_SIZE {
237            return Err(io::Error::new(
238                io::ErrorKind::InvalidData,
239                format!(
240                    "request larger than {}",
241                    ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
242                ),
243            ));
244        }
245        loop {
246            match self.decode_state {
247                DecodeState::Head => {
248                    if src.len() < 5 {
249                        return Ok(None);
250                    }
251                    let msg_type = src[0];
252                    let frame_len = parse_frame_len(&src[1..])?;
253                    src.advance(5);
254                    src.reserve(frame_len);
255                    self.decode_state = DecodeState::Data(msg_type, frame_len);
256                }
257
258                DecodeState::Data(msg_type, frame_len) => {
259                    if src.len() < frame_len {
260                        return Ok(None);
261                    }
262                    let buf = src.split_to(frame_len).freeze();
263                    let buf = Cursor::new(&buf);
264                    let msg = match msg_type {
265                        // Termination.
266                        b'X' => decode_terminate(buf)?,
267
268                        // Authentication.
269                        b'p' => decode_password(buf)?,
270
271                        // Invalid.
272                        _ => {
273                            return Err(io::Error::new(
274                                io::ErrorKind::InvalidData,
275                                format!("unknown message type {}", msg_type),
276                            ));
277                        }
278                    };
279                    src.reserve(5);
280                    self.decode_state = DecodeState::Head;
281                    return Ok(Some(msg));
282                }
283            }
284        }
285    }
286}
287
288fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
289    // Nothing more to decode.
290    Ok(FrontendMessage::Terminate)
291}
292
293fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
294    Ok(FrontendMessage::Password {
295        password: buf.read_cstr()?.to_owned(),
296    })
297}