tokio_postgres/
binary_copy.rs

1//! Utilities for working with the PostgreSQL binary copy format.
2
3use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
4use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
5use byteorder::{BigEndian, ByteOrder};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use futures_util::{ready, SinkExt, Stream};
8use pin_project_lite::pin_project;
9use postgres_types::BorrowToSql;
10use std::convert::TryFrom;
11use std::io;
12use std::io::Cursor;
13use std::ops::Range;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
19const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
20
21pin_project! {
22    /// A type which serializes rows into the PostgreSQL binary copy format.
23    ///
24    /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
25    pub struct BinaryCopyInWriter {
26        #[pin]
27        sink: CopyInSink<Bytes>,
28        types: Vec<Type>,
29        buf: BytesMut,
30    }
31}
32
33impl BinaryCopyInWriter {
34    /// Creates a new writer which will write rows of the provided types to the provided sink.
35    pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
36        let mut buf = BytesMut::new();
37        buf.put_slice(MAGIC);
38        buf.put_i32(0); // flags
39        buf.put_i32(0); // header extension
40
41        BinaryCopyInWriter {
42            sink,
43            types: types.to_vec(),
44            buf,
45        }
46    }
47
48    /// Writes a single row.
49    ///
50    /// # Panics
51    ///
52    /// Panics if the number of values provided does not match the number expected.
53    pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
54        self.write_raw(slice_iter(values)).await
55    }
56
57    /// A maximally-flexible version of `write`.
58    ///
59    /// # Panics
60    ///
61    /// Panics if the number of values provided does not match the number expected.
62    pub async fn write_raw<P, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
63    where
64        P: BorrowToSql,
65        I: IntoIterator<Item = P>,
66        I::IntoIter: ExactSizeIterator,
67    {
68        let mut this = self.project();
69
70        let values = values.into_iter();
71        assert!(
72            values.len() == this.types.len(),
73            "expected {} values but got {}",
74            this.types.len(),
75            values.len(),
76        );
77
78        this.buf.put_i16(this.types.len() as i16);
79
80        for (i, (value, type_)) in values.zip(this.types).enumerate() {
81            let idx = this.buf.len();
82            this.buf.put_i32(0);
83            let len = match value
84                .borrow_to_sql()
85                .to_sql_checked(type_, this.buf)
86                .map_err(|e| Error::to_sql(e, i))?
87            {
88                IsNull::Yes => -1,
89                IsNull::No => i32::try_from(this.buf.len() - idx - 4)
90                    .map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
91            };
92            BigEndian::write_i32(&mut this.buf[idx..], len);
93        }
94
95        if this.buf.len() > 4096 {
96            this.sink.send(this.buf.split().freeze()).await?;
97        }
98
99        Ok(())
100    }
101
102    /// Completes the copy, returning the number of rows added.
103    ///
104    /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
105    pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
106        let mut this = self.project();
107
108        this.buf.put_i16(-1);
109        this.sink.send(this.buf.split().freeze()).await?;
110        this.sink.finish().await
111    }
112}
113
114struct Header {
115    has_oids: bool,
116}
117
118pin_project! {
119    /// A stream of rows deserialized from the PostgreSQL binary copy format.
120    pub struct BinaryCopyOutStream {
121        #[pin]
122        stream: CopyOutStream,
123        types: Arc<Vec<Type>>,
124        header: Option<Header>,
125    }
126}
127
128impl BinaryCopyOutStream {
129    /// Creates a stream from a raw copy out stream and the types of the columns being returned.
130    pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
131        BinaryCopyOutStream {
132            stream,
133            types: Arc::new(types.to_vec()),
134            header: None,
135        }
136    }
137}
138
139impl Stream for BinaryCopyOutStream {
140    type Item = Result<BinaryCopyOutRow, Error>;
141
142    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
143        let this = self.project();
144
145        let chunk = match ready!(this.stream.poll_next(cx)) {
146            Some(Ok(chunk)) => chunk,
147            Some(Err(e)) => return Poll::Ready(Some(Err(e))),
148            None => return Poll::Ready(Some(Err(Error::closed()))),
149        };
150        let mut chunk = Cursor::new(chunk);
151
152        let has_oids = match &this.header {
153            Some(header) => header.has_oids,
154            None => {
155                check_remaining(&chunk, HEADER_LEN)?;
156                if !chunk.chunk().starts_with(MAGIC) {
157                    return Poll::Ready(Some(Err(Error::parse(io::Error::new(
158                        io::ErrorKind::InvalidData,
159                        "invalid magic value",
160                    )))));
161                }
162                chunk.advance(MAGIC.len());
163
164                let flags = chunk.get_i32();
165                let has_oids = (flags & (1 << 16)) != 0;
166
167                let header_extension = chunk.get_u32() as usize;
168                check_remaining(&chunk, header_extension)?;
169                chunk.advance(header_extension);
170
171                *this.header = Some(Header { has_oids });
172                has_oids
173            }
174        };
175
176        check_remaining(&chunk, 2)?;
177        let mut len = chunk.get_i16();
178        if len == -1 {
179            return Poll::Ready(None);
180        }
181
182        if has_oids {
183            len += 1;
184        }
185        if len as usize != this.types.len() {
186            return Poll::Ready(Some(Err(Error::parse(io::Error::new(
187                io::ErrorKind::InvalidInput,
188                format!("expected {} values but got {}", this.types.len(), len),
189            )))));
190        }
191
192        let mut ranges = vec![];
193        for _ in 0..len {
194            check_remaining(&chunk, 4)?;
195            let len = chunk.get_i32();
196            if len == -1 {
197                ranges.push(None);
198            } else {
199                let len = len as usize;
200                check_remaining(&chunk, len)?;
201                let start = chunk.position() as usize;
202                ranges.push(Some(start..start + len));
203                chunk.advance(len);
204            }
205        }
206
207        Poll::Ready(Some(Ok(BinaryCopyOutRow {
208            buf: chunk.into_inner(),
209            ranges,
210            types: this.types.clone(),
211        })))
212    }
213}
214
215fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
216    if buf.remaining() < len {
217        Err(Error::parse(io::Error::new(
218            io::ErrorKind::UnexpectedEof,
219            "unexpected EOF",
220        )))
221    } else {
222        Ok(())
223    }
224}
225
226/// A row of data parsed from a binary copy out stream.
227pub struct BinaryCopyOutRow {
228    buf: Bytes,
229    ranges: Vec<Option<Range<usize>>>,
230    types: Arc<Vec<Type>>,
231}
232
233impl BinaryCopyOutRow {
234    /// Like `get`, but returns a `Result` rather than panicking.
235    pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
236    where
237        T: FromSql<'a>,
238    {
239        let type_ = match self.types.get(idx) {
240            Some(type_) => type_,
241            None => return Err(Error::column(idx.to_string())),
242        };
243
244        if !T::accepts(type_) {
245            return Err(Error::from_sql(
246                Box::new(WrongType::new::<T>(type_.clone())),
247                idx,
248            ));
249        }
250
251        let r = match &self.ranges[idx] {
252            Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
253            None => T::from_sql_null(type_),
254        };
255
256        r.map_err(|e| Error::from_sql(e, idx))
257    }
258
259    /// Deserializes a value from the row.
260    ///
261    /// # Panics
262    ///
263    /// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
264    pub fn get<'a, T>(&'a self, idx: usize) -> T
265    where
266        T: FromSql<'a>,
267    {
268        match self.try_get(idx) {
269            Ok(value) => value,
270            Err(e) => panic!("error retrieving column {}: {}", idx, e),
271        }
272    }
273}