Skip to main content

mz_deploy/project/resolve/normalize/
mod_rewriter.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! AST-based rewriting of database and schema names in mod statements.
11//!
12//! When a profile suffix is active, database names in mod files must be
13//! suffixed (e.g., `app` -> `app_dev`). When a staging suffix is active,
14//! schema names must be suffixed (e.g., `public` -> `public_staging`).
15//!
16//! This module provides [`rewrite_database_names`] and [`rewrite_schema_names`],
17//! which apply these transformations at the AST level using the auto-generated
18//! [`VisitMut`] traversal. This is safer than raw text substitution because it
19//! only touches actual identifier nodes, not string literals or comments.
20//!
21//! ## Supported Statement Types
22//!
23//! The visitor handles all statement types permitted in mod files:
24//! - `COMMENT ON DATABASE/SCHEMA`
25//! - `GRANT ... ON DATABASE/SCHEMA`
26//! - `ALTER DEFAULT PRIVILEGES IN DATABASE/SCHEMA`
27//!
28//! The auto-generated traversal uses two separate hooks for database/schema
29//! names depending on AST position:
30//!
31//! - `visit_database_name_mut` / `visit_schema_name_mut` — associated type
32//!   hooks used by COMMENT ON DATABASE/SCHEMA and ALTER DEFAULT PRIVILEGES
33//! - `visit_unresolved_database_name_mut` / `visit_unresolved_schema_name_mut`
34//!   — concrete struct hooks used by GRANT ON DATABASE/SCHEMA (via
35//!   `UnresolvedObjectName::Database/Schema`)
36//!
37//! Both hooks receive `&mut UnresolvedDatabaseName` / `&mut UnresolvedSchemaName`,
38//! so the rewriting logic is shared via a helper function. Both hooks must be
39//! overridden to cover all mod statement types.
40
41use mz_sql_parser::ast::visit_mut::{self, VisitMut};
42use mz_sql_parser::ast::{
43    Ident, Raw, Statement, UnresolvedDatabaseName, UnresolvedObjectName, UnresolvedSchemaName,
44};
45
46/// Append `suffix` to the database name if it matches `database_name`.
47fn rewrite_database_name(node: &mut UnresolvedDatabaseName, database_name: &str, suffix: &str) {
48    if node.0.as_str() == database_name {
49        node.0 =
50            Ident::new(format!("{}{}", database_name, suffix)).expect("valid database identifier");
51    }
52}
53
54/// Append `suffix` to the schema part (last ident) if it matches `schema_name`.
55fn rewrite_schema_name(node: &mut UnresolvedSchemaName, schema_name: &str, suffix: &str) {
56    if let Some(last) = node.0.last() {
57        if last.as_str() == schema_name {
58            let idx = node.0.len() - 1;
59            node.0[idx] =
60                Ident::new(format!("{}{}", schema_name, suffix)).expect("valid schema identifier");
61        }
62    }
63}
64
65/// Visitor that rewrites database name identifiers by appending a suffix.
66///
67/// Overrides both `visit_database_name_mut` (associated type hook, used by
68/// COMMENT and ALTER DEFAULT PRIVILEGES) and `visit_unresolved_database_name_mut`
69/// (concrete struct hook, used by GRANT via `UnresolvedObjectName::Database`).
70struct DatabaseNameRewriter<'a> {
71    database_name: &'a str,
72    suffix: &'a str,
73}
74
75impl<'a> VisitMut<'_, Raw> for DatabaseNameRewriter<'a> {
76    fn visit_database_name_mut(&mut self, node: &mut UnresolvedDatabaseName) {
77        rewrite_database_name(node, self.database_name, self.suffix);
78    }
79
80    fn visit_object_name_mut(&mut self, node: &mut UnresolvedObjectName) {
81        if let UnresolvedObjectName::Database(db_name) = node {
82            rewrite_database_name(db_name, self.database_name, self.suffix);
83        }
84    }
85
86    fn visit_unresolved_database_name_mut(&mut self, node: &mut UnresolvedDatabaseName) {
87        rewrite_database_name(node, self.database_name, self.suffix);
88    }
89}
90
91/// Visitor that rewrites schema name identifiers by appending a suffix.
92///
93/// Overrides both `visit_schema_name_mut` (associated type hook, used by
94/// COMMENT and ALTER DEFAULT PRIVILEGES) and `visit_unresolved_schema_name_mut`
95/// (concrete struct hook, used by GRANT via `UnresolvedObjectName::Schema`).
96struct SchemaNameRewriter<'a> {
97    schema_name: &'a str,
98    suffix: &'a str,
99}
100
101impl<'a> VisitMut<'_, Raw> for SchemaNameRewriter<'a> {
102    fn visit_object_name_mut(&mut self, node: &mut UnresolvedObjectName) {
103        if let UnresolvedObjectName::Schema(schema_name) = node {
104            rewrite_schema_name(schema_name, self.schema_name, self.suffix);
105        }
106    }
107
108    fn visit_schema_name_mut(&mut self, node: &mut UnresolvedSchemaName) {
109        rewrite_schema_name(node, self.schema_name, self.suffix);
110    }
111
112    fn visit_unresolved_schema_name_mut(&mut self, node: &mut UnresolvedSchemaName) {
113        rewrite_schema_name(node, self.schema_name, self.suffix);
114    }
115}
116
117/// Rewrite database names in parsed mod statements by appending a suffix.
118///
119/// Applies suffix-based renaming to all [`UnresolvedDatabaseName`] nodes
120/// in the given statements using the auto-generated [`VisitMut`] traversal.
121/// Only identifiers that exactly match `database_name` are rewritten;
122/// string literals, comments, and other identifiers are untouched.
123///
124/// # Arguments
125/// * `statements` - Parsed mod statements to rewrite (mutated in place)
126/// * `database_name` - The original database name to match
127/// * `suffix` - The suffix to append (e.g., `"_dev"`)
128pub(crate) fn rewrite_database_names(
129    statements: &mut [Statement<Raw>],
130    database_name: &str,
131    suffix: &str,
132) {
133    let mut rewriter = DatabaseNameRewriter {
134        database_name,
135        suffix,
136    };
137    for stmt in statements {
138        visit_mut::visit_statement_mut(&mut rewriter, stmt);
139    }
140}
141
142/// Rewrite schema names in parsed mod statements by appending a suffix.
143///
144/// Applies suffix-based renaming to all [`UnresolvedSchemaName`] nodes
145/// in the given statements using the auto-generated [`VisitMut`] traversal.
146/// Matches the last identifier in each schema name (the schema part),
147/// so both `schema` and `db.schema` forms are handled.
148///
149/// # Arguments
150/// * `statements` - Parsed mod statements to rewrite (mutated in place)
151/// * `schema_name` - The original schema name to match
152/// * `suffix` - The suffix to append (e.g., `"_staging"`)
153pub(crate) fn rewrite_schema_names(
154    statements: &mut [Statement<Raw>],
155    schema_name: &str,
156    suffix: &str,
157) {
158    let mut rewriter = SchemaNameRewriter {
159        schema_name,
160        suffix,
161    };
162    for stmt in statements {
163        visit_mut::visit_statement_mut(&mut rewriter, stmt);
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use mz_sql_parser::parser::parse_statements;
171
172    fn parse_one(sql: &str) -> Statement<Raw> {
173        let stmts = parse_statements(sql).expect("valid SQL");
174        assert_eq!(stmts.len(), 1, "expected exactly one statement");
175        stmts.into_iter().next().unwrap().ast
176    }
177
178    fn rewrite_db(sql: &str, db_name: &str, suffix: &str) -> String {
179        let mut stmts = vec![parse_one(sql)];
180        rewrite_database_names(&mut stmts, db_name, suffix);
181        format!("{}", stmts[0])
182    }
183
184    fn rewrite_schema(sql: &str, schema_name: &str, suffix: &str) -> String {
185        let mut stmts = vec![parse_one(sql)];
186        rewrite_schema_names(&mut stmts, schema_name, suffix);
187        format!("{}", stmts[0])
188    }
189
190    // --- Database name rewriting ---
191
192    #[mz_ore::test]
193    fn test_comment_on_database() {
194        let result = rewrite_db(
195            "COMMENT ON DATABASE app IS 'app description'",
196            "app",
197            "_dev",
198        );
199        assert!(
200            result.contains("app_dev"),
201            "database name should be rewritten: {result}"
202        );
203        assert!(
204            result.contains("app description"),
205            "string literal should be untouched: {result}"
206        );
207    }
208
209    #[mz_ore::test]
210    fn test_grant_on_database() {
211        let result = rewrite_db("GRANT ALL ON DATABASE app TO role1", "app", "_dev");
212        assert!(
213            result.contains("app_dev"),
214            "database name should be rewritten: {result}"
215        );
216    }
217
218    #[mz_ore::test]
219    fn test_alter_default_privileges_database() {
220        let result = rewrite_db(
221            "ALTER DEFAULT PRIVILEGES FOR ALL ROLES IN DATABASE app GRANT SELECT ON TABLES TO role1",
222            "app",
223            "_dev",
224        );
225        assert!(
226            result.contains("app_dev"),
227            "database name should be rewritten: {result}"
228        );
229    }
230
231    #[mz_ore::test]
232    fn test_database_no_match_passthrough() {
233        let original = "COMMENT ON DATABASE other IS 'untouched'";
234        let result = rewrite_db(original, "app", "_dev");
235        assert!(
236            !result.contains("_dev"),
237            "non-matching database should be untouched: {result}"
238        );
239    }
240
241    // --- Schema name rewriting ---
242
243    #[mz_ore::test]
244    fn test_comment_on_schema_qualified() {
245        let result = rewrite_schema(
246            "COMMENT ON SCHEMA app.public IS 'schema description'",
247            "public",
248            "_staging",
249        );
250        assert!(
251            result.contains("public_staging"),
252            "schema name should be rewritten: {result}"
253        );
254        assert!(
255            result.contains("schema description"),
256            "string literal should be untouched: {result}"
257        );
258    }
259
260    #[mz_ore::test]
261    fn test_grant_on_schema() {
262        let result = rewrite_schema(
263            "GRANT USAGE ON SCHEMA app.public TO role1",
264            "public",
265            "_staging",
266        );
267        assert!(
268            result.contains("public_staging"),
269            "schema name should be rewritten: {result}"
270        );
271    }
272
273    #[mz_ore::test]
274    fn test_alter_default_privileges_schema() {
275        let result = rewrite_schema(
276            "ALTER DEFAULT PRIVILEGES FOR ALL ROLES IN SCHEMA app.public GRANT SELECT ON TABLES TO role1",
277            "public",
278            "_staging",
279        );
280        assert!(
281            result.contains("public_staging"),
282            "schema name should be rewritten: {result}"
283        );
284    }
285
286    #[mz_ore::test]
287    fn test_schema_no_match_passthrough() {
288        let original = "COMMENT ON SCHEMA app.other IS 'untouched'";
289        let result = rewrite_schema(original, "public", "_staging");
290        assert!(
291            !result.contains("_staging"),
292            "non-matching schema should be untouched: {result}"
293        );
294    }
295}