Skip to main content

mz_sql/ast/
transform.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Provides a publicly available interface to transform our SQL ASTs.
11
12use std::collections::{BTreeMap, BTreeSet};
13
14use mz_ore::str::StrExt;
15use mz_repr::CatalogItemId;
16use mz_sql_parser::ast::CreateTableFromSourceStatement;
17
18use crate::ast::visit::{self, Visit};
19use crate::ast::visit_mut::{self, VisitMut};
20use crate::ast::{
21    AstInfo, CreateConnectionStatement, CreateIndexStatement, CreateMaterializedViewStatement,
22    CreateSecretStatement, CreateSinkStatement, CreateSourceStatement, CreateSubsourceStatement,
23    CreateTableStatement, CreateViewStatement, CreateWebhookSourceStatement, Expr, Ident, Query,
24    Raw, RawItemName, Statement, UnresolvedItemName, ViewDefinition,
25};
26use crate::names::FullItemName;
27
28/// Given a [`Statement`] rewrites all references of the schema name `cur_schema_name` to
29/// `new_schema_name`.
30pub fn create_stmt_rename_schema_refs(
31    create_stmt: &mut Statement<Raw>,
32    database: &str,
33    cur_schema: &str,
34    new_schema: &str,
35) -> Result<(), (String, String)> {
36    match create_stmt {
37        stmt @ Statement::CreateConnection(_)
38        | stmt @ Statement::CreateDatabase(_)
39        | stmt @ Statement::CreateSchema(_)
40        | stmt @ Statement::CreateWebhookSource(_)
41        | stmt @ Statement::CreateSource(_)
42        | stmt @ Statement::CreateSubsource(_)
43        | stmt @ Statement::CreateSink(_)
44        | stmt @ Statement::CreateView(_)
45        | stmt @ Statement::CreateMaterializedView(_)
46        | stmt @ Statement::CreateTable(_)
47        | stmt @ Statement::CreateTableFromSource(_)
48        | stmt @ Statement::CreateIndex(_)
49        | stmt @ Statement::CreateType(_)
50        | stmt @ Statement::CreateSecret(_) => {
51            let mut visitor = CreateSqlRewriteSchema {
52                database,
53                cur_schema,
54                new_schema,
55                error: None,
56            };
57            visitor.visit_statement_mut(stmt);
58
59            if let Some(e) = visitor.error.take() {
60                Err(e)
61            } else {
62                Ok(())
63            }
64        }
65        stmt => {
66            unreachable!("Internal error: only catalog items need to update item refs. {stmt:?}")
67        }
68    }
69}
70
71struct CreateSqlRewriteSchema<'a> {
72    database: &'a str,
73    cur_schema: &'a str,
74    new_schema: &'a str,
75    error: Option<(String, String)>,
76}
77
78impl<'a> CreateSqlRewriteSchema<'a> {
79    fn maybe_rewrite_idents(&mut self, name: &mut [Ident]) {
80        match name {
81            [schema, item] if schema.as_str() == self.cur_schema => {
82                // TODO(parkmycar): I _think_ when the database component is not specified we can
83                // always infer we're using the current database. But I'm not positive, so for now
84                // we'll bail in this case.
85                if self.error.is_none() {
86                    self.error = Some((schema.to_string(), item.to_string()));
87                }
88            }
89            [database, schema, _item] => {
90                if database.as_str() == self.database && schema.as_str() == self.cur_schema {
91                    *schema = Ident::new_unchecked(self.new_schema);
92                }
93            }
94            _ => (),
95        }
96    }
97}
98
99impl<'a, 'ast> VisitMut<'ast, Raw> for CreateSqlRewriteSchema<'a> {
100    fn visit_expr_mut(&mut self, e: &'ast mut Expr<Raw>) {
101        match e {
102            Expr::Identifier(id) => {
103                // The last ID component is a column name that should not be
104                // considered in the rewrite.
105                let i = id.len() - 1;
106                self.maybe_rewrite_idents(&mut id[..i]);
107            }
108            Expr::QualifiedWildcard(id) => {
109                self.maybe_rewrite_idents(id);
110            }
111            _ => visit_mut::visit_expr_mut(self, e),
112        }
113    }
114
115    fn visit_unresolved_item_name_mut(
116        &mut self,
117        unresolved_item_name: &'ast mut UnresolvedItemName,
118    ) {
119        self.maybe_rewrite_idents(&mut unresolved_item_name.0);
120    }
121
122    fn visit_item_name_mut(
123        &mut self,
124        item_name: &'ast mut <mz_sql_parser::ast::Raw as AstInfo>::ItemName,
125    ) {
126        match item_name {
127            RawItemName::Name(n) | RawItemName::Id(_, n, _) => self.maybe_rewrite_idents(&mut n.0),
128        }
129    }
130}
131
132/// Changes the `name` used in an item's `CREATE` statement. To complete a
133/// rename operation, you must also call `create_stmt_rename_refs` on all dependent
134/// items.
135pub fn create_stmt_rename(create_stmt: &mut Statement<Raw>, to_item_name: String) {
136    // TODO(sploiselle): Support renaming schemas and databases.
137    match create_stmt {
138        Statement::CreateIndex(CreateIndexStatement { name, .. }) => {
139            *name = Some(Ident::new_unchecked(to_item_name));
140        }
141        Statement::CreateSink(CreateSinkStatement {
142            name: Some(name), ..
143        })
144        | Statement::CreateSource(CreateSourceStatement { name, .. })
145        | Statement::CreateSubsource(CreateSubsourceStatement { name, .. })
146        | Statement::CreateView(CreateViewStatement {
147            definition: ViewDefinition { name, .. },
148            ..
149        })
150        | Statement::CreateMaterializedView(CreateMaterializedViewStatement { name, .. })
151        | Statement::CreateTable(CreateTableStatement { name, .. })
152        | Statement::CreateTableFromSource(CreateTableFromSourceStatement { name, .. })
153        | Statement::CreateSecret(CreateSecretStatement { name, .. })
154        | Statement::CreateConnection(CreateConnectionStatement { name, .. })
155        | Statement::CreateWebhookSource(CreateWebhookSourceStatement { name, .. }) => {
156            // The last name in an ItemName is the item name. The item name
157            // does not have a fixed index.
158            // TODO: https://github.com/MaterializeInc/database-issues/issues/1721
159            let item_name_len = name.0.len() - 1;
160            name.0[item_name_len] = Ident::new_unchecked(to_item_name);
161        }
162        item => unreachable!("Internal error: only catalog items can be renamed {item:?}"),
163    }
164}
165
166/// Updates all references of `from_name` in `create_stmt` to `to_name` or
167/// errors if request is ambiguous.
168///
169/// Requests are considered ambiguous if `create_stmt` is a
170/// `Statement::CreateView`, and any of the following apply to its `query`:
171/// - `to_name.item` is used as an [`Ident`] in `query`.
172/// - `from_name.item` does not unambiguously refer to an item in the query,
173///   e.g. it is also used as a schema, or not all references to the item are
174///   sufficiently qualified.
175/// - `to_name.item` does not unambiguously refer to an item in the query after
176///   the rename. Right now, given the first condition, this is just a coherence
177///   check, but will be more meaningful once the first restriction is lifted.
178pub fn create_stmt_rename_refs(
179    create_stmt: &mut Statement<Raw>,
180    from_name: FullItemName,
181    to_item_name: String,
182) -> Result<(), String> {
183    let from_item = UnresolvedItemName::from(from_name.clone());
184    let maybe_update_item_name = |item_name: &mut UnresolvedItemName| {
185        if item_name.0 == from_item.0 {
186            // The last name in an ItemName is the item name. The item name
187            // does not have a fixed index.
188            // TODO: https://github.com/MaterializeInc/database-issues/issues/1721
189            let item_name_len = item_name.0.len() - 1;
190            item_name.0[item_name_len] = Ident::new_unchecked(to_item_name.clone());
191        }
192    };
193
194    // TODO(sploiselle): Support renaming schemas and databases.
195    match create_stmt {
196        Statement::CreateIndex(CreateIndexStatement { on_name, .. }) => {
197            maybe_update_item_name(on_name.name_mut());
198        }
199        Statement::CreateSink(CreateSinkStatement { from, .. }) => {
200            maybe_update_item_name(from.name_mut());
201        }
202        Statement::CreateTableFromSource(CreateTableFromSourceStatement { source, .. }) => {
203            maybe_update_item_name(source.name_mut());
204        }
205        Statement::CreateView(CreateViewStatement {
206            definition: ViewDefinition { query, .. },
207            ..
208        }) => {
209            rewrite_query(from_name, to_item_name, query)?;
210        }
211        Statement::CreateMaterializedView(CreateMaterializedViewStatement {
212            replacement_for,
213            query,
214            ..
215        }) => {
216            if let Some(target) = replacement_for {
217                maybe_update_item_name(target.name_mut());
218            }
219            rewrite_query(from_name, to_item_name, query)?;
220        }
221        Statement::CreateSource(_)
222        | Statement::CreateSubsource(_)
223        | Statement::CreateTable(_)
224        | Statement::CreateSecret(_)
225        | Statement::CreateConnection(_)
226        | Statement::CreateWebhookSource(_) => {}
227        item => {
228            unreachable!("Internal error: only catalog items need to update item refs {item:?}")
229        }
230    }
231
232    Ok(())
233}
234
235/// Rewrites `query`'s references of `from` to `to` or errors if too ambiguous.
236fn rewrite_query(from: FullItemName, to: String, query: &mut Query<Raw>) -> Result<(), String> {
237    let from_ident = Ident::new_unchecked(from.item.clone());
238    let to_ident = Ident::new_unchecked(to);
239    let qual_depth =
240        QueryIdentAgg::determine_qual_depth(&from_ident, Some(to_ident.clone()), query)?;
241    CreateSqlRewriter::rewrite_query_with_qual_depth(from, to_ident.clone(), qual_depth, query);
242    // Ensure that our rewrite didn't didn't introduce ambiguous
243    // references to `to_name`.
244    match QueryIdentAgg::determine_qual_depth(&to_ident, None, query) {
245        Ok(_) => Ok(()),
246        Err(e) => Err(e),
247    }
248}
249
250fn ambiguous_err(n: &Ident, t: &str) -> String {
251    format!(
252        "{} potentially used ambiguously as item and {}",
253        n.as_str().quoted(),
254        t
255    )
256}
257
258/// Visits a [`Query`], assessing catalog item [`Ident`]s' use of a specified `Ident`.
259struct QueryIdentAgg<'a> {
260    /// The name whose usage you want to assess.
261    name: &'a Ident,
262    /// Tracks all second-level qualifiers used on `name` in a `BTreeMap`, as
263    /// well as any third-level qualifiers used on those second-level qualifiers
264    /// in a `BTreeSet`.
265    qualifiers: BTreeMap<Ident, BTreeSet<Ident>>,
266    /// Tracks the least qualified instance of `name` seen.
267    min_qual_depth: usize,
268    /// Provides an option to fail the visit if encounters a specified `Ident`.
269    fail_on: Option<Ident>,
270    err: Option<String>,
271}
272
273impl<'a> QueryIdentAgg<'a> {
274    /// Determines the depth of qualification needed to unambiguously reference
275    /// catalog items in a [`Query`].
276    ///
277    /// Includes an option to fail if a given `Ident` is encountered.
278    ///
279    /// `Result`s of `Ok(usize)` indicate that `name` can be unambiguously
280    /// referred to with `usize` parts, e.g. 2 requires schema and item name
281    /// qualification.
282    ///
283    /// `Result`s of `Err` indicate that we cannot unambiguously reference
284    /// `name` or encountered `fail_on`, if it's provided.
285    fn determine_qual_depth(
286        name: &Ident,
287        fail_on: Option<Ident>,
288        query: &Query<Raw>,
289    ) -> Result<usize, String> {
290        let mut v = QueryIdentAgg {
291            qualifiers: BTreeMap::new(),
292            min_qual_depth: usize::MAX,
293            err: None,
294            name,
295            fail_on,
296        };
297
298        // Aggregate identities in `v`.
299        v.visit_query(query);
300        // Not possible to have a qualification depth of 0;
301        assert!(v.min_qual_depth > 0);
302
303        if let Some(e) = v.err {
304            return Err(e);
305        }
306
307        // Check if there was more than one 3rd-level (e.g.
308        // database) qualification used for any reference to `name`.
309        let req_depth = if v.qualifiers.values().any(|v| v.len() > 1) {
310            3
311        // Check if there was more than one 2nd-level (e.g. schema)
312        // qualification used for any reference to `name`.
313        } else if v.qualifiers.len() > 1 {
314            2
315        } else {
316            1
317        };
318
319        if v.min_qual_depth < req_depth {
320            Err(format!(
321                "{} is not sufficiently qualified to support renaming",
322                name.as_str().quoted()
323            ))
324        } else {
325            Ok(req_depth)
326        }
327    }
328
329    // Assesses `v` for uses of `self.name` and `self.fail_on`.
330    fn check_failure(&mut self, v: &[Ident]) {
331        // Fail if we encounter `self.fail_on`.
332        if let Some(f) = &self.fail_on {
333            if v.iter().any(|i| i == f) {
334                self.err = Some(format!(
335                    "found reference to {}; cannot rename {} to any identity \
336                    used in any existing view definitions",
337                    f.as_str().quoted(),
338                    self.name.as_str().quoted()
339                ));
340            }
341        }
342    }
343}
344
345impl<'a, 'ast> Visit<'ast, Raw> for QueryIdentAgg<'a> {
346    fn visit_expr(&mut self, e: &'ast Expr<Raw>) {
347        match e {
348            Expr::Identifier(i) => {
349                self.check_failure(i);
350                if let Some(p) = i.iter().rposition(|e| e == self.name) {
351                    if p == i.len() - 1 {
352                        // `self.name` used as a column if it's in the final
353                        // position here, e.g. `SELECT view.col FROM ...`
354                        self.err = Some(ambiguous_err(self.name, "column"));
355                        return;
356                    }
357                    self.min_qual_depth = std::cmp::min(p + 1, self.min_qual_depth);
358                }
359            }
360            Expr::QualifiedWildcard(i) => {
361                self.check_failure(i);
362                if let Some(p) = i.iter().rposition(|e| e == self.name) {
363                    self.min_qual_depth = std::cmp::min(p + 1, self.min_qual_depth);
364                }
365            }
366            _ => visit::visit_expr(self, e),
367        }
368    }
369
370    fn visit_ident(&mut self, ident: &'ast Ident) {
371        self.check_failure(std::slice::from_ref(ident));
372        // This is an unqualified item using `self.name`, e.g. an alias, which
373        // we cannot unambiguously resolve.
374        if ident == self.name {
375            self.err = Some(ambiguous_err(self.name, "alias or column"));
376        }
377    }
378
379    fn visit_unresolved_item_name(&mut self, unresolved_item_name: &'ast UnresolvedItemName) {
380        let names = &unresolved_item_name.0;
381        self.check_failure(names);
382        // Every item is used as an `ItemName` at least once, which
383        // lets use track all items named `self.name`.
384        if let Some(p) = names.iter().rposition(|e| e == self.name) {
385            // Name used as last element of `<db>.<schema>.<item>`
386            if p == names.len() - 1 && names.len() == 3 {
387                self.qualifiers
388                    .entry(names[1].clone())
389                    .or_default()
390                    .insert(names[0].clone());
391                self.min_qual_depth = std::cmp::min(3, self.min_qual_depth);
392            } else {
393                // Any other use is a database or schema
394                self.err = Some(ambiguous_err(self.name, "database, schema, or function"))
395            }
396        }
397    }
398
399    fn visit_item_name(&mut self, item_name: &'ast <Raw as AstInfo>::ItemName) {
400        match item_name {
401            RawItemName::Name(n) | RawItemName::Id(_, n, _) => self.visit_unresolved_item_name(n),
402        }
403    }
404}
405
406struct CreateSqlRewriter {
407    from: Vec<Ident>,
408    to: Ident,
409}
410
411impl CreateSqlRewriter {
412    fn rewrite_query_with_qual_depth(
413        from_name: FullItemName,
414        to_name: Ident,
415        qual_depth: usize,
416        query: &mut Query<Raw>,
417    ) {
418        let from = match qual_depth {
419            1 => vec![Ident::new_unchecked(from_name.item)],
420            2 => vec![
421                Ident::new_unchecked(from_name.schema),
422                Ident::new_unchecked(from_name.item),
423            ],
424            3 => vec![
425                Ident::new_unchecked(from_name.database.to_string()),
426                Ident::new_unchecked(from_name.schema),
427                Ident::new_unchecked(from_name.item),
428            ],
429            _ => unreachable!(),
430        };
431        let mut v = CreateSqlRewriter { from, to: to_name };
432        v.visit_query_mut(query);
433    }
434
435    fn maybe_rewrite_idents(&self, name: &mut [Ident]) {
436        if name.len() > 0 && name.ends_with(&self.from) {
437            name[name.len() - 1] = self.to.clone();
438        }
439    }
440}
441
442impl<'ast> VisitMut<'ast, Raw> for CreateSqlRewriter {
443    fn visit_expr_mut(&mut self, e: &'ast mut Expr<Raw>) {
444        match e {
445            Expr::Identifier(id) => {
446                // The last ID component is a column name that should not be
447                // considered in the rewrite.
448                let i = id.len() - 1;
449                self.maybe_rewrite_idents(&mut id[..i]);
450            }
451            Expr::QualifiedWildcard(id) => {
452                self.maybe_rewrite_idents(id);
453            }
454            _ => visit_mut::visit_expr_mut(self, e),
455        }
456    }
457    fn visit_unresolved_item_name_mut(
458        &mut self,
459        unresolved_item_name: &'ast mut UnresolvedItemName,
460    ) {
461        self.maybe_rewrite_idents(&mut unresolved_item_name.0);
462    }
463    fn visit_item_name_mut(
464        &mut self,
465        item_name: &'ast mut <mz_sql_parser::ast::Raw as AstInfo>::ItemName,
466    ) {
467        match item_name {
468            RawItemName::Name(n) | RawItemName::Id(_, n, _) => self.maybe_rewrite_idents(&mut n.0),
469        }
470    }
471}
472
473/// Updates all `CatalogItemId`s from the keys of `ids` to the values of `ids` within `create_stmt`.
474pub fn create_stmt_replace_ids(
475    create_stmt: &mut Statement<Raw>,
476    ids: &BTreeMap<CatalogItemId, CatalogItemId>,
477) {
478    let mut id_replacer = CreateSqlIdReplacer { ids };
479    id_replacer.visit_statement_mut(create_stmt);
480}
481
482struct CreateSqlIdReplacer<'a> {
483    ids: &'a BTreeMap<CatalogItemId, CatalogItemId>,
484}
485
486impl<'ast> VisitMut<'ast, Raw> for CreateSqlIdReplacer<'_> {
487    fn visit_item_name_mut(
488        &mut self,
489        item_name: &'ast mut <mz_sql_parser::ast::Raw as AstInfo>::ItemName,
490    ) {
491        match item_name {
492            RawItemName::Id(id, _, _) => {
493                let old_id = match id.parse() {
494                    Ok(old_id) => old_id,
495                    Err(e) => panic!("invalid persisted global id {id}: {e}"),
496                };
497                if let Some(new_id) = self.ids.get(&old_id) {
498                    *id = new_id.to_string();
499                }
500            }
501            RawItemName::Name(_) => {}
502        }
503    }
504}