tokio_postgres/
connection.rs

1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::copy_both::CopyBothReceiver;
3use crate::copy_in::CopyInReceiver;
4use crate::error::DbError;
5use crate::maybe_tls_stream::MaybeTlsStream;
6use crate::{AsyncMessage, Error, Notification};
7use bytes::BytesMut;
8use fallible_iterator::FallibleIterator;
9use futures_channel::mpsc;
10use futures_util::{ready, stream::FusedStream, Sink, Stream, StreamExt};
11use log::{info, trace};
12use postgres_protocol::message::backend::Message;
13use postgres_protocol::message::frontend;
14use std::collections::{HashMap, VecDeque};
15use std::future::Future;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::Framed;
20
21pub enum RequestMessages {
22    Single(FrontendMessage),
23    CopyIn(CopyInReceiver),
24    CopyBoth(CopyBothReceiver),
25}
26
27pub struct Request {
28    pub messages: RequestMessages,
29    pub sender: mpsc::Sender<BackendMessages>,
30}
31
32pub struct Response {
33    sender: mpsc::Sender<BackendMessages>,
34}
35
36#[derive(PartialEq, Debug)]
37enum State {
38    Active,
39    Terminating,
40    Closing,
41}
42
43/// A connection to a PostgreSQL database.
44///
45/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
46/// server, and should generally be spawned off onto an executor to run in the background.
47///
48/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
49/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
50#[must_use = "futures do nothing unless polled"]
51pub struct Connection<S, T> {
52    stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
53    parameters: HashMap<String, String>,
54    receiver: mpsc::UnboundedReceiver<Request>,
55    pending_request: Option<RequestMessages>,
56    pending_responses: VecDeque<BackendMessage>,
57    responses: VecDeque<Response>,
58    state: State,
59}
60
61impl<S, T> Connection<S, T>
62where
63    S: AsyncRead + AsyncWrite + Unpin,
64    T: AsyncRead + AsyncWrite + Unpin,
65{
66    pub(crate) fn new(
67        stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
68        pending_responses: VecDeque<BackendMessage>,
69        parameters: HashMap<String, String>,
70        receiver: mpsc::UnboundedReceiver<Request>,
71    ) -> Connection<S, T> {
72        Connection {
73            stream,
74            parameters,
75            receiver,
76            pending_request: None,
77            pending_responses,
78            responses: VecDeque::new(),
79            state: State::Active,
80        }
81    }
82
83    fn poll_response(
84        &mut self,
85        cx: &mut Context<'_>,
86    ) -> Poll<Option<Result<BackendMessage, Error>>> {
87        if let Some(message) = self.pending_responses.pop_front() {
88            trace!("retrying pending response");
89            return Poll::Ready(Some(Ok(message)));
90        }
91
92        Pin::new(&mut self.stream)
93            .poll_next(cx)
94            .map(|o| o.map(|r| r.map_err(Error::io)))
95    }
96
97    fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
98        if self.state != State::Active {
99            trace!("poll_read: done");
100            return Ok(None);
101        }
102
103        loop {
104            let message = match self.poll_response(cx)? {
105                Poll::Ready(Some(message)) => message,
106                Poll::Ready(None) => return Err(Error::closed()),
107                Poll::Pending => {
108                    trace!("poll_read: waiting on response");
109                    return Ok(None);
110                }
111            };
112
113            let (mut messages, request_complete) = match message {
114                BackendMessage::Async(Message::NoticeResponse(body)) => {
115                    let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
116                    return Ok(Some(AsyncMessage::Notice(error)));
117                }
118                BackendMessage::Async(Message::NotificationResponse(body)) => {
119                    let notification = Notification {
120                        process_id: body.process_id(),
121                        channel: body.channel().map_err(Error::parse)?.to_string(),
122                        payload: body.message().map_err(Error::parse)?.to_string(),
123                    };
124                    return Ok(Some(AsyncMessage::Notification(notification)));
125                }
126                BackendMessage::Async(Message::ParameterStatus(body)) => {
127                    self.parameters.insert(
128                        body.name().map_err(Error::parse)?.to_string(),
129                        body.value().map_err(Error::parse)?.to_string(),
130                    );
131                    continue;
132                }
133                BackendMessage::Async(_) => unreachable!(),
134                BackendMessage::Normal {
135                    messages,
136                    request_complete,
137                } => (messages, request_complete),
138            };
139
140            let mut response = match self.responses.pop_front() {
141                Some(response) => response,
142                None => match messages.next().map_err(Error::parse)? {
143                    Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
144                    _ => return Err(Error::unexpected_message()),
145                },
146            };
147
148            match response.sender.poll_ready(cx) {
149                Poll::Ready(Ok(())) => {
150                    let _ = response.sender.start_send(messages);
151                    if !request_complete {
152                        self.responses.push_front(response);
153                    }
154                }
155                Poll::Ready(Err(_)) => {
156                    // we need to keep paging through the rest of the messages even if the receiver's hung up
157                    if !request_complete {
158                        self.responses.push_front(response);
159                    }
160                }
161                Poll::Pending => {
162                    self.responses.push_front(response);
163                    self.pending_responses.push_back(BackendMessage::Normal {
164                        messages,
165                        request_complete,
166                    });
167                    trace!("poll_read: waiting on sender");
168                    return Ok(None);
169                }
170            }
171        }
172    }
173
174    fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
175        if let Some(messages) = self.pending_request.take() {
176            trace!("retrying pending request");
177            return Poll::Ready(Some(messages));
178        }
179
180        if self.receiver.is_terminated() {
181            return Poll::Ready(None);
182        }
183
184        match self.receiver.poll_next_unpin(cx) {
185            Poll::Ready(Some(request)) => {
186                trace!("polled new request");
187                self.responses.push_back(Response {
188                    sender: request.sender,
189                });
190                Poll::Ready(Some(request.messages))
191            }
192            Poll::Ready(None) => Poll::Ready(None),
193            Poll::Pending => Poll::Pending,
194        }
195    }
196
197    fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
198        loop {
199            if self.state == State::Closing {
200                trace!("poll_write: done");
201                return Ok(false);
202            }
203
204            if Pin::new(&mut self.stream)
205                .poll_ready(cx)
206                .map_err(Error::io)?
207                .is_pending()
208            {
209                trace!("poll_write: waiting on socket");
210                return Ok(false);
211            }
212
213            let request = match self.poll_request(cx) {
214                Poll::Ready(Some(request)) => request,
215                Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
216                    trace!("poll_write: at eof, terminating");
217                    self.state = State::Terminating;
218                    let mut request = BytesMut::new();
219                    frontend::terminate(&mut request);
220                    RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
221                }
222                Poll::Ready(None) => {
223                    trace!(
224                        "poll_write: at eof, pending responses {}",
225                        self.responses.len()
226                    );
227                    return Ok(true);
228                }
229                Poll::Pending => {
230                    trace!("poll_write: waiting on request");
231                    return Ok(true);
232                }
233            };
234
235            match request {
236                RequestMessages::Single(request) => {
237                    Pin::new(&mut self.stream)
238                        .start_send(request)
239                        .map_err(Error::io)?;
240                    if self.state == State::Terminating {
241                        trace!("poll_write: sent eof, closing");
242                        self.state = State::Closing;
243                    }
244                }
245                RequestMessages::CopyIn(mut receiver) => {
246                    let message = match receiver.poll_next_unpin(cx) {
247                        Poll::Ready(Some(message)) => message,
248                        Poll::Ready(None) => {
249                            trace!("poll_write: finished copy_in request");
250                            continue;
251                        }
252                        Poll::Pending => {
253                            trace!("poll_write: waiting on copy_in stream");
254                            self.pending_request = Some(RequestMessages::CopyIn(receiver));
255                            return Ok(true);
256                        }
257                    };
258                    Pin::new(&mut self.stream)
259                        .start_send(message)
260                        .map_err(Error::io)?;
261                    self.pending_request = Some(RequestMessages::CopyIn(receiver));
262                }
263                RequestMessages::CopyBoth(mut receiver) => {
264                    let message = match receiver.poll_next_unpin(cx) {
265                        Poll::Ready(Some(message)) => message,
266                        Poll::Ready(None) => {
267                            trace!("poll_write: finished copy_both request");
268                            continue;
269                        }
270                        Poll::Pending => {
271                            trace!("poll_write: waiting on copy_both stream");
272                            self.pending_request = Some(RequestMessages::CopyBoth(receiver));
273                            return Ok(true);
274                        }
275                    };
276                    Pin::new(&mut self.stream)
277                        .start_send(message)
278                        .map_err(Error::io)?;
279                    self.pending_request = Some(RequestMessages::CopyBoth(receiver));
280                }
281            }
282        }
283    }
284
285    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
286        match Pin::new(&mut self.stream)
287            .poll_flush(cx)
288            .map_err(Error::io)?
289        {
290            Poll::Ready(()) => trace!("poll_flush: flushed"),
291            Poll::Pending => trace!("poll_flush: waiting on socket"),
292        }
293        Ok(())
294    }
295
296    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
297        if self.state != State::Closing {
298            return Poll::Pending;
299        }
300
301        match Pin::new(&mut self.stream)
302            .poll_close(cx)
303            .map_err(Error::io)?
304        {
305            Poll::Ready(()) => {
306                trace!("poll_shutdown: complete");
307                Poll::Ready(Ok(()))
308            }
309            Poll::Pending => {
310                trace!("poll_shutdown: waiting on socket");
311                Poll::Pending
312            }
313        }
314    }
315
316    /// Returns the value of a runtime parameter for this connection.
317    pub fn parameter(&self, name: &str) -> Option<&str> {
318        self.parameters.get(name).map(|s| &**s)
319    }
320
321    /// Polls for asynchronous messages from the server.
322    ///
323    /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
324    /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
325    ///
326    /// Return values of `None` or `Some(Err(_))` are "terminal"; callers should not invoke this method again after
327    /// receiving one of those values.
328    pub fn poll_message(
329        &mut self,
330        cx: &mut Context<'_>,
331    ) -> Poll<Option<Result<AsyncMessage, Error>>> {
332        let message = self.poll_read(cx)?;
333        let want_flush = self.poll_write(cx)?;
334        if want_flush {
335            self.poll_flush(cx)?;
336        }
337        match message {
338            Some(message) => Poll::Ready(Some(Ok(message))),
339            None => match self.poll_shutdown(cx) {
340                Poll::Ready(Ok(())) => Poll::Ready(None),
341                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
342                Poll::Pending => Poll::Pending,
343            },
344        }
345    }
346}
347
348impl<S, T> Future for Connection<S, T>
349where
350    S: AsyncRead + AsyncWrite + Unpin,
351    T: AsyncRead + AsyncWrite + Unpin,
352{
353    type Output = Result<(), Error>;
354
355    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
356        while let Some(message) = ready!(self.poll_message(cx)?) {
357            if let AsyncMessage::Notice(notice) = message {
358                info!("{}: {}", notice.severity(), notice.message());
359            }
360        }
361        Poll::Ready(Ok(()))
362    }
363}