1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::copy_both::CopyBothReceiver;
3use crate::copy_in::CopyInReceiver;
4use crate::error::DbError;
5use crate::maybe_tls_stream::MaybeTlsStream;
6use crate::{AsyncMessage, Error, Notification};
7use bytes::BytesMut;
8use fallible_iterator::FallibleIterator;
9use futures_channel::mpsc;
10use futures_util::{ready, stream::FusedStream, Sink, Stream, StreamExt};
11use log::{info, trace};
12use postgres_protocol::message::backend::Message;
13use postgres_protocol::message::frontend;
14use std::collections::{HashMap, VecDeque};
15use std::future::Future;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::Framed;
20
21pub enum RequestMessages {
22 Single(FrontendMessage),
23 CopyIn(CopyInReceiver),
24 CopyBoth(CopyBothReceiver),
25}
26
27pub struct Request {
28 pub messages: RequestMessages,
29 pub sender: mpsc::Sender<BackendMessages>,
30}
31
32pub struct Response {
33 sender: mpsc::Sender<BackendMessages>,
34}
35
36#[derive(PartialEq, Debug)]
37enum State {
38 Active,
39 Terminating,
40 Closing,
41}
42
43#[must_use = "futures do nothing unless polled"]
51pub struct Connection<S, T> {
52 stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
53 parameters: HashMap<String, String>,
54 receiver: mpsc::UnboundedReceiver<Request>,
55 pending_request: Option<RequestMessages>,
56 pending_responses: VecDeque<BackendMessage>,
57 responses: VecDeque<Response>,
58 state: State,
59}
60
61impl<S, T> Connection<S, T>
62where
63 S: AsyncRead + AsyncWrite + Unpin,
64 T: AsyncRead + AsyncWrite + Unpin,
65{
66 pub(crate) fn new(
67 stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
68 pending_responses: VecDeque<BackendMessage>,
69 parameters: HashMap<String, String>,
70 receiver: mpsc::UnboundedReceiver<Request>,
71 ) -> Connection<S, T> {
72 Connection {
73 stream,
74 parameters,
75 receiver,
76 pending_request: None,
77 pending_responses,
78 responses: VecDeque::new(),
79 state: State::Active,
80 }
81 }
82
83 fn poll_response(
84 &mut self,
85 cx: &mut Context<'_>,
86 ) -> Poll<Option<Result<BackendMessage, Error>>> {
87 if let Some(message) = self.pending_responses.pop_front() {
88 trace!("retrying pending response");
89 return Poll::Ready(Some(Ok(message)));
90 }
91
92 Pin::new(&mut self.stream)
93 .poll_next(cx)
94 .map(|o| o.map(|r| r.map_err(Error::io)))
95 }
96
97 fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
98 if self.state != State::Active {
99 trace!("poll_read: done");
100 return Ok(None);
101 }
102
103 loop {
104 let message = match self.poll_response(cx)? {
105 Poll::Ready(Some(message)) => message,
106 Poll::Ready(None) => return Err(Error::closed()),
107 Poll::Pending => {
108 trace!("poll_read: waiting on response");
109 return Ok(None);
110 }
111 };
112
113 let (mut messages, request_complete) = match message {
114 BackendMessage::Async(Message::NoticeResponse(body)) => {
115 let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
116 return Ok(Some(AsyncMessage::Notice(error)));
117 }
118 BackendMessage::Async(Message::NotificationResponse(body)) => {
119 let notification = Notification {
120 process_id: body.process_id(),
121 channel: body.channel().map_err(Error::parse)?.to_string(),
122 payload: body.message().map_err(Error::parse)?.to_string(),
123 };
124 return Ok(Some(AsyncMessage::Notification(notification)));
125 }
126 BackendMessage::Async(Message::ParameterStatus(body)) => {
127 self.parameters.insert(
128 body.name().map_err(Error::parse)?.to_string(),
129 body.value().map_err(Error::parse)?.to_string(),
130 );
131 continue;
132 }
133 BackendMessage::Async(_) => unreachable!(),
134 BackendMessage::Normal {
135 messages,
136 request_complete,
137 } => (messages, request_complete),
138 };
139
140 let mut response = match self.responses.pop_front() {
141 Some(response) => response,
142 None => match messages.next().map_err(Error::parse)? {
143 Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
144 _ => return Err(Error::unexpected_message()),
145 },
146 };
147
148 match response.sender.poll_ready(cx) {
149 Poll::Ready(Ok(())) => {
150 let _ = response.sender.start_send(messages);
151 if !request_complete {
152 self.responses.push_front(response);
153 }
154 }
155 Poll::Ready(Err(_)) => {
156 if !request_complete {
158 self.responses.push_front(response);
159 }
160 }
161 Poll::Pending => {
162 self.responses.push_front(response);
163 self.pending_responses.push_back(BackendMessage::Normal {
164 messages,
165 request_complete,
166 });
167 trace!("poll_read: waiting on sender");
168 return Ok(None);
169 }
170 }
171 }
172 }
173
174 fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
175 if let Some(messages) = self.pending_request.take() {
176 trace!("retrying pending request");
177 return Poll::Ready(Some(messages));
178 }
179
180 if self.receiver.is_terminated() {
181 return Poll::Ready(None);
182 }
183
184 match self.receiver.poll_next_unpin(cx) {
185 Poll::Ready(Some(request)) => {
186 trace!("polled new request");
187 self.responses.push_back(Response {
188 sender: request.sender,
189 });
190 Poll::Ready(Some(request.messages))
191 }
192 Poll::Ready(None) => Poll::Ready(None),
193 Poll::Pending => Poll::Pending,
194 }
195 }
196
197 fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
198 loop {
199 if self.state == State::Closing {
200 trace!("poll_write: done");
201 return Ok(false);
202 }
203
204 if Pin::new(&mut self.stream)
205 .poll_ready(cx)
206 .map_err(Error::io)?
207 .is_pending()
208 {
209 trace!("poll_write: waiting on socket");
210 return Ok(false);
211 }
212
213 let request = match self.poll_request(cx) {
214 Poll::Ready(Some(request)) => request,
215 Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
216 trace!("poll_write: at eof, terminating");
217 self.state = State::Terminating;
218 let mut request = BytesMut::new();
219 frontend::terminate(&mut request);
220 RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
221 }
222 Poll::Ready(None) => {
223 trace!(
224 "poll_write: at eof, pending responses {}",
225 self.responses.len()
226 );
227 return Ok(true);
228 }
229 Poll::Pending => {
230 trace!("poll_write: waiting on request");
231 return Ok(true);
232 }
233 };
234
235 match request {
236 RequestMessages::Single(request) => {
237 Pin::new(&mut self.stream)
238 .start_send(request)
239 .map_err(Error::io)?;
240 if self.state == State::Terminating {
241 trace!("poll_write: sent eof, closing");
242 self.state = State::Closing;
243 }
244 }
245 RequestMessages::CopyIn(mut receiver) => {
246 let message = match receiver.poll_next_unpin(cx) {
247 Poll::Ready(Some(message)) => message,
248 Poll::Ready(None) => {
249 trace!("poll_write: finished copy_in request");
250 continue;
251 }
252 Poll::Pending => {
253 trace!("poll_write: waiting on copy_in stream");
254 self.pending_request = Some(RequestMessages::CopyIn(receiver));
255 return Ok(true);
256 }
257 };
258 Pin::new(&mut self.stream)
259 .start_send(message)
260 .map_err(Error::io)?;
261 self.pending_request = Some(RequestMessages::CopyIn(receiver));
262 }
263 RequestMessages::CopyBoth(mut receiver) => {
264 let message = match receiver.poll_next_unpin(cx) {
265 Poll::Ready(Some(message)) => message,
266 Poll::Ready(None) => {
267 trace!("poll_write: finished copy_both request");
268 continue;
269 }
270 Poll::Pending => {
271 trace!("poll_write: waiting on copy_both stream");
272 self.pending_request = Some(RequestMessages::CopyBoth(receiver));
273 return Ok(true);
274 }
275 };
276 Pin::new(&mut self.stream)
277 .start_send(message)
278 .map_err(Error::io)?;
279 self.pending_request = Some(RequestMessages::CopyBoth(receiver));
280 }
281 }
282 }
283 }
284
285 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
286 match Pin::new(&mut self.stream)
287 .poll_flush(cx)
288 .map_err(Error::io)?
289 {
290 Poll::Ready(()) => trace!("poll_flush: flushed"),
291 Poll::Pending => trace!("poll_flush: waiting on socket"),
292 }
293 Ok(())
294 }
295
296 fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
297 if self.state != State::Closing {
298 return Poll::Pending;
299 }
300
301 match Pin::new(&mut self.stream)
302 .poll_close(cx)
303 .map_err(Error::io)?
304 {
305 Poll::Ready(()) => {
306 trace!("poll_shutdown: complete");
307 Poll::Ready(Ok(()))
308 }
309 Poll::Pending => {
310 trace!("poll_shutdown: waiting on socket");
311 Poll::Pending
312 }
313 }
314 }
315
316 pub fn parameter(&self, name: &str) -> Option<&str> {
318 self.parameters.get(name).map(|s| &**s)
319 }
320
321 pub fn poll_message(
329 &mut self,
330 cx: &mut Context<'_>,
331 ) -> Poll<Option<Result<AsyncMessage, Error>>> {
332 let message = self.poll_read(cx)?;
333 let want_flush = self.poll_write(cx)?;
334 if want_flush {
335 self.poll_flush(cx)?;
336 }
337 match message {
338 Some(message) => Poll::Ready(Some(Ok(message))),
339 None => match self.poll_shutdown(cx) {
340 Poll::Ready(Ok(())) => Poll::Ready(None),
341 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
342 Poll::Pending => Poll::Pending,
343 },
344 }
345 }
346}
347
348impl<S, T> Future for Connection<S, T>
349where
350 S: AsyncRead + AsyncWrite + Unpin,
351 T: AsyncRead + AsyncWrite + Unpin,
352{
353 type Output = Result<(), Error>;
354
355 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
356 while let Some(message) = ready!(self.poll_message(cx)?) {
357 if let AsyncMessage::Notice(notice) = message {
358 info!("{}: {}", notice.severity(), notice.message());
359 }
360 }
361 Poll::Ready(Ok(()))
362 }
363}