1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::{simple_query, Error};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_channel::mpsc;
use futures_util::{ready, Sink, SinkExt, Stream, StreamExt};
use log::debug;
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::CopyData;
use std::marker::{PhantomData, PhantomPinned};
use std::pin::Pin;
use std::task::{Context, Poll};
/// The state machine of CopyBothReceiver
///
/// ```ignore
/// Setup
/// |
/// v
/// CopyBoth
/// / \
/// v v
/// CopyOut CopyIn
/// \ /
/// v v
/// CopyNone
/// |
/// v
/// CopyComplete
/// |
/// v
/// CommandComplete
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CopyBothState {
/// The state before having entered the CopyBoth mode.
Setup,
/// Initial state where CopyData messages can go in both directions
CopyBoth,
/// The server->client stream is closed and we're in CopyIn mode
CopyIn,
/// The client->server stream is closed and we're in CopyOut mode
CopyOut,
/// Both directions are closed, we waiting for CommandComplete messages
CopyNone,
/// We have received the first CommandComplete message for the copy
CopyComplete,
/// We have received the final CommandComplete message for the statement
CommandComplete,
}
/// A CopyBothReceiver is responsible for handling the CopyBoth subprotocol. It ensures that no
/// matter what the users do with their CopyBothDuplex handle we're always going to send the
/// correct messages to the backend in order to restore the connection into a usable state.
///
/// ```ignore
/// |
/// <tokio_postgres owned> | <userland owned>
/// |
/// pg -> Connection -> CopyBothReceiver ---+---> CopyBothDuplex
/// | ^ \
/// | / v
/// | Sink Stream
/// ```
pub struct CopyBothReceiver {
/// Receiver of backend messages from the underlying [Connection](crate::Connection)
responses: Responses,
/// Receiver of frontend messages sent by the user using <CopyBothDuplex as Sink>
sink_receiver: mpsc::Receiver<FrontendMessage>,
/// Sender of CopyData contents to be consumed by the user using <CopyBothDuplex as Stream>
stream_sender: mpsc::Sender<Result<Message, Error>>,
/// The current state of the subprotocol
state: CopyBothState,
/// Holds a buffered message until we are ready to send it to the user's stream
buffered_message: Option<Result<Message, Error>>,
}
impl CopyBothReceiver {
pub(crate) fn new(
responses: Responses,
sink_receiver: mpsc::Receiver<FrontendMessage>,
stream_sender: mpsc::Sender<Result<Message, Error>>,
) -> CopyBothReceiver {
CopyBothReceiver {
responses,
sink_receiver,
stream_sender,
state: CopyBothState::Setup,
buffered_message: None,
}
}
/// Convenience method to set the subprotocol into an unexpected message state
fn unexpected_message(&mut self) {
self.sink_receiver.close();
self.buffered_message = Some(Err(Error::unexpected_message()));
self.state = CopyBothState::CommandComplete;
}
/// Processes messages from the backend, it will resolve once all backend messages have been
/// processed
fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll<()> {
use CopyBothState::*;
loop {
// Deliver the buffered message (if any) to the user to ensure we can potentially
// buffer a new one in response to a server message
if let Some(message) = self.buffered_message.take() {
match self.stream_sender.poll_ready(cx) {
Poll::Ready(_) => {
// If the receiver has hung up we'll just drop the message
let _ = self.stream_sender.start_send(message);
}
Poll::Pending => {
// Stash the message and try again later
self.buffered_message = Some(message);
return Poll::Pending;
}
}
}
match ready!(self.responses.poll_next_unpin(cx)) {
Some(Ok(Message::CopyBothResponse(body))) => match self.state {
Setup => {
self.buffered_message = Some(Ok(Message::CopyBothResponse(body)));
self.state = CopyBoth;
}
_ => self.unexpected_message(),
},
Some(Ok(Message::CopyData(body))) => match self.state {
CopyBoth | CopyOut => {
self.buffered_message = Some(Ok(Message::CopyData(body)));
}
_ => self.unexpected_message(),
},
// The server->client stream is done
Some(Ok(Message::CopyDone)) => {
match self.state {
CopyBoth => self.state = CopyIn,
CopyOut => self.state = CopyNone,
_ => self.unexpected_message(),
};
}
Some(Ok(Message::CommandComplete(_))) => {
match self.state {
CopyNone => self.state = CopyComplete,
CopyComplete => {
self.stream_sender.close_channel();
self.sink_receiver.close();
self.state = CommandComplete;
}
_ => self.unexpected_message(),
};
}
// The server indicated an error, terminate our side if we haven't already
Some(Err(err)) => {
match self.state {
Setup | CopyBoth | CopyOut | CopyIn => {
self.sink_receiver.close();
self.buffered_message = Some(Err(err));
self.state = CommandComplete;
}
_ => self.unexpected_message(),
};
}
Some(Ok(Message::ReadyForQuery(_))) => match self.state {
CommandComplete => {
self.sink_receiver.close();
self.stream_sender.close_channel();
}
_ => self.unexpected_message(),
},
Some(Ok(_)) => self.unexpected_message(),
None => return Poll::Ready(()),
}
}
}
}
/// The [Connection](crate::Connection) will keep polling this stream until it is exhausted. This
/// is the mechanism that drives the CopyBoth subprotocol forward
impl Stream for CopyBothReceiver {
type Item = FrontendMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
use CopyBothState::*;
match self.poll_backend(cx) {
Poll::Ready(()) => Poll::Ready(None),
Poll::Pending => match self.state {
Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) {
Some(msg) => Poll::Ready(Some(msg)),
None => match self.state {
// The user has cancelled their interest to this CopyBoth query but we're
// still in the Setup phase. From this point the receiver will either enter
// CopyBoth mode or will receive an Error response from PostgreSQL. When
// either of those happens the state machine will terminate the connection
// appropriately.
Setup => Poll::Pending,
CopyBoth => {
self.state = CopyOut;
let mut buf = BytesMut::new();
frontend::copy_done(&mut buf);
Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
}
CopyIn => {
self.state = CopyNone;
let mut buf = BytesMut::new();
frontend::copy_done(&mut buf);
Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
}
_ => unreachable!(),
},
},
_ => Poll::Pending,
},
}
}
}
pin_project! {
/// A duplex stream for consuming streaming replication data.
///
/// Users should ensure that CopyBothDuplex is dropped before attempting to await on a new
/// query. This will ensure that the connection returns into normal processing mode.
///
/// ```no_run
/// use tokio_postgres::Client;
///
/// async fn foo(client: &Client) {
/// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await;
///
/// // ⚠️ INCORRECT ⚠️
/// client.query("SELECT 1", &[]).await; // hangs forever
///
/// // duplex_stream drop-ed here
/// }
/// ```
///
/// ```no_run
/// use tokio_postgres::Client;
///
/// async fn foo(client: &Client) {
/// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await;
///
/// // ✅ CORRECT ✅
/// drop(duplex_stream);
///
/// client.query("SELECT 1", &[]).await;
/// }
/// ```
pub struct CopyBothDuplex<T> {
#[pin]
sink_sender: mpsc::Sender<FrontendMessage>,
#[pin]
stream_receiver: mpsc::Receiver<Result<Message, Error>>,
buf: BytesMut,
#[pin]
_p: PhantomPinned,
_p2: PhantomData<T>,
}
}
impl<T> Stream for CopyBothDuplex<T> {
type Item = Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) {
Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())),
Some(Ok(_)) => Some(Err(Error::unexpected_message())),
Some(Err(err)) => Some(Err(err)),
None => None,
})
}
}
impl<T> Sink<T> for CopyBothDuplex<T>
where
T: Buf + 'static + Send,
{
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.project()
.sink_sender
.poll_ready(cx)
.map_err(|_| Error::closed())
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
let this = self.project();
let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
if this.buf.is_empty() {
Box::new(item)
} else {
Box::new(this.buf.split().freeze().chain(item))
}
} else {
this.buf.put(item);
if this.buf.len() > 4096 {
Box::new(this.buf.split().freeze())
} else {
return Ok(());
}
};
let data = CopyData::new(data).map_err(Error::encode)?;
this.sink_sender
.start_send(FrontendMessage::CopyData(data))
.map_err(|_| Error::closed())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let mut this = self.project();
if !this.buf.is_empty() {
ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
let data = CopyData::new(data).map_err(Error::encode)?;
this.sink_sender
.as_mut()
.start_send(FrontendMessage::CopyData(data))
.map_err(|_| Error::closed())?;
}
this.sink_sender.poll_flush(cx).map_err(|_| Error::closed())
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
ready!(self.as_mut().poll_flush(cx))?;
let mut this = self.as_mut().project();
this.sink_sender.disconnect();
Poll::Ready(Ok(()))
}
}
pub async fn copy_both_simple<T>(
client: &InnerClient,
query: &str,
) -> Result<CopyBothDuplex<T>, Error>
where
T: Buf + 'static + Send,
{
debug!("executing copy both query {}", query);
let buf = simple_query::encode(client, query)?;
let mut handles = client.start_copy_both()?;
handles
.sink_sender
.send(FrontendMessage::Raw(buf))
.await
.map_err(|_| Error::closed())?;
match handles.stream_receiver.next().await.transpose()? {
Some(Message::CopyBothResponse(_)) => {}
_ => return Err(Error::unexpected_message()),
}
Ok(CopyBothDuplex {
stream_receiver: handles.stream_receiver,
sink_sender: handles.sink_sender,
buf: BytesMut::new(),
_p: PhantomPinned,
_p2: PhantomData,
})
}