tokio_postgres/
copy_both.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::{simple_query, Error};
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use futures_channel::mpsc;
6use futures_util::{ready, Sink, SinkExt, Stream, StreamExt};
7use log::debug;
8use pin_project_lite::pin_project;
9use postgres_protocol::message::backend::Message;
10use postgres_protocol::message::frontend;
11use postgres_protocol::message::frontend::CopyData;
12use std::marker::{PhantomData, PhantomPinned};
13use std::pin::Pin;
14use std::task::{Context, Poll};
15
16/// The state machine of CopyBothReceiver
17///
18/// ```ignore
19///       Setup
20///         |
21///         v
22///      CopyBoth
23///       /   \
24///      v     v
25///  CopyOut  CopyIn
26///       \   /
27///        v v
28///      CopyNone
29///         |
30///         v
31///    CopyComplete
32///         |
33///         v
34///   CommandComplete
35/// ```
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37enum CopyBothState {
38    /// The state before having entered the CopyBoth mode.
39    Setup,
40    /// Initial state where CopyData messages can go in both directions
41    CopyBoth,
42    /// The server->client stream is closed and we're in CopyIn mode
43    CopyIn,
44    /// The client->server stream is closed and we're in CopyOut mode
45    CopyOut,
46    /// Both directions are closed, we waiting for CommandComplete messages
47    CopyNone,
48    /// We have received the first CommandComplete message for the copy
49    CopyComplete,
50    /// We have received the final CommandComplete message for the statement
51    CommandComplete,
52}
53
54/// A CopyBothReceiver is responsible for handling the CopyBoth subprotocol. It ensures that no
55/// matter what the users do with their CopyBothDuplex handle we're always going to send the
56/// correct messages to the backend in order to restore the connection into a usable state.
57///
58/// ```ignore
59///                                          |
60///          <tokio_postgres owned>          |    <userland owned>
61///                                          |
62///  pg -> Connection -> CopyBothReceiver ---+---> CopyBothDuplex
63///                                          |          ^   \
64///                                          |         /     v
65///                                          |      Sink    Stream
66/// ```
67pub struct CopyBothReceiver {
68    /// Receiver of backend messages from the underlying [Connection](crate::Connection)
69    responses: Responses,
70    /// Receiver of frontend messages sent by the user using <CopyBothDuplex as Sink>
71    sink_receiver: mpsc::Receiver<FrontendMessage>,
72    /// Sender of CopyData contents to be consumed by the user using <CopyBothDuplex as Stream>
73    stream_sender: mpsc::Sender<Result<Message, Error>>,
74    /// The current state of the subprotocol
75    state: CopyBothState,
76    /// Holds a buffered message until we are ready to send it to the user's stream
77    buffered_message: Option<Result<Message, Error>>,
78}
79
80impl CopyBothReceiver {
81    pub(crate) fn new(
82        responses: Responses,
83        sink_receiver: mpsc::Receiver<FrontendMessage>,
84        stream_sender: mpsc::Sender<Result<Message, Error>>,
85    ) -> CopyBothReceiver {
86        CopyBothReceiver {
87            responses,
88            sink_receiver,
89            stream_sender,
90            state: CopyBothState::Setup,
91            buffered_message: None,
92        }
93    }
94
95    /// Convenience method to set the subprotocol into an unexpected message state
96    fn unexpected_message(&mut self) {
97        self.sink_receiver.close();
98        self.buffered_message = Some(Err(Error::unexpected_message()));
99        self.state = CopyBothState::CommandComplete;
100    }
101
102    /// Processes messages from the backend, it will resolve once all backend messages have been
103    /// processed
104    fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll<()> {
105        use CopyBothState::*;
106
107        loop {
108            // Deliver the buffered message (if any) to the user to ensure we can potentially
109            // buffer a new one in response to a server message
110            if let Some(message) = self.buffered_message.take() {
111                match self.stream_sender.poll_ready(cx) {
112                    Poll::Ready(_) => {
113                        // If the receiver has hung up we'll just drop the message
114                        let _ = self.stream_sender.start_send(message);
115                    }
116                    Poll::Pending => {
117                        // Stash the message and try again later
118                        self.buffered_message = Some(message);
119                        return Poll::Pending;
120                    }
121                }
122            }
123
124            match ready!(self.responses.poll_next_unpin(cx)) {
125                Some(Ok(Message::CopyBothResponse(body))) => match self.state {
126                    Setup => {
127                        self.buffered_message = Some(Ok(Message::CopyBothResponse(body)));
128                        self.state = CopyBoth;
129                    }
130                    _ => self.unexpected_message(),
131                },
132                Some(Ok(Message::CopyData(body))) => match self.state {
133                    CopyBoth | CopyOut => {
134                        self.buffered_message = Some(Ok(Message::CopyData(body)));
135                    }
136                    _ => self.unexpected_message(),
137                },
138                // The server->client stream is done
139                Some(Ok(Message::CopyDone)) => {
140                    match self.state {
141                        CopyBoth => self.state = CopyIn,
142                        CopyOut => self.state = CopyNone,
143                        _ => self.unexpected_message(),
144                    };
145                }
146                Some(Ok(Message::CommandComplete(_))) => {
147                    match self.state {
148                        CopyNone => self.state = CopyComplete,
149                        CopyComplete => {
150                            self.stream_sender.close_channel();
151                            self.sink_receiver.close();
152                            self.state = CommandComplete;
153                        }
154                        _ => self.unexpected_message(),
155                    };
156                }
157                // The server indicated an error, terminate our side if we haven't already
158                Some(Err(err)) => {
159                    match self.state {
160                        Setup | CopyBoth | CopyOut | CopyIn => {
161                            self.sink_receiver.close();
162                            self.buffered_message = Some(Err(err));
163                            self.state = CommandComplete;
164                        }
165                        _ => self.unexpected_message(),
166                    };
167                }
168                Some(Ok(Message::ReadyForQuery(_))) => match self.state {
169                    CommandComplete => {
170                        self.sink_receiver.close();
171                        self.stream_sender.close_channel();
172                    }
173                    _ => self.unexpected_message(),
174                },
175                Some(Ok(_)) => self.unexpected_message(),
176                None => return Poll::Ready(()),
177            }
178        }
179    }
180}
181
182/// The [Connection](crate::Connection) will keep polling this stream until it is exhausted. This
183/// is the mechanism that drives the CopyBoth subprotocol forward
184impl Stream for CopyBothReceiver {
185    type Item = FrontendMessage;
186
187    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
188        use CopyBothState::*;
189
190        match self.poll_backend(cx) {
191            Poll::Ready(()) => Poll::Ready(None),
192            Poll::Pending => match self.state {
193                Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) {
194                    Some(msg) => Poll::Ready(Some(msg)),
195                    None => match self.state {
196                        // The user has cancelled their interest to this CopyBoth query but we're
197                        // still in the Setup phase. From this point the receiver will either enter
198                        // CopyBoth mode or will receive an Error response from PostgreSQL. When
199                        // either of those happens the state machine will terminate the connection
200                        // appropriately.
201                        Setup => Poll::Pending,
202                        CopyBoth => {
203                            self.state = CopyOut;
204                            let mut buf = BytesMut::new();
205                            frontend::copy_done(&mut buf);
206                            Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
207                        }
208                        CopyIn => {
209                            self.state = CopyNone;
210                            let mut buf = BytesMut::new();
211                            frontend::copy_done(&mut buf);
212                            Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
213                        }
214                        _ => unreachable!(),
215                    },
216                },
217                _ => Poll::Pending,
218            },
219        }
220    }
221}
222
223pin_project! {
224    /// A duplex stream for consuming streaming replication data.
225    ///
226    /// Users should ensure that CopyBothDuplex is dropped before attempting to await on a new
227    /// query. This will ensure that the connection returns into normal processing mode.
228    ///
229    /// ```no_run
230    /// use tokio_postgres::Client;
231    ///
232    /// async fn foo(client: &Client) {
233    ///   let duplex_stream = client.copy_both_simple::<&[u8]>("..").await;
234    ///
235    ///   // ⚠️ INCORRECT ⚠️
236    ///   client.query("SELECT 1", &[]).await; // hangs forever
237    ///
238    ///   // duplex_stream drop-ed here
239    /// }
240    /// ```
241    ///
242    /// ```no_run
243    /// use tokio_postgres::Client;
244    ///
245    /// async fn foo(client: &Client) {
246    ///   let duplex_stream = client.copy_both_simple::<&[u8]>("..").await;
247    ///
248    ///   // ✅ CORRECT ✅
249    ///   drop(duplex_stream);
250    ///
251    ///   client.query("SELECT 1", &[]).await;
252    /// }
253    /// ```
254    pub struct CopyBothDuplex<T> {
255        #[pin]
256        sink_sender: mpsc::Sender<FrontendMessage>,
257        #[pin]
258        stream_receiver: mpsc::Receiver<Result<Message, Error>>,
259        buf: BytesMut,
260        #[pin]
261        _p: PhantomPinned,
262        _p2: PhantomData<T>,
263    }
264}
265
266impl<T> Stream for CopyBothDuplex<T> {
267    type Item = Result<Bytes, Error>;
268
269    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270        Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) {
271            Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())),
272            Some(Ok(_)) => Some(Err(Error::unexpected_message())),
273            Some(Err(err)) => Some(Err(err)),
274            None => None,
275        })
276    }
277}
278
279impl<T> Sink<T> for CopyBothDuplex<T>
280where
281    T: Buf + 'static + Send,
282{
283    type Error = Error;
284
285    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
286        self.project()
287            .sink_sender
288            .poll_ready(cx)
289            .map_err(|_| Error::closed())
290    }
291
292    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
293        let this = self.project();
294
295        let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
296            if this.buf.is_empty() {
297                Box::new(item)
298            } else {
299                Box::new(this.buf.split().freeze().chain(item))
300            }
301        } else {
302            this.buf.put(item);
303            if this.buf.len() > 4096 {
304                Box::new(this.buf.split().freeze())
305            } else {
306                return Ok(());
307            }
308        };
309
310        let data = CopyData::new(data).map_err(Error::encode)?;
311        this.sink_sender
312            .start_send(FrontendMessage::CopyData(data))
313            .map_err(|_| Error::closed())
314    }
315
316    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
317        let mut this = self.project();
318
319        if !this.buf.is_empty() {
320            ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
321            let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
322            let data = CopyData::new(data).map_err(Error::encode)?;
323            this.sink_sender
324                .as_mut()
325                .start_send(FrontendMessage::CopyData(data))
326                .map_err(|_| Error::closed())?;
327        }
328
329        this.sink_sender.poll_flush(cx).map_err(|_| Error::closed())
330    }
331
332    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
333        ready!(self.as_mut().poll_flush(cx))?;
334        let mut this = self.as_mut().project();
335        this.sink_sender.disconnect();
336        Poll::Ready(Ok(()))
337    }
338}
339
340pub async fn copy_both_simple<T>(
341    client: &InnerClient,
342    query: &str,
343) -> Result<CopyBothDuplex<T>, Error>
344where
345    T: Buf + 'static + Send,
346{
347    debug!("executing copy both query {}", query);
348
349    let buf = simple_query::encode(client, query)?;
350
351    let mut handles = client.start_copy_both()?;
352
353    handles
354        .sink_sender
355        .send(FrontendMessage::Raw(buf))
356        .await
357        .map_err(|_| Error::closed())?;
358
359    match handles.stream_receiver.next().await.transpose()? {
360        Some(Message::CopyBothResponse(_)) => {}
361        _ => return Err(Error::unexpected_message()),
362    }
363
364    Ok(CopyBothDuplex {
365        stream_receiver: handles.stream_receiver,
366        sink_sender: handles.sink_sender,
367        buf: BytesMut::new(),
368        _p: PhantomPinned,
369        _p2: PhantomData,
370    })
371}