1use std::collections::{BTreeMap, BTreeSet};
13
14use mz_ore::str::StrExt;
15use mz_repr::CatalogItemId;
16use mz_sql_parser::ast::CreateTableFromSourceStatement;
17
18use crate::ast::visit::{self, Visit};
19use crate::ast::visit_mut::{self, VisitMut};
20use crate::ast::{
21 AstInfo, CreateConnectionStatement, CreateIndexStatement, CreateMaterializedViewStatement,
22 CreateSecretStatement, CreateSinkStatement, CreateSourceStatement, CreateSubsourceStatement,
23 CreateTableStatement, CreateViewStatement, CreateWebhookSourceStatement, Expr, Ident, Query,
24 Raw, RawItemName, Statement, UnresolvedItemName, ViewDefinition,
25};
26use crate::names::FullItemName;
27
28pub fn create_stmt_rename_schema_refs(
31 create_stmt: &mut Statement<Raw>,
32 database: &str,
33 cur_schema: &str,
34 new_schema: &str,
35) -> Result<(), (String, String)> {
36 match create_stmt {
37 stmt @ Statement::CreateConnection(_)
38 | stmt @ Statement::CreateDatabase(_)
39 | stmt @ Statement::CreateSchema(_)
40 | stmt @ Statement::CreateWebhookSource(_)
41 | stmt @ Statement::CreateSource(_)
42 | stmt @ Statement::CreateSubsource(_)
43 | stmt @ Statement::CreateSink(_)
44 | stmt @ Statement::CreateView(_)
45 | stmt @ Statement::CreateMaterializedView(_)
46 | stmt @ Statement::CreateTable(_)
47 | stmt @ Statement::CreateTableFromSource(_)
48 | stmt @ Statement::CreateIndex(_)
49 | stmt @ Statement::CreateType(_)
50 | stmt @ Statement::CreateSecret(_) => {
51 let mut visitor = CreateSqlRewriteSchema {
52 database,
53 cur_schema,
54 new_schema,
55 error: None,
56 };
57 visitor.visit_statement_mut(stmt);
58
59 if let Some(e) = visitor.error.take() {
60 Err(e)
61 } else {
62 Ok(())
63 }
64 }
65 stmt => {
66 unreachable!("Internal error: only catalog items need to update item refs. {stmt:?}")
67 }
68 }
69}
70
71struct CreateSqlRewriteSchema<'a> {
72 database: &'a str,
73 cur_schema: &'a str,
74 new_schema: &'a str,
75 error: Option<(String, String)>,
76}
77
78impl<'a> CreateSqlRewriteSchema<'a> {
79 fn maybe_rewrite_idents(&mut self, name: &mut [Ident]) {
80 match name {
81 [schema, item] if schema.as_str() == self.cur_schema => {
82 if self.error.is_none() {
86 self.error = Some((schema.to_string(), item.to_string()));
87 }
88 }
89 [database, schema, _item] => {
90 if database.as_str() == self.database && schema.as_str() == self.cur_schema {
91 *schema = Ident::new_unchecked(self.new_schema);
92 }
93 }
94 _ => (),
95 }
96 }
97}
98
99impl<'a, 'ast> VisitMut<'ast, Raw> for CreateSqlRewriteSchema<'a> {
100 fn visit_expr_mut(&mut self, e: &'ast mut Expr<Raw>) {
101 match e {
102 Expr::Identifier(id) => {
103 let i = id.len() - 1;
106 self.maybe_rewrite_idents(&mut id[..i]);
107 }
108 Expr::QualifiedWildcard(id) => {
109 self.maybe_rewrite_idents(id);
110 }
111 _ => visit_mut::visit_expr_mut(self, e),
112 }
113 }
114
115 fn visit_unresolved_item_name_mut(
116 &mut self,
117 unresolved_item_name: &'ast mut UnresolvedItemName,
118 ) {
119 self.maybe_rewrite_idents(&mut unresolved_item_name.0);
120 }
121
122 fn visit_item_name_mut(
123 &mut self,
124 item_name: &'ast mut <mz_sql_parser::ast::Raw as AstInfo>::ItemName,
125 ) {
126 match item_name {
127 RawItemName::Name(n) | RawItemName::Id(_, n, _) => self.maybe_rewrite_idents(&mut n.0),
128 }
129 }
130}
131
132pub fn create_stmt_rename(create_stmt: &mut Statement<Raw>, to_item_name: String) {
136 match create_stmt {
138 Statement::CreateIndex(CreateIndexStatement { name, .. }) => {
139 *name = Some(Ident::new_unchecked(to_item_name));
140 }
141 Statement::CreateSink(CreateSinkStatement {
142 name: Some(name), ..
143 })
144 | Statement::CreateSource(CreateSourceStatement { name, .. })
145 | Statement::CreateSubsource(CreateSubsourceStatement { name, .. })
146 | Statement::CreateView(CreateViewStatement {
147 definition: ViewDefinition { name, .. },
148 ..
149 })
150 | Statement::CreateMaterializedView(CreateMaterializedViewStatement { name, .. })
151 | Statement::CreateTable(CreateTableStatement { name, .. })
152 | Statement::CreateTableFromSource(CreateTableFromSourceStatement { name, .. })
153 | Statement::CreateSecret(CreateSecretStatement { name, .. })
154 | Statement::CreateConnection(CreateConnectionStatement { name, .. })
155 | Statement::CreateWebhookSource(CreateWebhookSourceStatement { name, .. }) => {
156 let item_name_len = name.0.len() - 1;
160 name.0[item_name_len] = Ident::new_unchecked(to_item_name);
161 }
162 item => unreachable!("Internal error: only catalog items can be renamed {item:?}"),
163 }
164}
165
166pub fn create_stmt_rename_refs(
179 create_stmt: &mut Statement<Raw>,
180 from_name: FullItemName,
181 to_item_name: String,
182) -> Result<(), String> {
183 let from_item = UnresolvedItemName::from(from_name.clone());
184 let maybe_update_item_name = |item_name: &mut UnresolvedItemName| {
185 if item_name.0 == from_item.0 {
186 let item_name_len = item_name.0.len() - 1;
190 item_name.0[item_name_len] = Ident::new_unchecked(to_item_name.clone());
191 }
192 };
193
194 match create_stmt {
196 Statement::CreateIndex(CreateIndexStatement { on_name, .. }) => {
197 maybe_update_item_name(on_name.name_mut());
198 }
199 Statement::CreateSink(CreateSinkStatement { from, .. }) => {
200 maybe_update_item_name(from.name_mut());
201 }
202 Statement::CreateTableFromSource(CreateTableFromSourceStatement { source, .. }) => {
203 maybe_update_item_name(source.name_mut());
204 }
205 Statement::CreateView(CreateViewStatement {
206 definition: ViewDefinition { query, .. },
207 ..
208 }) => {
209 rewrite_query(from_name, to_item_name, query)?;
210 }
211 Statement::CreateMaterializedView(CreateMaterializedViewStatement {
212 replacement_for,
213 query,
214 ..
215 }) => {
216 if let Some(target) = replacement_for {
217 maybe_update_item_name(target.name_mut());
218 }
219 rewrite_query(from_name, to_item_name, query)?;
220 }
221 Statement::CreateSource(_)
222 | Statement::CreateSubsource(_)
223 | Statement::CreateTable(_)
224 | Statement::CreateSecret(_)
225 | Statement::CreateConnection(_)
226 | Statement::CreateWebhookSource(_) => {}
227 item => {
228 unreachable!("Internal error: only catalog items need to update item refs {item:?}")
229 }
230 }
231
232 Ok(())
233}
234
235fn rewrite_query(from: FullItemName, to: String, query: &mut Query<Raw>) -> Result<(), String> {
237 let from_ident = Ident::new_unchecked(from.item.clone());
238 let to_ident = Ident::new_unchecked(to);
239 let qual_depth =
240 QueryIdentAgg::determine_qual_depth(&from_ident, Some(to_ident.clone()), query)?;
241 CreateSqlRewriter::rewrite_query_with_qual_depth(from, to_ident.clone(), qual_depth, query);
242 match QueryIdentAgg::determine_qual_depth(&to_ident, None, query) {
245 Ok(_) => Ok(()),
246 Err(e) => Err(e),
247 }
248}
249
250fn ambiguous_err(n: &Ident, t: &str) -> String {
251 format!(
252 "{} potentially used ambiguously as item and {}",
253 n.as_str().quoted(),
254 t
255 )
256}
257
258struct QueryIdentAgg<'a> {
260 name: &'a Ident,
262 qualifiers: BTreeMap<Ident, BTreeSet<Ident>>,
266 min_qual_depth: usize,
268 fail_on: Option<Ident>,
270 err: Option<String>,
271}
272
273impl<'a> QueryIdentAgg<'a> {
274 fn determine_qual_depth(
286 name: &Ident,
287 fail_on: Option<Ident>,
288 query: &Query<Raw>,
289 ) -> Result<usize, String> {
290 let mut v = QueryIdentAgg {
291 qualifiers: BTreeMap::new(),
292 min_qual_depth: usize::MAX,
293 err: None,
294 name,
295 fail_on,
296 };
297
298 v.visit_query(query);
300 assert!(v.min_qual_depth > 0);
302
303 if let Some(e) = v.err {
304 return Err(e);
305 }
306
307 let req_depth = if v.qualifiers.values().any(|v| v.len() > 1) {
310 3
311 } else if v.qualifiers.len() > 1 {
314 2
315 } else {
316 1
317 };
318
319 if v.min_qual_depth < req_depth {
320 Err(format!(
321 "{} is not sufficiently qualified to support renaming",
322 name.as_str().quoted()
323 ))
324 } else {
325 Ok(req_depth)
326 }
327 }
328
329 fn check_failure(&mut self, v: &[Ident]) {
331 if let Some(f) = &self.fail_on {
333 if v.iter().any(|i| i == f) {
334 self.err = Some(format!(
335 "found reference to {}; cannot rename {} to any identity \
336 used in any existing view definitions",
337 f.as_str().quoted(),
338 self.name.as_str().quoted()
339 ));
340 }
341 }
342 }
343}
344
345impl<'a, 'ast> Visit<'ast, Raw> for QueryIdentAgg<'a> {
346 fn visit_expr(&mut self, e: &'ast Expr<Raw>) {
347 match e {
348 Expr::Identifier(i) => {
349 self.check_failure(i);
350 if let Some(p) = i.iter().rposition(|e| e == self.name) {
351 if p == i.len() - 1 {
352 self.err = Some(ambiguous_err(self.name, "column"));
355 return;
356 }
357 self.min_qual_depth = std::cmp::min(p + 1, self.min_qual_depth);
358 }
359 }
360 Expr::QualifiedWildcard(i) => {
361 self.check_failure(i);
362 if let Some(p) = i.iter().rposition(|e| e == self.name) {
363 self.min_qual_depth = std::cmp::min(p + 1, self.min_qual_depth);
364 }
365 }
366 _ => visit::visit_expr(self, e),
367 }
368 }
369
370 fn visit_ident(&mut self, ident: &'ast Ident) {
371 self.check_failure(std::slice::from_ref(ident));
372 if ident == self.name {
375 self.err = Some(ambiguous_err(self.name, "alias or column"));
376 }
377 }
378
379 fn visit_unresolved_item_name(&mut self, unresolved_item_name: &'ast UnresolvedItemName) {
380 let names = &unresolved_item_name.0;
381 self.check_failure(names);
382 if let Some(p) = names.iter().rposition(|e| e == self.name) {
385 if p == names.len() - 1 && names.len() == 3 {
387 self.qualifiers
388 .entry(names[1].clone())
389 .or_default()
390 .insert(names[0].clone());
391 self.min_qual_depth = std::cmp::min(3, self.min_qual_depth);
392 } else {
393 self.err = Some(ambiguous_err(self.name, "database, schema, or function"))
395 }
396 }
397 }
398
399 fn visit_item_name(&mut self, item_name: &'ast <Raw as AstInfo>::ItemName) {
400 match item_name {
401 RawItemName::Name(n) | RawItemName::Id(_, n, _) => self.visit_unresolved_item_name(n),
402 }
403 }
404}
405
406struct CreateSqlRewriter {
407 from: Vec<Ident>,
408 to: Ident,
409}
410
411impl CreateSqlRewriter {
412 fn rewrite_query_with_qual_depth(
413 from_name: FullItemName,
414 to_name: Ident,
415 qual_depth: usize,
416 query: &mut Query<Raw>,
417 ) {
418 let from = match qual_depth {
419 1 => vec![Ident::new_unchecked(from_name.item)],
420 2 => vec![
421 Ident::new_unchecked(from_name.schema),
422 Ident::new_unchecked(from_name.item),
423 ],
424 3 => vec![
425 Ident::new_unchecked(from_name.database.to_string()),
426 Ident::new_unchecked(from_name.schema),
427 Ident::new_unchecked(from_name.item),
428 ],
429 _ => unreachable!(),
430 };
431 let mut v = CreateSqlRewriter { from, to: to_name };
432 v.visit_query_mut(query);
433 }
434
435 fn maybe_rewrite_idents(&self, name: &mut [Ident]) {
436 if name.len() > 0 && name.ends_with(&self.from) {
437 name[name.len() - 1] = self.to.clone();
438 }
439 }
440}
441
442impl<'ast> VisitMut<'ast, Raw> for CreateSqlRewriter {
443 fn visit_expr_mut(&mut self, e: &'ast mut Expr<Raw>) {
444 match e {
445 Expr::Identifier(id) => {
446 let i = id.len() - 1;
449 self.maybe_rewrite_idents(&mut id[..i]);
450 }
451 Expr::QualifiedWildcard(id) => {
452 self.maybe_rewrite_idents(id);
453 }
454 _ => visit_mut::visit_expr_mut(self, e),
455 }
456 }
457 fn visit_unresolved_item_name_mut(
458 &mut self,
459 unresolved_item_name: &'ast mut UnresolvedItemName,
460 ) {
461 self.maybe_rewrite_idents(&mut unresolved_item_name.0);
462 }
463 fn visit_item_name_mut(
464 &mut self,
465 item_name: &'ast mut <mz_sql_parser::ast::Raw as AstInfo>::ItemName,
466 ) {
467 match item_name {
468 RawItemName::Name(n) | RawItemName::Id(_, n, _) => self.maybe_rewrite_idents(&mut n.0),
469 }
470 }
471}
472
473pub fn create_stmt_replace_ids(
475 create_stmt: &mut Statement<Raw>,
476 ids: &BTreeMap<CatalogItemId, CatalogItemId>,
477) {
478 let mut id_replacer = CreateSqlIdReplacer { ids };
479 id_replacer.visit_statement_mut(create_stmt);
480}
481
482struct CreateSqlIdReplacer<'a> {
483 ids: &'a BTreeMap<CatalogItemId, CatalogItemId>,
484}
485
486impl<'ast> VisitMut<'ast, Raw> for CreateSqlIdReplacer<'_> {
487 fn visit_item_name_mut(
488 &mut self,
489 item_name: &'ast mut <mz_sql_parser::ast::Raw as AstInfo>::ItemName,
490 ) {
491 match item_name {
492 RawItemName::Id(id, _, _) => {
493 let old_id = match id.parse() {
494 Ok(old_id) => old_id,
495 Err(e) => panic!("invalid persisted global id {id}: {e}"),
496 };
497 if let Some(new_id) = self.ids.get(&old_id) {
498 *id = new_id.to_string();
499 }
500 }
501 RawItemName::Name(_) => {}
502 }
503 }
504}