Skip to main content

mz_postgres_util/
schemas.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities to fetch schema information for Postgres sources.
11
12use std::collections::{BTreeMap, BTreeSet};
13
14use tokio_postgres::Client;
15use tokio_postgres::types::Oid;
16
17use crate::desc::{PostgresColumnDesc, PostgresKeyDesc, PostgresSchemaDesc, PostgresTableDesc};
18use crate::{PostgresError, simple_query_opt};
19
20pub async fn get_schemas(client: &Client) -> Result<Vec<PostgresSchemaDesc>, PostgresError> {
21    Ok(client
22        .query("SELECT oid, nspname, nspowner FROM pg_namespace", &[])
23        .await?
24        .into_iter()
25        .map(|row| {
26            let oid: Oid = row.get("oid");
27            let name: String = row.get("nspname");
28            let owner: Oid = row.get("nspowner");
29            PostgresSchemaDesc { oid, name, owner }
30        })
31        .collect::<Vec<_>>())
32}
33
34/// Get the major version of the PostgreSQL server.
35pub async fn get_pg_major_version(client: &Client) -> Result<u32, PostgresError> {
36    // server_version_num is an integer like 140005 for version 14.5
37    let query = "SHOW server_version_num";
38    let row = simple_query_opt(client, query).await?;
39    let version_num: u32 = row
40        .and_then(|r| r.get("server_version_num").map(|s| s.parse().ok()))
41        .flatten()
42        .ok_or_else(|| {
43            PostgresError::Generic(anyhow::anyhow!("failed to get PostgreSQL version"))
44        })?;
45    // server_version_num format: XXYYZZ where XX is major, YY is minor, ZZ is patch
46    // For PG >= 10, it's XXXYYZZ (3 digit major)
47    Ok(version_num / 10000)
48}
49
50/// Fetches table schema information from an upstream Postgres source for tables
51/// that are part of a publication, given a connection string and the
52/// publication name. Returns a map from table OID to table schema information.
53///
54/// The `oids` parameter controls for which tables to fetch schema information. If `None`,
55/// schema information for all tables in the publication is fetched. If `Some`, only
56/// schema information for the tables with the specified OIDs is fetched.
57///
58/// # Errors
59///
60/// - Invalid connection string, user information, or user permissions.
61/// - Upstream publication does not exist or contains invalid values.
62pub async fn publication_info(
63    client: &Client,
64    publication: &str,
65    oids: Option<&[Oid]>,
66) -> Result<BTreeMap<Oid, PostgresTableDesc>, PostgresError> {
67    let server_major_version = get_pg_major_version(client).await?;
68
69    client
70        .query(
71            "SELECT oid FROM pg_publication WHERE pubname = $1",
72            &[&publication],
73        )
74        .await
75        .map_err(PostgresError::from)?
76        .get(0)
77        .ok_or_else(|| PostgresError::PublicationMissing(publication.to_string()))?;
78
79    let tables = if let Some(oids) = oids {
80        client
81            .query(
82                "SELECT
83                    c.oid, p.schemaname, p.tablename
84                FROM
85                    pg_catalog.pg_class AS c
86                    JOIN pg_namespace AS n ON c.relnamespace = n.oid
87                    JOIN pg_publication_tables AS p ON
88                            c.relname = p.tablename AND n.nspname = p.schemaname
89                WHERE
90                    p.pubname = $1
91                    AND c.oid = ANY ($2)",
92                &[&publication, &oids],
93            )
94            .await
95    } else {
96        client
97            .query(
98                "SELECT
99                    c.oid, p.schemaname, p.tablename
100                FROM
101                    pg_catalog.pg_class AS c
102                    JOIN pg_namespace AS n ON c.relnamespace = n.oid
103                    JOIN pg_publication_tables AS p ON
104                            c.relname = p.tablename AND n.nspname = p.schemaname
105                WHERE
106                    p.pubname = $1",
107                &[&publication],
108            )
109            .await
110    }?;
111
112    // The Postgres replication protocol does not support GENERATED columns
113    // so we exclude them from this query. But not all Postgres-like
114    // databases have the `pg_attribute.attgenerated` column.
115    let attgenerated = if server_major_version >= 12 {
116        "a.attgenerated = ''"
117    } else {
118        "true"
119    };
120
121    let pg_columns = format!(
122        "
123        SELECT
124            a.attrelid AS table_oid,
125            a.attname AS name,
126            a.atttypid AS typoid,
127            a.attnum AS colnum,
128            a.atttypmod AS typmod,
129            a.attnotnull AS not_null,
130            b.oid IS NOT NULL AS primary_key
131        FROM pg_catalog.pg_attribute a
132        LEFT JOIN pg_catalog.pg_constraint b
133            ON a.attrelid = b.conrelid
134            AND b.contype = 'p'
135            AND a.attnum = ANY (b.conkey)
136        WHERE a.attnum > 0::pg_catalog.int2
137            AND NOT a.attisdropped
138            AND {attgenerated}
139            AND a.attrelid = ANY ($1)
140        ORDER BY a.attnum"
141    );
142
143    let table_oids = tables
144        .iter()
145        .map(|row| row.get("oid"))
146        .collect::<Vec<Oid>>();
147
148    let mut columns: BTreeMap<Oid, Vec<_>> = BTreeMap::new();
149    for row in client.query(&pg_columns, &[&table_oids]).await? {
150        let table_oid: Oid = row.get("table_oid");
151        let name: String = row.get("name");
152        let type_oid = row.get("typoid");
153        let col_num = row
154            .get::<_, i16>("colnum")
155            .try_into()
156            .expect("non-negative values");
157        let type_mod: i32 = row.get("typmod");
158        let not_null: bool = row.get("not_null");
159        let desc = PostgresColumnDesc {
160            name,
161            col_num,
162            type_oid,
163            type_mod,
164            nullable: !not_null,
165        };
166        columns.entry(table_oid).or_default().push(desc);
167    }
168
169    // PG 15 adds UNIQUE NULLS NOT DISTINCT, which would let us use `UNIQUE` constraints over
170    // nullable columns as keys; i.e. aligns a PG index's NULL handling with an arrangement's
171    // keys. For more info, see https://www.postgresql.org/about/featurematrix/detail/392/
172    let nulls_not_distinct = if server_major_version >= 15 {
173        "pg_index.indnullsnotdistinct"
174    } else {
175        "false"
176    };
177    let pg_keys = format!(
178        "
179        SELECT
180            pg_constraint.conrelid AS table_oid,
181            pg_constraint.oid,
182            pg_constraint.conkey,
183            pg_constraint.conname,
184            pg_constraint.contype = 'p' AS is_primary,
185            {nulls_not_distinct} AS nulls_not_distinct
186        FROM
187            pg_constraint
188                JOIN
189                    pg_index
190                    ON pg_index.indexrelid = pg_constraint.conindid
191        WHERE
192            pg_constraint.conrelid = ANY ($1)
193                AND
194            pg_constraint.contype = ANY (ARRAY['p', 'u']);"
195    );
196
197    let mut keys: BTreeMap<Oid, BTreeSet<_>> = BTreeMap::new();
198    for row in client.query(&pg_keys, &[&table_oids]).await? {
199        let table_oid: Oid = row.get("table_oid");
200        let oid: Oid = row.get("oid");
201        let cols: Vec<i16> = row.get("conkey");
202        let name: String = row.get("conname");
203        let is_primary: bool = row.get("is_primary");
204        let nulls_not_distinct: bool = row.get("nulls_not_distinct");
205        let cols = cols
206            .into_iter()
207            .map(|col| u16::try_from(col).expect("non-negative colnums"))
208            .collect();
209        let desc = PostgresKeyDesc {
210            oid,
211            name,
212            cols,
213            is_primary,
214            nulls_not_distinct,
215        };
216        keys.entry(table_oid).or_default().insert(desc);
217    }
218
219    Ok(tables
220        .into_iter()
221        .map(|row| {
222            let oid: Oid = row.get("oid");
223            let columns = columns.remove(&oid).unwrap_or_default();
224            let keys = keys.remove(&oid).unwrap_or_default();
225            let desc = PostgresTableDesc {
226                oid,
227                namespace: row.get("schemaname"),
228                name: row.get("tablename"),
229                columns,
230                keys,
231            };
232            (oid, desc)
233        })
234        .collect())
235}