tokio_postgres/
query.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::prepare::get_type;
5use crate::types::{BorrowToSql, IsNull};
6use crate::{Column, Error, Portal, Row, Statement};
7use bytes::{Bytes, BytesMut};
8use fallible_iterator::FallibleIterator;
9use futures_util::{ready, Stream};
10use log::{debug, log_enabled, Level};
11use pin_project_lite::pin_project;
12use postgres_protocol::message::backend::{CommandCompleteBody, Message};
13use postgres_protocol::message::frontend;
14use postgres_types::Type;
15use std::fmt;
16use std::marker::PhantomPinned;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
22
23impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
24where
25    T: BorrowToSql,
26{
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_list()
29            .entries(self.0.iter().map(|x| x.borrow_to_sql()))
30            .finish()
31    }
32}
33
34pub async fn query<P, I>(
35    client: &InnerClient,
36    statement: Statement,
37    params: I,
38) -> Result<RowStream, Error>
39where
40    P: BorrowToSql,
41    I: IntoIterator<Item = P>,
42    I::IntoIter: ExactSizeIterator,
43{
44    let buf = if log_enabled!(Level::Debug) {
45        let params = params.into_iter().collect::<Vec<_>>();
46        debug!(
47            "executing statement {} with parameters: {:?}",
48            statement.name(),
49            BorrowToSqlParamsDebug(params.as_slice()),
50        );
51        encode(client, &statement, params)?
52    } else {
53        encode(client, &statement, params)?
54    };
55    let responses = start(client, buf).await?;
56    Ok(RowStream {
57        statement,
58        responses,
59        rows_affected: None,
60        _p: PhantomPinned,
61    })
62}
63
64pub async fn query_typed<'a, P, I>(
65    client: &Arc<InnerClient>,
66    query: &str,
67    params: I,
68) -> Result<RowStream, Error>
69where
70    P: BorrowToSql,
71    I: IntoIterator<Item = (P, Type)>,
72{
73    let buf = {
74        let params = params.into_iter().collect::<Vec<_>>();
75        let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
76
77        client.with_buf(|buf| {
78            frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;
79            encode_bind_raw("", params, "", buf)?;
80            frontend::describe(b'S', "", buf).map_err(Error::encode)?;
81            frontend::execute("", 0, buf).map_err(Error::encode)?;
82            frontend::sync(buf);
83
84            Ok(buf.split().freeze())
85        })?
86    };
87
88    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
89
90    loop {
91        match responses.next().await? {
92            Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {}
93            Message::NoData => {
94                return Ok(RowStream {
95                    statement: Statement::unnamed(vec![], vec![]),
96                    responses,
97                    rows_affected: None,
98                    _p: PhantomPinned,
99                });
100            }
101            Message::RowDescription(row_description) => {
102                let mut columns: Vec<Column> = vec![];
103                let mut it = row_description.fields();
104                while let Some(field) = it.next().map_err(Error::parse)? {
105                    let type_ = get_type(client, field.type_oid()).await?;
106                    let column = Column {
107                        name: field.name().to_string(),
108                        table_oid: Some(field.table_oid()).filter(|n| *n != 0),
109                        column_id: Some(field.column_id()).filter(|n| *n != 0),
110                        r#type: type_,
111                    };
112                    columns.push(column);
113                }
114                return Ok(RowStream {
115                    statement: Statement::unnamed(vec![], columns),
116                    responses,
117                    rows_affected: None,
118                    _p: PhantomPinned,
119                });
120            }
121            _ => return Err(Error::unexpected_message()),
122        }
123    }
124}
125
126pub async fn query_portal(
127    client: &InnerClient,
128    portal: &Portal,
129    max_rows: i32,
130) -> Result<RowStream, Error> {
131    let buf = client.with_buf(|buf| {
132        frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
133        frontend::sync(buf);
134        Ok(buf.split().freeze())
135    })?;
136
137    let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
138
139    Ok(RowStream {
140        statement: portal.statement().clone(),
141        responses,
142        rows_affected: None,
143        _p: PhantomPinned,
144    })
145}
146
147/// Extract the number of rows affected from [`CommandCompleteBody`].
148pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
149    let rows = body
150        .tag()
151        .map_err(Error::parse)?
152        .rsplit(' ')
153        .next()
154        .unwrap()
155        .parse()
156        .unwrap_or(0);
157    Ok(rows)
158}
159
160pub async fn execute<P, I>(
161    client: &InnerClient,
162    statement: Statement,
163    params: I,
164) -> Result<u64, Error>
165where
166    P: BorrowToSql,
167    I: IntoIterator<Item = P>,
168    I::IntoIter: ExactSizeIterator,
169{
170    let buf = if log_enabled!(Level::Debug) {
171        let params = params.into_iter().collect::<Vec<_>>();
172        debug!(
173            "executing statement {} with parameters: {:?}",
174            statement.name(),
175            BorrowToSqlParamsDebug(params.as_slice()),
176        );
177        encode(client, &statement, params)?
178    } else {
179        encode(client, &statement, params)?
180    };
181    let mut responses = start(client, buf).await?;
182
183    let mut rows = 0;
184    loop {
185        match responses.next().await? {
186            Message::DataRow(_) => {}
187            Message::CommandComplete(body) => {
188                rows = extract_row_affected(&body)?;
189            }
190            Message::EmptyQueryResponse => rows = 0,
191            Message::ReadyForQuery(_) => return Ok(rows),
192            _ => return Err(Error::unexpected_message()),
193        }
194    }
195}
196
197async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
198    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
199
200    match responses.next().await? {
201        Message::BindComplete => {}
202        _ => return Err(Error::unexpected_message()),
203    }
204
205    Ok(responses)
206}
207
208pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
209where
210    P: BorrowToSql,
211    I: IntoIterator<Item = P>,
212    I::IntoIter: ExactSizeIterator,
213{
214    client.with_buf(|buf| {
215        encode_bind(statement, params, "", buf)?;
216        frontend::execute("", 0, buf).map_err(Error::encode)?;
217        frontend::sync(buf);
218        Ok(buf.split().freeze())
219    })
220}
221
222pub fn encode_bind<P, I>(
223    statement: &Statement,
224    params: I,
225    portal: &str,
226    buf: &mut BytesMut,
227) -> Result<(), Error>
228where
229    P: BorrowToSql,
230    I: IntoIterator<Item = P>,
231    I::IntoIter: ExactSizeIterator,
232{
233    let params = params.into_iter();
234    if params.len() != statement.params().len() {
235        return Err(Error::parameters(params.len(), statement.params().len()));
236    }
237
238    encode_bind_raw(
239        statement.name(),
240        params.zip(statement.params().iter().cloned()),
241        portal,
242        buf,
243    )
244}
245
246fn encode_bind_raw<P, I>(
247    statement_name: &str,
248    params: I,
249    portal: &str,
250    buf: &mut BytesMut,
251) -> Result<(), Error>
252where
253    P: BorrowToSql,
254    I: IntoIterator<Item = (P, Type)>,
255    I::IntoIter: ExactSizeIterator,
256{
257    let (param_formats, params): (Vec<_>, Vec<_>) = params
258        .into_iter()
259        .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty)))
260        .unzip();
261
262    let mut error_idx = 0;
263    let r = frontend::bind(
264        portal,
265        statement_name,
266        param_formats,
267        params.into_iter().enumerate(),
268        |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) {
269            Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
270            Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
271            Err(e) => {
272                error_idx = idx;
273                Err(e)
274            }
275        },
276        Some(1),
277        buf,
278    );
279    match r {
280        Ok(()) => Ok(()),
281        Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
282        Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
283    }
284}
285
286pin_project! {
287    /// A stream of table rows.
288    pub struct RowStream {
289        statement: Statement,
290        responses: Responses,
291        rows_affected: Option<u64>,
292        #[pin]
293        _p: PhantomPinned,
294    }
295}
296
297impl Stream for RowStream {
298    type Item = Result<Row, Error>;
299
300    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301        let this = self.project();
302        loop {
303            match ready!(this.responses.poll_next(cx)?) {
304                Message::DataRow(body) => {
305                    return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
306                }
307                Message::CommandComplete(body) => {
308                    *this.rows_affected = Some(extract_row_affected(&body)?);
309                }
310                Message::EmptyQueryResponse | Message::PortalSuspended => {}
311                Message::ReadyForQuery(_) => return Poll::Ready(None),
312                _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
313            }
314        }
315    }
316}
317
318impl RowStream {
319    /// Returns the number of rows affected by the query.
320    ///
321    /// This function will return `None` until the stream has been exhausted.
322    pub fn rows_affected(&self) -> Option<u64> {
323        self.rows_affected
324    }
325}