mz_postgres_util/
schemas.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities to fetch schema information for Postgres sources.
11
12use std::collections::{BTreeMap, BTreeSet};
13
14use anyhow::anyhow;
15use tokio_postgres::Client;
16use tokio_postgres::types::Oid;
17
18use crate::PostgresError;
19use crate::desc::{PostgresColumnDesc, PostgresKeyDesc, PostgresSchemaDesc, PostgresTableDesc};
20
21pub async fn get_schemas(client: &Client) -> Result<Vec<PostgresSchemaDesc>, PostgresError> {
22    Ok(client
23        .query("SELECT oid, nspname, nspowner FROM pg_namespace", &[])
24        .await?
25        .into_iter()
26        .map(|row| {
27            let oid: Oid = row.get("oid");
28            let name: String = row.get("nspname");
29            let owner: Oid = row.get("nspowner");
30            PostgresSchemaDesc { oid, name, owner }
31        })
32        .collect::<Vec<_>>())
33}
34
35/// Fetches table schema information from an upstream Postgres source for tables
36/// that are part of a publication, given a connection string and the
37/// publication name. Returns a map from table OID to table schema information.
38///
39/// The `oids` parameter controls for which tables to fetch schema information. If `None`,
40/// schema information for all tables in the publication is fetched. If `Some`, only
41/// schema information for the tables with the specified OIDs is fetched.
42///
43/// # Errors
44///
45/// - Invalid connection string, user information, or user permissions.
46/// - Upstream publication does not exist or contains invalid values.
47pub async fn publication_info(
48    client: &Client,
49    publication: &str,
50    oids: Option<&[Oid]>,
51) -> Result<BTreeMap<Oid, PostgresTableDesc>, PostgresError> {
52    let server_version_num = client
53        .query_one("SHOW server_version_num", &[])
54        .await?
55        .get::<_, &str>("server_version_num")
56        .parse::<i32>()
57        .map_err(|e| PostgresError::Generic(anyhow!("unable to parse server_version_num: {e}")))?;
58
59    client
60        .query(
61            "SELECT oid FROM pg_publication WHERE pubname = $1",
62            &[&publication],
63        )
64        .await
65        .map_err(PostgresError::from)?
66        .get(0)
67        .ok_or_else(|| PostgresError::PublicationMissing(publication.to_string()))?;
68
69    let tables = if let Some(oids) = oids {
70        client
71            .query(
72                "SELECT
73                    c.oid, p.schemaname, p.tablename
74                FROM
75                    pg_catalog.pg_class AS c
76                    JOIN pg_namespace AS n ON c.relnamespace = n.oid
77                    JOIN pg_publication_tables AS p ON
78                            c.relname = p.tablename AND n.nspname = p.schemaname
79                WHERE
80                    p.pubname = $1
81                    AND c.oid = ANY ($2)",
82                &[&publication, &oids],
83            )
84            .await
85    } else {
86        client
87            .query(
88                "SELECT
89                    c.oid, p.schemaname, p.tablename
90                FROM
91                    pg_catalog.pg_class AS c
92                    JOIN pg_namespace AS n ON c.relnamespace = n.oid
93                    JOIN pg_publication_tables AS p ON
94                            c.relname = p.tablename AND n.nspname = p.schemaname
95                WHERE
96                    p.pubname = $1",
97                &[&publication],
98            )
99            .await
100    }?;
101
102    // The Postgres replication protocol does not support GENERATED columns
103    // so we exclude them from this query. But not all Postgres-like
104    // databases have the `pg_attribute.attgenerated` column.
105    let attgenerated = if server_version_num >= 120000 {
106        "a.attgenerated = ''"
107    } else {
108        "true"
109    };
110
111    let pg_columns = format!(
112        "
113        SELECT
114            a.attrelid AS table_oid,
115            a.attname AS name,
116            a.atttypid AS typoid,
117            a.attnum AS colnum,
118            a.atttypmod AS typmod,
119            a.attnotnull AS not_null,
120            b.oid IS NOT NULL AS primary_key
121        FROM pg_catalog.pg_attribute a
122        LEFT JOIN pg_catalog.pg_constraint b
123            ON a.attrelid = b.conrelid
124            AND b.contype = 'p'
125            AND a.attnum = ANY (b.conkey)
126        WHERE a.attnum > 0::pg_catalog.int2
127            AND NOT a.attisdropped
128            AND {attgenerated}
129            AND a.attrelid = ANY ($1)
130        ORDER BY a.attnum"
131    );
132
133    let table_oids = tables
134        .iter()
135        .map(|row| row.get("oid"))
136        .collect::<Vec<Oid>>();
137
138    let mut columns: BTreeMap<Oid, Vec<_>> = BTreeMap::new();
139    for row in client.query(&pg_columns, &[&table_oids]).await? {
140        let table_oid: Oid = row.get("table_oid");
141        let name: String = row.get("name");
142        let type_oid = row.get("typoid");
143        let col_num = row
144            .get::<_, i16>("colnum")
145            .try_into()
146            .expect("non-negative values");
147        let type_mod: i32 = row.get("typmod");
148        let not_null: bool = row.get("not_null");
149        let desc = PostgresColumnDesc {
150            name,
151            col_num,
152            type_oid,
153            type_mod,
154            nullable: !not_null,
155        };
156        columns.entry(table_oid).or_default().push(desc);
157    }
158
159    // PG 15 adds UNIQUE NULLS NOT DISTINCT, which would let us use `UNIQUE` constraints over
160    // nullable columns as keys; i.e. aligns a PG index's NULL handling with an arrangement's
161    // keys. For more info, see https://www.postgresql.org/about/featurematrix/detail/392/
162    let nulls_not_distinct = if server_version_num >= 150000 {
163        "pg_index.indnullsnotdistinct"
164    } else {
165        "false"
166    };
167    let pg_keys = format!(
168        "
169        SELECT
170            pg_constraint.conrelid AS table_oid,
171            pg_constraint.oid,
172            pg_constraint.conkey,
173            pg_constraint.conname,
174            pg_constraint.contype = 'p' AS is_primary,
175            {nulls_not_distinct} AS nulls_not_distinct
176        FROM
177            pg_constraint
178                JOIN
179                    pg_index
180                    ON pg_index.indexrelid = pg_constraint.conindid
181        WHERE
182            pg_constraint.conrelid = ANY ($1)
183                AND
184            pg_constraint.contype = ANY (ARRAY['p', 'u']);"
185    );
186
187    let mut keys: BTreeMap<Oid, BTreeSet<_>> = BTreeMap::new();
188    for row in client.query(&pg_keys, &[&table_oids]).await? {
189        let table_oid: Oid = row.get("table_oid");
190        let oid: Oid = row.get("oid");
191        let cols: Vec<i16> = row.get("conkey");
192        let name: String = row.get("conname");
193        let is_primary: bool = row.get("is_primary");
194        let nulls_not_distinct: bool = row.get("nulls_not_distinct");
195        let cols = cols
196            .into_iter()
197            .map(|col| u16::try_from(col).expect("non-negative colnums"))
198            .collect();
199        let desc = PostgresKeyDesc {
200            oid,
201            name,
202            cols,
203            is_primary,
204            nulls_not_distinct,
205        };
206        keys.entry(table_oid).or_default().insert(desc);
207    }
208
209    Ok(tables
210        .into_iter()
211        .map(|row| {
212            let oid: Oid = row.get("oid");
213            let columns = columns.remove(&oid).unwrap_or_default();
214            let keys = keys.remove(&oid).unwrap_or_default();
215            let desc = PostgresTableDesc {
216                oid,
217                namespace: row.get("schemaname"),
218                name: row.get("tablename"),
219                columns,
220                keys,
221            };
222            (oid, desc)
223        })
224        .collect())
225}