1use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
4use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
5use byteorder::{BigEndian, ByteOrder};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use futures_util::{ready, SinkExt, Stream};
8use pin_project_lite::pin_project;
9use postgres_types::BorrowToSql;
10use std::convert::TryFrom;
11use std::io;
12use std::io::Cursor;
13use std::ops::Range;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
19const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
20
21pin_project! {
22 pub struct BinaryCopyInWriter {
26 #[pin]
27 sink: CopyInSink<Bytes>,
28 types: Vec<Type>,
29 buf: BytesMut,
30 }
31}
32
33impl BinaryCopyInWriter {
34 pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
36 let mut buf = BytesMut::new();
37 buf.put_slice(MAGIC);
38 buf.put_i32(0); buf.put_i32(0); BinaryCopyInWriter {
42 sink,
43 types: types.to_vec(),
44 buf,
45 }
46 }
47
48 pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
54 self.write_raw(slice_iter(values)).await
55 }
56
57 pub async fn write_raw<P, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
63 where
64 P: BorrowToSql,
65 I: IntoIterator<Item = P>,
66 I::IntoIter: ExactSizeIterator,
67 {
68 let mut this = self.project();
69
70 let values = values.into_iter();
71 assert!(
72 values.len() == this.types.len(),
73 "expected {} values but got {}",
74 this.types.len(),
75 values.len(),
76 );
77
78 this.buf.put_i16(this.types.len() as i16);
79
80 for (i, (value, type_)) in values.zip(this.types).enumerate() {
81 let idx = this.buf.len();
82 this.buf.put_i32(0);
83 let len = match value
84 .borrow_to_sql()
85 .to_sql_checked(type_, this.buf)
86 .map_err(|e| Error::to_sql(e, i))?
87 {
88 IsNull::Yes => -1,
89 IsNull::No => i32::try_from(this.buf.len() - idx - 4)
90 .map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
91 };
92 BigEndian::write_i32(&mut this.buf[idx..], len);
93 }
94
95 if this.buf.len() > 4096 {
96 this.sink.send(this.buf.split().freeze()).await?;
97 }
98
99 Ok(())
100 }
101
102 pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
106 let mut this = self.project();
107
108 this.buf.put_i16(-1);
109 this.sink.send(this.buf.split().freeze()).await?;
110 this.sink.finish().await
111 }
112}
113
114struct Header {
115 has_oids: bool,
116}
117
118pin_project! {
119 pub struct BinaryCopyOutStream {
121 #[pin]
122 stream: CopyOutStream,
123 types: Arc<Vec<Type>>,
124 header: Option<Header>,
125 }
126}
127
128impl BinaryCopyOutStream {
129 pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
131 BinaryCopyOutStream {
132 stream,
133 types: Arc::new(types.to_vec()),
134 header: None,
135 }
136 }
137}
138
139impl Stream for BinaryCopyOutStream {
140 type Item = Result<BinaryCopyOutRow, Error>;
141
142 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
143 let this = self.project();
144
145 let chunk = match ready!(this.stream.poll_next(cx)) {
146 Some(Ok(chunk)) => chunk,
147 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
148 None => return Poll::Ready(Some(Err(Error::closed()))),
149 };
150 let mut chunk = Cursor::new(chunk);
151
152 let has_oids = match &this.header {
153 Some(header) => header.has_oids,
154 None => {
155 check_remaining(&chunk, HEADER_LEN)?;
156 if !chunk.chunk().starts_with(MAGIC) {
157 return Poll::Ready(Some(Err(Error::parse(io::Error::new(
158 io::ErrorKind::InvalidData,
159 "invalid magic value",
160 )))));
161 }
162 chunk.advance(MAGIC.len());
163
164 let flags = chunk.get_i32();
165 let has_oids = (flags & (1 << 16)) != 0;
166
167 let header_extension = chunk.get_u32() as usize;
168 check_remaining(&chunk, header_extension)?;
169 chunk.advance(header_extension);
170
171 *this.header = Some(Header { has_oids });
172 has_oids
173 }
174 };
175
176 check_remaining(&chunk, 2)?;
177 let mut len = chunk.get_i16();
178 if len == -1 {
179 return Poll::Ready(None);
180 }
181
182 if has_oids {
183 len += 1;
184 }
185 if len as usize != this.types.len() {
186 return Poll::Ready(Some(Err(Error::parse(io::Error::new(
187 io::ErrorKind::InvalidInput,
188 format!("expected {} values but got {}", this.types.len(), len),
189 )))));
190 }
191
192 let mut ranges = vec![];
193 for _ in 0..len {
194 check_remaining(&chunk, 4)?;
195 let len = chunk.get_i32();
196 if len == -1 {
197 ranges.push(None);
198 } else {
199 let len = len as usize;
200 check_remaining(&chunk, len)?;
201 let start = chunk.position() as usize;
202 ranges.push(Some(start..start + len));
203 chunk.advance(len);
204 }
205 }
206
207 Poll::Ready(Some(Ok(BinaryCopyOutRow {
208 buf: chunk.into_inner(),
209 ranges,
210 types: this.types.clone(),
211 })))
212 }
213}
214
215fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
216 if buf.remaining() < len {
217 Err(Error::parse(io::Error::new(
218 io::ErrorKind::UnexpectedEof,
219 "unexpected EOF",
220 )))
221 } else {
222 Ok(())
223 }
224}
225
226pub struct BinaryCopyOutRow {
228 buf: Bytes,
229 ranges: Vec<Option<Range<usize>>>,
230 types: Arc<Vec<Type>>,
231}
232
233impl BinaryCopyOutRow {
234 pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
236 where
237 T: FromSql<'a>,
238 {
239 let type_ = match self.types.get(idx) {
240 Some(type_) => type_,
241 None => return Err(Error::column(idx.to_string())),
242 };
243
244 if !T::accepts(type_) {
245 return Err(Error::from_sql(
246 Box::new(WrongType::new::<T>(type_.clone())),
247 idx,
248 ));
249 }
250
251 let r = match &self.ranges[idx] {
252 Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
253 None => T::from_sql_null(type_),
254 };
255
256 r.map_err(|e| Error::from_sql(e, idx))
257 }
258
259 pub fn get<'a, T>(&'a self, idx: usize) -> T
265 where
266 T: FromSql<'a>,
267 {
268 match self.try_get(idx) {
269 Ok(value) => value,
270 Err(e) => panic!("error retrieving column {}: {}", idx, e),
271 }
272 }
273}