tokio_postgres/
connect_raw.rs

1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::config::{self, Config, ReplicationMode};
3use crate::connect_tls::connect_tls;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::tls::{TlsConnect, TlsStream};
6use crate::{Client, Connection, Error};
7use bytes::BytesMut;
8use fallible_iterator::FallibleIterator;
9use futures_channel::mpsc;
10use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
11use postgres_protocol::authentication;
12use postgres_protocol::authentication::sasl;
13use postgres_protocol::authentication::sasl::ScramSha256;
14use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
15use postgres_protocol::message::frontend;
16use std::borrow::Cow;
17use std::collections::{HashMap, VecDeque};
18use std::io;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio_util::codec::Framed;
23
24pub struct StartupStream<S, T> {
25    inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
26    buf: BackendMessages,
27    delayed: VecDeque<BackendMessage>,
28}
29
30impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
31where
32    S: AsyncRead + AsyncWrite + Unpin,
33    T: AsyncRead + AsyncWrite + Unpin,
34{
35    type Error = io::Error;
36
37    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
38        Pin::new(&mut self.inner).poll_ready(cx)
39    }
40
41    fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
42        Pin::new(&mut self.inner).start_send(item)
43    }
44
45    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46        Pin::new(&mut self.inner).poll_flush(cx)
47    }
48
49    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
50        Pin::new(&mut self.inner).poll_close(cx)
51    }
52}
53
54impl<S, T> Stream for StartupStream<S, T>
55where
56    S: AsyncRead + AsyncWrite + Unpin,
57    T: AsyncRead + AsyncWrite + Unpin,
58{
59    type Item = io::Result<Message>;
60
61    fn poll_next(
62        mut self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64    ) -> Poll<Option<io::Result<Message>>> {
65        loop {
66            match self.buf.next() {
67                Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
68                Ok(None) => {}
69                Err(e) => return Poll::Ready(Some(Err(e))),
70            }
71
72            match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
73                Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
74                Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
75                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
76                None => return Poll::Ready(None),
77            }
78        }
79    }
80}
81
82pub async fn connect_raw<S, T>(
83    stream: S,
84    tls: T,
85    has_hostname: bool,
86    config: &Config,
87) -> Result<(Client, Connection<S, T::Stream>), Error>
88where
89    S: AsyncRead + AsyncWrite + Unpin,
90    T: TlsConnect<S>,
91{
92    let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?;
93
94    let mut stream = StartupStream {
95        inner: Framed::new(stream, PostgresCodec),
96        buf: BackendMessages::empty(),
97        delayed: VecDeque::new(),
98    };
99
100    let user = config
101        .user
102        .as_deref()
103        .map_or_else(|| Cow::Owned(whoami::username()), Cow::Borrowed);
104
105    startup(&mut stream, config, &user).await?;
106    authenticate(&mut stream, config, &user).await?;
107    let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
108
109    let (sender, receiver) = mpsc::unbounded();
110    let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
111    let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
112
113    Ok((client, connection))
114}
115
116async fn startup<S, T>(
117    stream: &mut StartupStream<S, T>,
118    config: &Config,
119    user: &str,
120) -> Result<(), Error>
121where
122    S: AsyncRead + AsyncWrite + Unpin,
123    T: AsyncRead + AsyncWrite + Unpin,
124{
125    let mut params = vec![("client_encoding", "UTF8")];
126    params.push(("user", user));
127    if let Some(dbname) = &config.dbname {
128        params.push(("database", &**dbname));
129    }
130    if let Some(options) = &config.options {
131        params.push(("options", &**options));
132    }
133    if let Some(application_name) = &config.application_name {
134        params.push(("application_name", &**application_name));
135    }
136    if let Some(replication_mode) = &config.replication_mode {
137        match replication_mode {
138            ReplicationMode::Physical => params.push(("replication", "true")),
139            ReplicationMode::Logical => params.push(("replication", "database")),
140        }
141    }
142
143    let mut buf = BytesMut::new();
144    frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
145
146    stream
147        .send(FrontendMessage::Raw(buf.freeze()))
148        .await
149        .map_err(Error::io)
150}
151
152async fn authenticate<S, T>(
153    stream: &mut StartupStream<S, T>,
154    config: &Config,
155    user: &str,
156) -> Result<(), Error>
157where
158    S: AsyncRead + AsyncWrite + Unpin,
159    T: TlsStream + Unpin,
160{
161    match stream.try_next().await.map_err(Error::io)? {
162        Some(Message::AuthenticationOk) => {
163            can_skip_channel_binding(config)?;
164            return Ok(());
165        }
166        Some(Message::AuthenticationCleartextPassword) => {
167            can_skip_channel_binding(config)?;
168
169            let pass = config
170                .password
171                .as_ref()
172                .ok_or_else(|| Error::config("password missing".into()))?;
173
174            authenticate_password(stream, pass).await?;
175        }
176        Some(Message::AuthenticationMd5Password(body)) => {
177            can_skip_channel_binding(config)?;
178
179            let pass = config
180                .password
181                .as_ref()
182                .ok_or_else(|| Error::config("password missing".into()))?;
183
184            let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
185            authenticate_password(stream, output.as_bytes()).await?;
186        }
187        Some(Message::AuthenticationSasl(body)) => {
188            authenticate_sasl(stream, body, config).await?;
189        }
190        Some(Message::AuthenticationKerberosV5)
191        | Some(Message::AuthenticationScmCredential)
192        | Some(Message::AuthenticationGss)
193        | Some(Message::AuthenticationSspi) => {
194            return Err(Error::authentication(
195                "unsupported authentication method".into(),
196            ))
197        }
198        Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
199        Some(_) => return Err(Error::unexpected_message()),
200        None => return Err(Error::closed()),
201    }
202
203    match stream.try_next().await.map_err(Error::io)? {
204        Some(Message::AuthenticationOk) => Ok(()),
205        Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
206        Some(_) => Err(Error::unexpected_message()),
207        None => Err(Error::closed()),
208    }
209}
210
211fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
212    match config.channel_binding {
213        config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
214        config::ChannelBinding::Require => Err(Error::authentication(
215            "server did not use channel binding".into(),
216        )),
217    }
218}
219
220async fn authenticate_password<S, T>(
221    stream: &mut StartupStream<S, T>,
222    password: &[u8],
223) -> Result<(), Error>
224where
225    S: AsyncRead + AsyncWrite + Unpin,
226    T: AsyncRead + AsyncWrite + Unpin,
227{
228    let mut buf = BytesMut::new();
229    frontend::password_message(password, &mut buf).map_err(Error::encode)?;
230
231    stream
232        .send(FrontendMessage::Raw(buf.freeze()))
233        .await
234        .map_err(Error::io)
235}
236
237async fn authenticate_sasl<S, T>(
238    stream: &mut StartupStream<S, T>,
239    body: AuthenticationSaslBody,
240    config: &Config,
241) -> Result<(), Error>
242where
243    S: AsyncRead + AsyncWrite + Unpin,
244    T: TlsStream + Unpin,
245{
246    let password = config
247        .password
248        .as_ref()
249        .ok_or_else(|| Error::config("password missing".into()))?;
250
251    let mut has_scram = false;
252    let mut has_scram_plus = false;
253    let mut mechanisms = body.mechanisms();
254    while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
255        match mechanism {
256            sasl::SCRAM_SHA_256 => has_scram = true,
257            sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
258            _ => {}
259        }
260    }
261
262    let channel_binding = stream
263        .inner
264        .get_ref()
265        .channel_binding()
266        .tls_server_end_point
267        .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
268        .map(sasl::ChannelBinding::tls_server_end_point);
269
270    let (channel_binding, mechanism) = if has_scram_plus {
271        match channel_binding {
272            Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
273            None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
274        }
275    } else if has_scram {
276        match channel_binding {
277            Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
278            None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
279        }
280    } else {
281        return Err(Error::authentication("unsupported SASL mechanism".into()));
282    };
283
284    if mechanism != sasl::SCRAM_SHA_256_PLUS {
285        can_skip_channel_binding(config)?;
286    }
287
288    let mut scram = ScramSha256::new(password, channel_binding);
289
290    let mut buf = BytesMut::new();
291    frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
292    stream
293        .send(FrontendMessage::Raw(buf.freeze()))
294        .await
295        .map_err(Error::io)?;
296
297    let body = match stream.try_next().await.map_err(Error::io)? {
298        Some(Message::AuthenticationSaslContinue(body)) => body,
299        Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
300        Some(_) => return Err(Error::unexpected_message()),
301        None => return Err(Error::closed()),
302    };
303
304    scram
305        .update(body.data())
306        .map_err(|e| Error::authentication(e.into()))?;
307
308    let mut buf = BytesMut::new();
309    frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
310    stream
311        .send(FrontendMessage::Raw(buf.freeze()))
312        .await
313        .map_err(Error::io)?;
314
315    let body = match stream.try_next().await.map_err(Error::io)? {
316        Some(Message::AuthenticationSaslFinal(body)) => body,
317        Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
318        Some(_) => return Err(Error::unexpected_message()),
319        None => return Err(Error::closed()),
320    };
321
322    scram
323        .finish(body.data())
324        .map_err(|e| Error::authentication(e.into()))?;
325
326    Ok(())
327}
328
329async fn read_info<S, T>(
330    stream: &mut StartupStream<S, T>,
331) -> Result<(i32, i32, HashMap<String, String>), Error>
332where
333    S: AsyncRead + AsyncWrite + Unpin,
334    T: AsyncRead + AsyncWrite + Unpin,
335{
336    let mut process_id = 0;
337    let mut secret_key = 0;
338    let mut parameters = HashMap::new();
339
340    loop {
341        match stream.try_next().await.map_err(Error::io)? {
342            Some(Message::BackendKeyData(body)) => {
343                process_id = body.process_id();
344                secret_key = body.secret_key();
345            }
346            Some(Message::ParameterStatus(body)) => {
347                parameters.insert(
348                    body.name().map_err(Error::parse)?.to_string(),
349                    body.value().map_err(Error::parse)?.to_string(),
350                );
351            }
352            Some(msg @ Message::NoticeResponse(_)) => {
353                stream.delayed.push_back(BackendMessage::Async(msg))
354            }
355            Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
356            Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
357            Some(_) => return Err(Error::unexpected_message()),
358            None => return Err(Error::closed()),
359        }
360    }
361}