mz_balancerd/
codec.rs
1use async_trait::async_trait;
11use bytes::{Buf, BufMut, BytesMut};
12use bytesize::ByteSize;
13use futures::{SinkExt, TryStreamExt, sink};
14use mz_ore::cast::CastFrom;
15use mz_ore::future::OreSinkExt;
16use mz_ore::netio::AsyncReady;
17use mz_pgwire_common::{
18 Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf,
19 parse_frame_len,
20};
21use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
22use tokio_util::codec::{Decoder, Encoder, Framed};
23
24#[derive(Debug)]
28pub enum BackendMessage {
29 AuthenticationCleartextPassword,
30 ErrorResponse(ErrorResponse),
31}
32
33impl From<ErrorResponse> for BackendMessage {
34 fn from(err: ErrorResponse) -> BackendMessage {
35 BackendMessage::ErrorResponse(err)
36 }
37}
38
39pub struct FramedConn<A> {
41 inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
42}
43
44impl<A> FramedConn<A>
45where
46 A: AsyncRead + AsyncWrite + Unpin,
47{
48 pub fn new(inner: Conn<A>) -> FramedConn<A> {
54 FramedConn {
55 inner: Framed::new(inner, Codec::new()).buffer(32),
56 }
57 }
58
59 pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
73 let message = self.inner.try_next().await?;
74 Ok(message)
75 }
76
77 pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
86 where
87 M: Into<BackendMessage>,
88 {
89 let message = message.into();
90 self.inner.enqueue(message).await
91 }
92
93 pub async fn flush(&mut self) -> Result<(), io::Error> {
95 self.inner.flush().await
96 }
97}
98
99impl<A> FramedConn<A>
100where
101 A: AsyncRead + AsyncWrite + Unpin,
102{
103 pub fn inner(&self) -> &Conn<A> {
104 self.inner.get_ref().get_ref()
105 }
106 pub fn inner_mut(&mut self) -> &mut Conn<A> {
107 self.inner.get_mut().get_mut()
108 }
109}
110
111#[async_trait]
112impl<A> AsyncReady for FramedConn<A>
113where
114 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
115{
116 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
117 self.inner.get_ref().get_ref().ready(interest).await
118 }
119}
120
121struct Codec {
122 decode_state: DecodeState,
123}
124
125impl Codec {
126 pub fn new() -> Codec {
128 Codec {
129 decode_state: DecodeState::Head,
130 }
131 }
132}
133
134impl Default for Codec {
135 fn default() -> Codec {
136 Codec::new()
137 }
138}
139
140impl Encoder<BackendMessage> for Codec {
141 type Error = io::Error;
142
143 fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
144 let byte = match &msg {
146 BackendMessage::AuthenticationCleartextPassword => b'R',
147 BackendMessage::ErrorResponse(r) => {
148 if r.severity.is_error() {
149 b'E'
150 } else {
151 b'N'
152 }
153 }
154 };
155 dst.put_u8(byte);
156
157 let base = dst.len();
159 dst.put_u32(0);
160
161 match msg {
163 BackendMessage::AuthenticationCleartextPassword => {
164 dst.put_u32(3);
165 }
166 BackendMessage::ErrorResponse(ErrorResponse {
167 severity,
168 code,
169 message,
170 detail,
171 hint,
172 position,
173 }) => {
174 dst.put_u8(b'S');
175 dst.put_string(severity.as_str());
176 dst.put_u8(b'C');
177 dst.put_string(code.code());
178 dst.put_u8(b'M');
179 dst.put_string(&message);
180 if let Some(detail) = &detail {
181 dst.put_u8(b'D');
182 dst.put_string(detail);
183 }
184 if let Some(hint) = &hint {
185 dst.put_u8(b'H');
186 dst.put_string(hint);
187 }
188 if let Some(position) = &position {
189 dst.put_u8(b'P');
190 dst.put_string(&position.to_string());
191 }
192 dst.put_u8(b'\0');
193 }
194 }
195
196 let len = dst.len() - base;
197
198 let len = i32::try_from(len).map_err(|_| {
200 io::Error::new(
201 io::ErrorKind::Other,
202 "length of encoded message does not fit into an i32",
203 )
204 })?;
205 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
206
207 Ok(())
208 }
209}
210
211impl Decoder for Codec {
212 type Item = FrontendMessage;
213 type Error = io::Error;
214
215 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
216 if src.len() > MAX_REQUEST_SIZE {
217 return Err(io::Error::new(
218 io::ErrorKind::InvalidData,
219 format!(
220 "request larger than {}",
221 ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
222 ),
223 ));
224 }
225 loop {
226 match self.decode_state {
227 DecodeState::Head => {
228 if src.len() < 5 {
229 return Ok(None);
230 }
231 let msg_type = src[0];
232 let frame_len = parse_frame_len(&src[1..])?;
233 src.advance(5);
234 src.reserve(frame_len);
235 self.decode_state = DecodeState::Data(msg_type, frame_len);
236 }
237
238 DecodeState::Data(msg_type, frame_len) => {
239 if src.len() < frame_len {
240 return Ok(None);
241 }
242 let buf = src.split_to(frame_len).freeze();
243 let buf = Cursor::new(&buf);
244 let msg = match msg_type {
245 b'X' => decode_terminate(buf)?,
247
248 b'p' => decode_password(buf)?,
250
251 _ => {
253 return Err(io::Error::new(
254 io::ErrorKind::InvalidData,
255 format!("unknown message type {}", msg_type),
256 ));
257 }
258 };
259 src.reserve(5);
260 self.decode_state = DecodeState::Head;
261 return Ok(Some(msg));
262 }
263 }
264 }
265 }
266}
267
268fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
269 Ok(FrontendMessage::Terminate)
271}
272
273fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
274 Ok(FrontendMessage::Password {
275 password: buf.read_cstr()?.to_owned(),
276 })
277}