1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::prepare::get_type;
5use crate::types::{BorrowToSql, IsNull};
6use crate::{Column, Error, Portal, Row, Statement};
7use bytes::{Bytes, BytesMut};
8use fallible_iterator::FallibleIterator;
9use futures_util::{ready, Stream};
10use log::{debug, log_enabled, Level};
11use pin_project_lite::pin_project;
12use postgres_protocol::message::backend::{CommandCompleteBody, Message};
13use postgres_protocol::message::frontend;
14use postgres_types::Type;
15use std::fmt;
16use std::marker::PhantomPinned;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
22
23impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
24where
25 T: BorrowToSql,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_list()
29 .entries(self.0.iter().map(|x| x.borrow_to_sql()))
30 .finish()
31 }
32}
33
34pub async fn query<P, I>(
35 client: &InnerClient,
36 statement: Statement,
37 params: I,
38) -> Result<RowStream, Error>
39where
40 P: BorrowToSql,
41 I: IntoIterator<Item = P>,
42 I::IntoIter: ExactSizeIterator,
43{
44 let buf = if log_enabled!(Level::Debug) {
45 let params = params.into_iter().collect::<Vec<_>>();
46 debug!(
47 "executing statement {} with parameters: {:?}",
48 statement.name(),
49 BorrowToSqlParamsDebug(params.as_slice()),
50 );
51 encode(client, &statement, params)?
52 } else {
53 encode(client, &statement, params)?
54 };
55 let responses = start(client, buf).await?;
56 Ok(RowStream {
57 statement,
58 responses,
59 rows_affected: None,
60 _p: PhantomPinned,
61 })
62}
63
64pub async fn query_typed<'a, P, I>(
65 client: &Arc<InnerClient>,
66 query: &str,
67 params: I,
68) -> Result<RowStream, Error>
69where
70 P: BorrowToSql,
71 I: IntoIterator<Item = (P, Type)>,
72{
73 let buf = {
74 let params = params.into_iter().collect::<Vec<_>>();
75 let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
76
77 client.with_buf(|buf| {
78 frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;
79 encode_bind_raw("", params, "", buf)?;
80 frontend::describe(b'S', "", buf).map_err(Error::encode)?;
81 frontend::execute("", 0, buf).map_err(Error::encode)?;
82 frontend::sync(buf);
83
84 Ok(buf.split().freeze())
85 })?
86 };
87
88 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
89
90 loop {
91 match responses.next().await? {
92 Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {}
93 Message::NoData => {
94 return Ok(RowStream {
95 statement: Statement::unnamed(vec![], vec![]),
96 responses,
97 rows_affected: None,
98 _p: PhantomPinned,
99 });
100 }
101 Message::RowDescription(row_description) => {
102 let mut columns: Vec<Column> = vec![];
103 let mut it = row_description.fields();
104 while let Some(field) = it.next().map_err(Error::parse)? {
105 let type_ = get_type(client, field.type_oid()).await?;
106 let column = Column {
107 name: field.name().to_string(),
108 table_oid: Some(field.table_oid()).filter(|n| *n != 0),
109 column_id: Some(field.column_id()).filter(|n| *n != 0),
110 r#type: type_,
111 };
112 columns.push(column);
113 }
114 return Ok(RowStream {
115 statement: Statement::unnamed(vec![], columns),
116 responses,
117 rows_affected: None,
118 _p: PhantomPinned,
119 });
120 }
121 _ => return Err(Error::unexpected_message()),
122 }
123 }
124}
125
126pub async fn query_portal(
127 client: &InnerClient,
128 portal: &Portal,
129 max_rows: i32,
130) -> Result<RowStream, Error> {
131 let buf = client.with_buf(|buf| {
132 frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
133 frontend::sync(buf);
134 Ok(buf.split().freeze())
135 })?;
136
137 let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
138
139 Ok(RowStream {
140 statement: portal.statement().clone(),
141 responses,
142 rows_affected: None,
143 _p: PhantomPinned,
144 })
145}
146
147pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
149 let rows = body
150 .tag()
151 .map_err(Error::parse)?
152 .rsplit(' ')
153 .next()
154 .unwrap()
155 .parse()
156 .unwrap_or(0);
157 Ok(rows)
158}
159
160pub async fn execute<P, I>(
161 client: &InnerClient,
162 statement: Statement,
163 params: I,
164) -> Result<u64, Error>
165where
166 P: BorrowToSql,
167 I: IntoIterator<Item = P>,
168 I::IntoIter: ExactSizeIterator,
169{
170 let buf = if log_enabled!(Level::Debug) {
171 let params = params.into_iter().collect::<Vec<_>>();
172 debug!(
173 "executing statement {} with parameters: {:?}",
174 statement.name(),
175 BorrowToSqlParamsDebug(params.as_slice()),
176 );
177 encode(client, &statement, params)?
178 } else {
179 encode(client, &statement, params)?
180 };
181 let mut responses = start(client, buf).await?;
182
183 let mut rows = 0;
184 loop {
185 match responses.next().await? {
186 Message::DataRow(_) => {}
187 Message::CommandComplete(body) => {
188 rows = extract_row_affected(&body)?;
189 }
190 Message::EmptyQueryResponse => rows = 0,
191 Message::ReadyForQuery(_) => return Ok(rows),
192 _ => return Err(Error::unexpected_message()),
193 }
194 }
195}
196
197async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
198 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
199
200 match responses.next().await? {
201 Message::BindComplete => {}
202 _ => return Err(Error::unexpected_message()),
203 }
204
205 Ok(responses)
206}
207
208pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
209where
210 P: BorrowToSql,
211 I: IntoIterator<Item = P>,
212 I::IntoIter: ExactSizeIterator,
213{
214 client.with_buf(|buf| {
215 encode_bind(statement, params, "", buf)?;
216 frontend::execute("", 0, buf).map_err(Error::encode)?;
217 frontend::sync(buf);
218 Ok(buf.split().freeze())
219 })
220}
221
222pub fn encode_bind<P, I>(
223 statement: &Statement,
224 params: I,
225 portal: &str,
226 buf: &mut BytesMut,
227) -> Result<(), Error>
228where
229 P: BorrowToSql,
230 I: IntoIterator<Item = P>,
231 I::IntoIter: ExactSizeIterator,
232{
233 let params = params.into_iter();
234 if params.len() != statement.params().len() {
235 return Err(Error::parameters(params.len(), statement.params().len()));
236 }
237
238 encode_bind_raw(
239 statement.name(),
240 params.zip(statement.params().iter().cloned()),
241 portal,
242 buf,
243 )
244}
245
246fn encode_bind_raw<P, I>(
247 statement_name: &str,
248 params: I,
249 portal: &str,
250 buf: &mut BytesMut,
251) -> Result<(), Error>
252where
253 P: BorrowToSql,
254 I: IntoIterator<Item = (P, Type)>,
255 I::IntoIter: ExactSizeIterator,
256{
257 let (param_formats, params): (Vec<_>, Vec<_>) = params
258 .into_iter()
259 .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty)))
260 .unzip();
261
262 let mut error_idx = 0;
263 let r = frontend::bind(
264 portal,
265 statement_name,
266 param_formats,
267 params.into_iter().enumerate(),
268 |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) {
269 Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
270 Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
271 Err(e) => {
272 error_idx = idx;
273 Err(e)
274 }
275 },
276 Some(1),
277 buf,
278 );
279 match r {
280 Ok(()) => Ok(()),
281 Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
282 Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
283 }
284}
285
286pin_project! {
287 pub struct RowStream {
289 statement: Statement,
290 responses: Responses,
291 rows_affected: Option<u64>,
292 #[pin]
293 _p: PhantomPinned,
294 }
295}
296
297impl Stream for RowStream {
298 type Item = Result<Row, Error>;
299
300 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301 let this = self.project();
302 loop {
303 match ready!(this.responses.poll_next(cx)?) {
304 Message::DataRow(body) => {
305 return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
306 }
307 Message::CommandComplete(body) => {
308 *this.rows_affected = Some(extract_row_affected(&body)?);
309 }
310 Message::EmptyQueryResponse | Message::PortalSuspended => {}
311 Message::ReadyForQuery(_) => return Poll::Ready(None),
312 _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
313 }
314 }
315 }
316}
317
318impl RowStream {
319 pub fn rows_affected(&self) -> Option<u64> {
323 self.rows_affected
324 }
325}