tokio_postgres/
prepare.rs

1use crate::client::InnerClient;
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::error::SqlState;
5use crate::types::{Field, Kind, Oid, Type};
6use crate::{query, slice_iter};
7use crate::{Column, Error, Statement};
8use bytes::Bytes;
9use fallible_iterator::FallibleIterator;
10use futures_util::{pin_mut, TryStreamExt};
11use log::debug;
12use postgres_protocol::message::backend::Message;
13use postgres_protocol::message::frontend;
14use std::future::Future;
15use std::pin::Pin;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18
19const TYPEINFO_QUERY: &str = "\
20SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
21FROM pg_catalog.pg_type t
22LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
23INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
24WHERE t.oid = $1
25";
26
27// Range types weren't added until Postgres 9.2, so pg_range may not exist
28const TYPEINFO_FALLBACK_QUERY: &str = "\
29SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid
30FROM pg_catalog.pg_type t
31INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
32WHERE t.oid = $1
33";
34
35const TYPEINFO_ENUM_QUERY: &str = "\
36SELECT enumlabel
37FROM pg_catalog.pg_enum
38WHERE enumtypid = $1
39ORDER BY enumsortorder
40";
41
42// Postgres 9.0 didn't have enumsortorder
43const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\
44SELECT enumlabel
45FROM pg_catalog.pg_enum
46WHERE enumtypid = $1
47ORDER BY oid
48";
49
50const TYPEINFO_COMPOSITE_QUERY: &str = "\
51SELECT attname, atttypid
52FROM pg_catalog.pg_attribute
53WHERE attrelid = $1
54AND NOT attisdropped
55AND attnum > 0
56ORDER BY attnum
57";
58
59static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
60
61pub async fn prepare(
62    client: &Arc<InnerClient>,
63    query: &str,
64    types: &[Type],
65) -> Result<Statement, Error> {
66    let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
67    let buf = encode(client, &name, query, types)?;
68    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
69
70    match responses.next().await? {
71        Message::ParseComplete => {}
72        _ => return Err(Error::unexpected_message()),
73    }
74
75    let parameter_description = match responses.next().await? {
76        Message::ParameterDescription(body) => body,
77        _ => return Err(Error::unexpected_message()),
78    };
79
80    let row_description = match responses.next().await? {
81        Message::RowDescription(body) => Some(body),
82        Message::NoData => None,
83        _ => return Err(Error::unexpected_message()),
84    };
85
86    let mut parameters = vec![];
87    let mut it = parameter_description.parameters();
88    while let Some(oid) = it.next().map_err(Error::parse)? {
89        let type_ = get_type(client, oid).await?;
90        parameters.push(type_);
91    }
92
93    let mut columns = vec![];
94    if let Some(row_description) = row_description {
95        let mut it = row_description.fields();
96        while let Some(field) = it.next().map_err(Error::parse)? {
97            let type_ = get_type(client, field.type_oid()).await?;
98            let column = Column {
99                name: field.name().to_string(),
100                table_oid: Some(field.table_oid()).filter(|n| *n != 0),
101                column_id: Some(field.column_id()).filter(|n| *n != 0),
102                r#type: type_,
103            };
104            columns.push(column);
105        }
106    }
107
108    Ok(Statement::new(client, name, parameters, columns))
109}
110
111fn prepare_rec<'a>(
112    client: &'a Arc<InnerClient>,
113    query: &'a str,
114    types: &'a [Type],
115) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
116    Box::pin(prepare(client, query, types))
117}
118
119fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
120    if types.is_empty() {
121        debug!("preparing query {}: {}", name, query);
122    } else {
123        debug!("preparing query {} with types {:?}: {}", name, types, query);
124    }
125
126    client.with_buf(|buf| {
127        frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
128        frontend::describe(b'S', name, buf).map_err(Error::encode)?;
129        frontend::sync(buf);
130        Ok(buf.split().freeze())
131    })
132}
133
134pub(crate) async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
135    if let Some(type_) = Type::from_oid(oid) {
136        return Ok(type_);
137    }
138
139    if let Some(type_) = client.type_(oid) {
140        return Ok(type_);
141    }
142
143    let stmt = typeinfo_statement(client).await?;
144
145    let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
146    pin_mut!(rows);
147
148    let row = match rows.try_next().await? {
149        Some(row) => row,
150        None => return Err(Error::unexpected_message()),
151    };
152
153    let name: String = row.try_get(0)?;
154    let type_: i8 = row.try_get(1)?;
155    let elem_oid: Oid = row.try_get(2)?;
156    let rngsubtype: Option<Oid> = row.try_get(3)?;
157    let basetype: Oid = row.try_get(4)?;
158    let schema: String = row.try_get(5)?;
159    let relid: Oid = row.try_get(6)?;
160
161    let kind = if type_ == b'e' as i8 {
162        let variants = get_enum_variants(client, oid).await?;
163        Kind::Enum(variants)
164    } else if type_ == b'p' as i8 {
165        Kind::Pseudo
166    } else if basetype != 0 {
167        let type_ = get_type_rec(client, basetype).await?;
168        Kind::Domain(type_)
169    } else if elem_oid != 0 {
170        let type_ = get_type_rec(client, elem_oid).await?;
171        Kind::Array(type_)
172    } else if relid != 0 {
173        let fields = get_composite_fields(client, relid).await?;
174        Kind::Composite(fields)
175    } else if let Some(rngsubtype) = rngsubtype {
176        let type_ = get_type_rec(client, rngsubtype).await?;
177        Kind::Range(type_)
178    } else {
179        Kind::Simple
180    };
181
182    let type_ = Type::new(name, oid, kind, schema);
183    client.set_type(oid, &type_);
184
185    Ok(type_)
186}
187
188fn get_type_rec<'a>(
189    client: &'a Arc<InnerClient>,
190    oid: Oid,
191) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
192    Box::pin(get_type(client, oid))
193}
194
195async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
196    if let Some(stmt) = client.typeinfo() {
197        return Ok(stmt);
198    }
199
200    let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await {
201        Ok(stmt) => stmt,
202        Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
203            prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await?
204        }
205        Err(e) => return Err(e),
206    };
207
208    client.set_typeinfo(&stmt);
209    Ok(stmt)
210}
211
212async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
213    let stmt = typeinfo_enum_statement(client).await?;
214
215    query::query(client, stmt, slice_iter(&[&oid]))
216        .await?
217        .and_then(|row| async move { row.try_get(0) })
218        .try_collect()
219        .await
220}
221
222async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
223    if let Some(stmt) = client.typeinfo_enum() {
224        return Ok(stmt);
225    }
226
227    let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await {
228        Ok(stmt) => stmt,
229        Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
230            prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
231        }
232        Err(e) => return Err(e),
233    };
234
235    client.set_typeinfo_enum(&stmt);
236    Ok(stmt)
237}
238
239async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
240    let stmt = typeinfo_composite_statement(client).await?;
241
242    let rows = query::query(client, stmt, slice_iter(&[&oid]))
243        .await?
244        .try_collect::<Vec<_>>()
245        .await?;
246
247    let mut fields = vec![];
248    for row in rows {
249        let name = row.try_get(0)?;
250        let oid = row.try_get(1)?;
251        let type_ = get_type_rec(client, oid).await?;
252        fields.push(Field::new(name, type_));
253    }
254
255    Ok(fields)
256}
257
258async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
259    if let Some(stmt) = client.typeinfo_composite() {
260        return Ok(stmt);
261    }
262
263    let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
264
265    client.set_typeinfo_composite(&stmt);
266    Ok(stmt)
267}