1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::{simple_query, Error};
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use futures_channel::mpsc;
6use futures_util::{ready, Sink, SinkExt, Stream, StreamExt};
7use log::debug;
8use pin_project_lite::pin_project;
9use postgres_protocol::message::backend::Message;
10use postgres_protocol::message::frontend;
11use postgres_protocol::message::frontend::CopyData;
12use std::marker::{PhantomData, PhantomPinned};
13use std::pin::Pin;
14use std::task::{Context, Poll};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37enum CopyBothState {
38 Setup,
40 CopyBoth,
42 CopyIn,
44 CopyOut,
46 CopyNone,
48 CopyComplete,
50 CommandComplete,
52}
53
54pub struct CopyBothReceiver {
68 responses: Responses,
70 sink_receiver: mpsc::Receiver<FrontendMessage>,
72 stream_sender: mpsc::Sender<Result<Message, Error>>,
74 state: CopyBothState,
76 buffered_message: Option<Result<Message, Error>>,
78}
79
80impl CopyBothReceiver {
81 pub(crate) fn new(
82 responses: Responses,
83 sink_receiver: mpsc::Receiver<FrontendMessage>,
84 stream_sender: mpsc::Sender<Result<Message, Error>>,
85 ) -> CopyBothReceiver {
86 CopyBothReceiver {
87 responses,
88 sink_receiver,
89 stream_sender,
90 state: CopyBothState::Setup,
91 buffered_message: None,
92 }
93 }
94
95 fn unexpected_message(&mut self) {
97 self.sink_receiver.close();
98 self.buffered_message = Some(Err(Error::unexpected_message()));
99 self.state = CopyBothState::CommandComplete;
100 }
101
102 fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll<()> {
105 use CopyBothState::*;
106
107 loop {
108 if let Some(message) = self.buffered_message.take() {
111 match self.stream_sender.poll_ready(cx) {
112 Poll::Ready(_) => {
113 let _ = self.stream_sender.start_send(message);
115 }
116 Poll::Pending => {
117 self.buffered_message = Some(message);
119 return Poll::Pending;
120 }
121 }
122 }
123
124 match ready!(self.responses.poll_next_unpin(cx)) {
125 Some(Ok(Message::CopyBothResponse(body))) => match self.state {
126 Setup => {
127 self.buffered_message = Some(Ok(Message::CopyBothResponse(body)));
128 self.state = CopyBoth;
129 }
130 _ => self.unexpected_message(),
131 },
132 Some(Ok(Message::CopyData(body))) => match self.state {
133 CopyBoth | CopyOut => {
134 self.buffered_message = Some(Ok(Message::CopyData(body)));
135 }
136 _ => self.unexpected_message(),
137 },
138 Some(Ok(Message::CopyDone)) => {
140 match self.state {
141 CopyBoth => self.state = CopyIn,
142 CopyOut => self.state = CopyNone,
143 _ => self.unexpected_message(),
144 };
145 }
146 Some(Ok(Message::CommandComplete(_))) => {
147 match self.state {
148 CopyNone => self.state = CopyComplete,
149 CopyComplete => {
150 self.stream_sender.close_channel();
151 self.sink_receiver.close();
152 self.state = CommandComplete;
153 }
154 _ => self.unexpected_message(),
155 };
156 }
157 Some(Err(err)) => {
159 match self.state {
160 Setup | CopyBoth | CopyOut | CopyIn => {
161 self.sink_receiver.close();
162 self.buffered_message = Some(Err(err));
163 self.state = CommandComplete;
164 }
165 _ => self.unexpected_message(),
166 };
167 }
168 Some(Ok(Message::ReadyForQuery(_))) => match self.state {
169 CommandComplete => {
170 self.sink_receiver.close();
171 self.stream_sender.close_channel();
172 }
173 _ => self.unexpected_message(),
174 },
175 Some(Ok(_)) => self.unexpected_message(),
176 None => return Poll::Ready(()),
177 }
178 }
179 }
180}
181
182impl Stream for CopyBothReceiver {
185 type Item = FrontendMessage;
186
187 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
188 use CopyBothState::*;
189
190 match self.poll_backend(cx) {
191 Poll::Ready(()) => Poll::Ready(None),
192 Poll::Pending => match self.state {
193 Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) {
194 Some(msg) => Poll::Ready(Some(msg)),
195 None => match self.state {
196 Setup => Poll::Pending,
202 CopyBoth => {
203 self.state = CopyOut;
204 let mut buf = BytesMut::new();
205 frontend::copy_done(&mut buf);
206 Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
207 }
208 CopyIn => {
209 self.state = CopyNone;
210 let mut buf = BytesMut::new();
211 frontend::copy_done(&mut buf);
212 Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
213 }
214 _ => unreachable!(),
215 },
216 },
217 _ => Poll::Pending,
218 },
219 }
220 }
221}
222
223pin_project! {
224 pub struct CopyBothDuplex<T> {
255 #[pin]
256 sink_sender: mpsc::Sender<FrontendMessage>,
257 #[pin]
258 stream_receiver: mpsc::Receiver<Result<Message, Error>>,
259 buf: BytesMut,
260 #[pin]
261 _p: PhantomPinned,
262 _p2: PhantomData<T>,
263 }
264}
265
266impl<T> Stream for CopyBothDuplex<T> {
267 type Item = Result<Bytes, Error>;
268
269 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270 Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) {
271 Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())),
272 Some(Ok(_)) => Some(Err(Error::unexpected_message())),
273 Some(Err(err)) => Some(Err(err)),
274 None => None,
275 })
276 }
277}
278
279impl<T> Sink<T> for CopyBothDuplex<T>
280where
281 T: Buf + 'static + Send,
282{
283 type Error = Error;
284
285 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
286 self.project()
287 .sink_sender
288 .poll_ready(cx)
289 .map_err(|_| Error::closed())
290 }
291
292 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
293 let this = self.project();
294
295 let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
296 if this.buf.is_empty() {
297 Box::new(item)
298 } else {
299 Box::new(this.buf.split().freeze().chain(item))
300 }
301 } else {
302 this.buf.put(item);
303 if this.buf.len() > 4096 {
304 Box::new(this.buf.split().freeze())
305 } else {
306 return Ok(());
307 }
308 };
309
310 let data = CopyData::new(data).map_err(Error::encode)?;
311 this.sink_sender
312 .start_send(FrontendMessage::CopyData(data))
313 .map_err(|_| Error::closed())
314 }
315
316 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
317 let mut this = self.project();
318
319 if !this.buf.is_empty() {
320 ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
321 let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
322 let data = CopyData::new(data).map_err(Error::encode)?;
323 this.sink_sender
324 .as_mut()
325 .start_send(FrontendMessage::CopyData(data))
326 .map_err(|_| Error::closed())?;
327 }
328
329 this.sink_sender.poll_flush(cx).map_err(|_| Error::closed())
330 }
331
332 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
333 ready!(self.as_mut().poll_flush(cx))?;
334 let mut this = self.as_mut().project();
335 this.sink_sender.disconnect();
336 Poll::Ready(Ok(()))
337 }
338}
339
340pub async fn copy_both_simple<T>(
341 client: &InnerClient,
342 query: &str,
343) -> Result<CopyBothDuplex<T>, Error>
344where
345 T: Buf + 'static + Send,
346{
347 debug!("executing copy both query {}", query);
348
349 let buf = simple_query::encode(client, query)?;
350
351 let mut handles = client.start_copy_both()?;
352
353 handles
354 .sink_sender
355 .send(FrontendMessage::Raw(buf))
356 .await
357 .map_err(|_| Error::closed())?;
358
359 match handles.stream_receiver.next().await.transpose()? {
360 Some(Message::CopyBothResponse(_)) => {}
361 _ => return Err(Error::unexpected_message()),
362 }
363
364 Ok(CopyBothDuplex {
365 stream_receiver: handles.stream_receiver,
366 sink_sender: handles.sink_sender,
367 buf: BytesMut::new(),
368 _p: PhantomPinned,
369 _p2: PhantomData,
370 })
371}