tokio_postgres/
codec.rs

1use bytes::{Buf, Bytes, BytesMut};
2use fallible_iterator::FallibleIterator;
3use postgres_protocol::message::backend;
4use postgres_protocol::message::frontend::CopyData;
5use std::io;
6use tokio_util::codec::{Decoder, Encoder};
7
8pub enum FrontendMessage {
9    Raw(Bytes),
10    CopyData(CopyData<Box<dyn Buf + Send>>),
11}
12
13pub enum BackendMessage {
14    Normal {
15        messages: BackendMessages,
16        request_complete: bool,
17    },
18    Async(backend::Message),
19}
20
21pub struct BackendMessages(BytesMut);
22
23impl BackendMessages {
24    pub fn empty() -> BackendMessages {
25        BackendMessages(BytesMut::new())
26    }
27}
28
29impl FallibleIterator for BackendMessages {
30    type Item = backend::Message;
31    type Error = io::Error;
32
33    fn next(&mut self) -> io::Result<Option<backend::Message>> {
34        backend::Message::parse(&mut self.0)
35    }
36}
37
38pub struct PostgresCodec;
39
40impl Encoder<FrontendMessage> for PostgresCodec {
41    type Error = io::Error;
42
43    fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
44        match item {
45            FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
46            FrontendMessage::CopyData(data) => data.write(dst),
47        }
48
49        Ok(())
50    }
51}
52
53impl Decoder for PostgresCodec {
54    type Item = BackendMessage;
55    type Error = io::Error;
56
57    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
58        let mut idx = 0;
59        let mut request_complete = false;
60
61        while let Some(header) = backend::Header::parse(&src[idx..])? {
62            let len = header.len() as usize + 1;
63            if src[idx..].len() < len {
64                break;
65            }
66
67            match header.tag() {
68                backend::NOTICE_RESPONSE_TAG
69                | backend::NOTIFICATION_RESPONSE_TAG
70                | backend::PARAMETER_STATUS_TAG => {
71                    if idx == 0 {
72                        let message = backend::Message::parse(src)?.unwrap();
73                        return Ok(Some(BackendMessage::Async(message)));
74                    } else {
75                        break;
76                    }
77                }
78                _ => {}
79            }
80
81            idx += len;
82
83            if header.tag() == backend::READY_FOR_QUERY_TAG {
84                request_complete = true;
85                break;
86            }
87        }
88
89        if idx == 0 {
90            Ok(None)
91        } else {
92            Ok(Some(BackendMessage::Normal {
93                messages: BackendMessages(src.split_to(idx)),
94                request_complete,
95            }))
96        }
97    }
98}