Skip to main content

mz_postgres_util/
schemas.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities to fetch schema information for Postgres sources.
11
12use std::collections::{BTreeMap, BTreeSet};
13
14use tokio_postgres::Client;
15use tokio_postgres::types::Oid;
16
17use crate::desc::{PostgresColumnDesc, PostgresKeyDesc, PostgresSchemaDesc, PostgresTableDesc};
18use crate::{PostgresError, simple_query_opt};
19
20pub async fn get_schemas(client: &Client) -> Result<Vec<PostgresSchemaDesc>, PostgresError> {
21    Ok(client
22        .query("SELECT oid, nspname, nspowner FROM pg_namespace", &[])
23        .await?
24        .into_iter()
25        .map(|row| {
26            let oid: Oid = row.get("oid");
27            let name: String = row.get("nspname");
28            let owner: Oid = row.get("nspowner");
29            PostgresSchemaDesc { oid, name, owner }
30        })
31        .collect::<Vec<_>>())
32}
33
34/// Get the major version of the PostgreSQL server.
35pub async fn get_pg_major_version(client: &Client) -> Result<u32, PostgresError> {
36    // server_version_num is an integer like 140005 for version 14.5
37    // NOTE: We use the statement SELECT instead of SHOW because older Aurora
38    // versions don't support SHOW via a replication channel.
39    let query = "SELECT current_setting('server_version_num')";
40    let row = simple_query_opt(client, query).await?;
41    let version_num: u32 = row
42        .and_then(|r| r.get("current_setting").map(|s| s.parse().ok()))
43        .flatten()
44        .ok_or_else(|| {
45            PostgresError::Generic(anyhow::anyhow!("failed to get PostgreSQL version"))
46        })?;
47    // server_version_num format: XXYYZZ where XX is major, YY is minor, ZZ is patch
48    // For PG >= 10, it's XXXYYZZ (3 digit major)
49    Ok(version_num / 10000)
50}
51
52/// Fetches table schema information from an upstream Postgres source for tables
53/// that are part of a publication, given a connection string and the
54/// publication name. Returns a map from table OID to table schema information.
55///
56/// The `oids` parameter controls for which tables to fetch schema information. If `None`,
57/// schema information for all tables in the publication is fetched. If `Some`, only
58/// schema information for the tables with the specified OIDs is fetched.
59///
60/// # Errors
61///
62/// - Invalid connection string, user information, or user permissions.
63/// - Upstream publication does not exist or contains invalid values.
64pub async fn publication_info(
65    client: &Client,
66    publication: &str,
67    oids: Option<&[Oid]>,
68) -> Result<BTreeMap<Oid, PostgresTableDesc>, PostgresError> {
69    let server_major_version = get_pg_major_version(client).await?;
70
71    client
72        .query(
73            "SELECT oid FROM pg_publication WHERE pubname = $1",
74            &[&publication],
75        )
76        .await
77        .map_err(PostgresError::from)?
78        .get(0)
79        .ok_or_else(|| PostgresError::PublicationMissing(publication.to_string()))?;
80
81    let tables = if let Some(oids) = oids {
82        client
83            .query(
84                "SELECT
85                    c.oid, p.schemaname, p.tablename
86                FROM
87                    pg_catalog.pg_class AS c
88                    JOIN pg_namespace AS n ON c.relnamespace = n.oid
89                    JOIN pg_publication_tables AS p ON
90                            c.relname = p.tablename AND n.nspname = p.schemaname
91                WHERE
92                    p.pubname = $1
93                    AND c.oid = ANY ($2)",
94                &[&publication, &oids],
95            )
96            .await
97    } else {
98        client
99            .query(
100                "SELECT
101                    c.oid, p.schemaname, p.tablename
102                FROM
103                    pg_catalog.pg_class AS c
104                    JOIN pg_namespace AS n ON c.relnamespace = n.oid
105                    JOIN pg_publication_tables AS p ON
106                            c.relname = p.tablename AND n.nspname = p.schemaname
107                WHERE
108                    p.pubname = $1",
109                &[&publication],
110            )
111            .await
112    }?;
113
114    // The Postgres replication protocol does not support GENERATED columns
115    // so we exclude them from this query. But not all Postgres-like
116    // databases have the `pg_attribute.attgenerated` column.
117    let attgenerated = if server_major_version >= 12 {
118        "a.attgenerated = ''"
119    } else {
120        "true"
121    };
122
123    let pg_columns = format!(
124        "
125        SELECT
126            a.attrelid AS table_oid,
127            a.attname AS name,
128            a.atttypid AS typoid,
129            a.attnum AS colnum,
130            a.atttypmod AS typmod,
131            a.attnotnull AS not_null,
132            b.oid IS NOT NULL AS primary_key
133        FROM pg_catalog.pg_attribute a
134        LEFT JOIN pg_catalog.pg_constraint b
135            ON a.attrelid = b.conrelid
136            AND b.contype = 'p'
137            AND a.attnum = ANY (b.conkey)
138        WHERE a.attnum > 0::pg_catalog.int2
139            AND NOT a.attisdropped
140            AND {attgenerated}
141            AND a.attrelid = ANY ($1)
142        ORDER BY a.attnum"
143    );
144
145    let table_oids = tables
146        .iter()
147        .map(|row| row.get("oid"))
148        .collect::<Vec<Oid>>();
149
150    let mut columns: BTreeMap<Oid, Vec<_>> = BTreeMap::new();
151    for row in client.query(&pg_columns, &[&table_oids]).await? {
152        let table_oid: Oid = row.get("table_oid");
153        let name: String = row.get("name");
154        let type_oid = row.get("typoid");
155        let col_num = row
156            .get::<_, i16>("colnum")
157            .try_into()
158            .expect("non-negative values");
159        let type_mod: i32 = row.get("typmod");
160        let not_null: bool = row.get("not_null");
161        let desc = PostgresColumnDesc {
162            name,
163            col_num,
164            type_oid,
165            type_mod,
166            nullable: !not_null,
167        };
168        columns.entry(table_oid).or_default().push(desc);
169    }
170
171    // PG 15 adds UNIQUE NULLS NOT DISTINCT, which would let us use `UNIQUE` constraints over
172    // nullable columns as keys; i.e. aligns a PG index's NULL handling with an arrangement's
173    // keys. For more info, see https://www.postgresql.org/about/featurematrix/detail/392/
174    let nulls_not_distinct = if server_major_version >= 15 {
175        "pg_index.indnullsnotdistinct"
176    } else {
177        "false"
178    };
179    let pg_keys = format!(
180        "
181        SELECT
182            pg_constraint.conrelid AS table_oid,
183            pg_constraint.oid,
184            pg_constraint.conkey,
185            pg_constraint.conname,
186            pg_constraint.contype = 'p' AS is_primary,
187            {nulls_not_distinct} AS nulls_not_distinct
188        FROM
189            pg_constraint
190                JOIN
191                    pg_index
192                    ON pg_index.indexrelid = pg_constraint.conindid
193        WHERE
194            pg_constraint.conrelid = ANY ($1)
195                AND
196            pg_constraint.contype = ANY (ARRAY['p', 'u']);"
197    );
198
199    let mut keys: BTreeMap<Oid, BTreeSet<_>> = BTreeMap::new();
200    for row in client.query(&pg_keys, &[&table_oids]).await? {
201        let table_oid: Oid = row.get("table_oid");
202        let oid: Oid = row.get("oid");
203        let cols: Vec<i16> = row.get("conkey");
204        let name: String = row.get("conname");
205        let is_primary: bool = row.get("is_primary");
206        let nulls_not_distinct: bool = row.get("nulls_not_distinct");
207        let cols = cols
208            .into_iter()
209            .map(|col| u16::try_from(col).expect("non-negative colnums"))
210            .collect();
211        let desc = PostgresKeyDesc {
212            oid,
213            name,
214            cols,
215            is_primary,
216            nulls_not_distinct,
217        };
218        keys.entry(table_oid).or_default().insert(desc);
219    }
220
221    Ok(tables
222        .into_iter()
223        .map(|row| {
224            let oid: Oid = row.get("oid");
225            let columns = columns.remove(&oid).unwrap_or_default();
226            let keys = keys.remove(&oid).unwrap_or_default();
227            let desc = PostgresTableDesc {
228                oid,
229                namespace: row.get("schemaname"),
230                name: row.get("tablename"),
231                columns,
232                keys,
233            };
234            (oid, desc)
235        })
236        .collect())
237}