Skip to main content

mz_deploy/project/
roles.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Role definition loading and validation.
11//!
12//! Loads role definitions from `<root>/roles/` directory. Each `.sql` file
13//! defines a single role with a required `CREATE ROLE` statement and optional
14//! `ALTER ROLE`, `GRANT ROLE`, and `COMMENT` statements.
15
16use crate::project::error::{
17    LoadError, ProjectError, ValidationError, ValidationErrorKind, ValidationErrors,
18};
19use crate::project::syntax::parser::{
20    LocatedStatement, parse_statements_with_context, statement_type_name,
21};
22use crate::project::syntax::profile_files::collect_all_sql_files;
23use mz_sql_parser::ast::{
24    AlterRoleStatement, CommentObjectType, CommentStatement, CreateRoleStatement,
25    GrantRoleStatement, Raw, Statement,
26};
27use std::collections::BTreeMap;
28use std::path::Path;
29
30/// A parsed role definition from a `.sql` file in the `roles/` directory.
31pub(crate) struct RoleDefinition {
32    /// Role name (derived from filename and validated against CREATE statement).
33    pub name: String,
34    /// The CREATE ROLE statement.
35    pub create_stmt: CreateRoleStatement,
36    /// Optional ALTER ROLE statements for this role.
37    pub alter_stmts: Vec<AlterRoleStatement<Raw>>,
38    /// Optional GRANT ROLE statements granting this role to members.
39    pub grants: Vec<GrantRoleStatement<Raw>>,
40    /// Optional COMMENT statements targeting this role.
41    pub comments: Vec<CommentStatement<Raw>>,
42}
43
44/// Load all role definitions from `<root>/roles/`.
45///
46/// Returns an empty vec if `roles/` doesn't exist (the directory is optional).
47pub(crate) fn load_roles(
48    root: &Path,
49    profile: &str,
50    variables: &BTreeMap<String, String>,
51) -> Result<Vec<RoleDefinition>, ProjectError> {
52    let roles_dir = root.join("roles");
53
54    if !roles_dir.exists() {
55        return Ok(vec![]);
56    }
57
58    if !roles_dir.is_dir() {
59        return Err(LoadError::RootNotDirectory { path: roles_dir }.into());
60    }
61
62    let all_files = collect_all_sql_files(&roles_dir)?;
63
64    let mut definitions = Vec::new();
65    let mut errors = Vec::new();
66
67    for object_files in all_files {
68        let expected_name = &object_files.name;
69
70        // Validate all variants independently
71        let mut all_variant_paths = Vec::new();
72        if let Some(ref default_path) = object_files.default {
73            all_variant_paths.push(default_path.clone());
74        }
75        for (_, override_path) in &object_files.overrides {
76            all_variant_paths.push(override_path.clone());
77        }
78
79        for path in &all_variant_paths {
80            let sql = std::fs::read_to_string(path).map_err(|e| LoadError::FileReadFailed {
81                path: path.clone(),
82                source: e,
83            })?;
84            let located = parse_statements_with_context(&sql, path.clone(), variables, true)?;
85
86            if let Err(mut errs) = classify_role_statements(expected_name, path, located) {
87                errors.append(&mut errs);
88            }
89        }
90
91        // Resolve the active variant: prefer profile match, fall back to default
92        let active_path = object_files
93            .overrides
94            .get(profile)
95            .or(object_files.default.as_ref());
96
97        let active_path = match active_path {
98            Some(p) => p.clone(),
99            None => continue,
100        };
101
102        let sql = std::fs::read_to_string(&active_path).map_err(|e| LoadError::FileReadFailed {
103            path: active_path.clone(),
104            source: e,
105        })?;
106        let located = parse_statements_with_context(&sql, active_path.clone(), variables, true)?;
107
108        match classify_role_statements(expected_name, &active_path, located) {
109            Ok(def) => definitions.push(def),
110            Err(mut errs) => errors.append(&mut errs),
111        }
112    }
113
114    if !errors.is_empty() {
115        return Err(ValidationErrors::new(errors).into());
116    }
117
118    Ok(definitions)
119}
120
121/// Classify parsed statements into a `RoleDefinition`, returning validation errors.
122fn classify_role_statements(
123    expected_name: &str,
124    path: &Path,
125    located_statements: Vec<LocatedStatement>,
126) -> Result<RoleDefinition, Vec<ValidationError>> {
127    let mut create_stmts: Vec<(CreateRoleStatement, usize)> = Vec::new();
128    let mut alter_stmts: Vec<AlterRoleStatement<Raw>> = Vec::new();
129    let mut grants: Vec<GrantRoleStatement<Raw>> = Vec::new();
130    let mut comments: Vec<CommentStatement<Raw>> = Vec::new();
131    let mut errors = Vec::new();
132
133    for LocatedStatement {
134        ast: stmt,
135        byte_offset,
136    } in located_statements
137    {
138        match stmt {
139            Statement::CreateRole(s) => {
140                create_stmts.push((s, byte_offset));
141            }
142            Statement::AlterRole(s) => {
143                // Validate that the ALTER targets this role
144                let target_name = s.name.to_string();
145                if target_name.to_lowercase() != expected_name.to_lowercase() {
146                    errors.push(ValidationError::with_file_sql_and_offset(
147                        ValidationErrorKind::RoleAlterTargetMismatch {
148                            target: target_name,
149                            role_name: expected_name.to_string(),
150                        },
151                        path.to_path_buf(),
152                        s.to_string(),
153                        byte_offset,
154                    ));
155                } else {
156                    alter_stmts.push(s);
157                }
158            }
159            Statement::GrantRole(s) => {
160                // Validate that this role is among the roles being granted
161                let has_match = s
162                    .role_names
163                    .iter()
164                    .any(|r| r.to_string().to_lowercase() == expected_name.to_lowercase());
165                if !has_match {
166                    let target_names: Vec<String> =
167                        s.role_names.iter().map(|r| r.to_string()).collect();
168                    errors.push(ValidationError::with_file_sql_and_offset(
169                        ValidationErrorKind::RoleGrantTargetMismatch {
170                            target: target_names.join(", "),
171                            role_name: expected_name.to_string(),
172                        },
173                        path.to_path_buf(),
174                        s.to_string(),
175                        byte_offset,
176                    ));
177                } else {
178                    grants.push(s);
179                }
180            }
181            Statement::Comment(s) => {
182                // Validate that the comment targets this role
183                match &s.object {
184                    CommentObjectType::Role { name } => {
185                        let target_name = name.to_string();
186                        if target_name.to_lowercase() != expected_name.to_lowercase() {
187                            errors.push(ValidationError::with_file_sql_and_offset(
188                                ValidationErrorKind::RoleCommentTargetMismatch {
189                                    target: target_name,
190                                    role_name: expected_name.to_string(),
191                                },
192                                path.to_path_buf(),
193                                s.to_string(),
194                                byte_offset,
195                            ));
196                        }
197                        comments.push(s);
198                    }
199                    _ => {
200                        errors.push(ValidationError::with_file_sql_and_offset(
201                            ValidationErrorKind::InvalidRoleStatement {
202                                statement_type: "COMMENT (not targeting a role)".to_string(),
203                                role_name: expected_name.to_string(),
204                            },
205                            path.to_path_buf(),
206                            s.to_string(),
207                            byte_offset,
208                        ));
209                    }
210                }
211            }
212            other => {
213                errors.push(ValidationError::with_file_sql_and_offset(
214                    ValidationErrorKind::InvalidRoleStatement {
215                        statement_type: statement_type_name(&other).to_string(),
216                        role_name: expected_name.to_string(),
217                    },
218                    path.to_path_buf(),
219                    other.to_string(),
220                    byte_offset,
221                ));
222            }
223        }
224    }
225
226    // Validate exactly one CREATE ROLE (file-level errors)
227    if create_stmts.is_empty() {
228        errors.push(ValidationError::with_file(
229            ValidationErrorKind::RoleMissingCreateStatement {
230                role_name: expected_name.to_string(),
231            },
232            path.to_path_buf(),
233        ));
234    } else if create_stmts.len() > 1 {
235        // Point to the second CREATE ROLE
236        errors.push(ValidationError::with_file_and_offset(
237            ValidationErrorKind::RoleMultipleCreateStatements {
238                role_name: expected_name.to_string(),
239            },
240            path.to_path_buf(),
241            create_stmts[1].1,
242        ));
243    }
244
245    if !errors.is_empty() {
246        return Err(errors);
247    }
248
249    let (create_stmt, create_offset) = create_stmts.into_iter().next().unwrap();
250
251    // Validate role name matches filename
252    let declared_name = create_stmt.name.to_string();
253    if declared_name.to_lowercase() != expected_name.to_lowercase() {
254        return Err(vec![ValidationError::with_file_and_offset(
255            ValidationErrorKind::RoleNameMismatch {
256                declared: declared_name,
257                expected: expected_name.to_string(),
258            },
259            path.to_path_buf(),
260            create_offset,
261        )]);
262    }
263
264    Ok(RoleDefinition {
265        name: expected_name.to_string(),
266        create_stmt,
267        alter_stmts,
268        grants,
269        comments,
270    })
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use std::fs;
277    use tempfile::TempDir;
278
279    fn create_test_dir() -> TempDir {
280        TempDir::new().unwrap()
281    }
282
283    #[mz_ore::test]
284    fn test_load_roles_no_directory() {
285        let dir = create_test_dir();
286        let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
287        assert!(
288            result.is_empty(),
289            "should return empty vec when roles/ doesn't exist"
290        );
291    }
292
293    #[mz_ore::test]
294    fn test_load_roles_basic() {
295        let dir = create_test_dir();
296        let roles_dir = dir.path().join("roles");
297        fs::create_dir(&roles_dir).unwrap();
298
299        fs::write(
300            roles_dir.join("analyst.sql"),
301            "CREATE ROLE analyst INHERIT;\n\
302             ALTER ROLE analyst SET cluster TO 'analytics';\n\
303             GRANT analyst TO joe, jane;\n\
304             COMMENT ON ROLE analyst IS 'Read-only analytics access';",
305        )
306        .unwrap();
307
308        let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
309        assert_eq!(result.len(), 1);
310        assert_eq!(result[0].name, "analyst");
311        assert_eq!(result[0].alter_stmts.len(), 1);
312        assert_eq!(result[0].grants.len(), 1);
313        assert_eq!(result[0].comments.len(), 1);
314    }
315
316    #[mz_ore::test]
317    fn test_load_roles_create_only() {
318        let dir = create_test_dir();
319        let roles_dir = dir.path().join("roles");
320        fs::create_dir(&roles_dir).unwrap();
321
322        fs::write(roles_dir.join("reader.sql"), "CREATE ROLE reader;").unwrap();
323
324        let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
325        assert_eq!(result.len(), 1);
326        assert_eq!(result[0].name, "reader");
327        assert!(result[0].alter_stmts.is_empty());
328        assert!(result[0].grants.is_empty());
329        assert!(result[0].comments.is_empty());
330    }
331
332    #[mz_ore::test]
333    fn test_load_roles_name_mismatch() {
334        let dir = create_test_dir();
335        let roles_dir = dir.path().join("roles");
336        fs::create_dir(&roles_dir).unwrap();
337
338        fs::write(roles_dir.join("analyst.sql"), "CREATE ROLE wrong_name;").unwrap();
339
340        let result = load_roles(dir.path(), "default", &BTreeMap::new());
341        assert!(
342            result.is_err(),
343            "should error when role name doesn't match filename"
344        );
345    }
346
347    #[mz_ore::test]
348    fn test_load_roles_missing_create() {
349        let dir = create_test_dir();
350        let roles_dir = dir.path().join("roles");
351        fs::create_dir(&roles_dir).unwrap();
352
353        fs::write(roles_dir.join("analyst.sql"), "GRANT analyst TO joe;").unwrap();
354
355        let result = load_roles(dir.path(), "default", &BTreeMap::new());
356        assert!(
357            result.is_err(),
358            "should error when no CREATE ROLE statement"
359        );
360    }
361
362    #[mz_ore::test]
363    fn test_load_roles_unsupported_statement() {
364        let dir = create_test_dir();
365        let roles_dir = dir.path().join("roles");
366        fs::create_dir(&roles_dir).unwrap();
367
368        fs::write(
369            roles_dir.join("analyst.sql"),
370            "CREATE ROLE analyst;\n\
371             CREATE TABLE foo (id INT);",
372        )
373        .unwrap();
374
375        let result = load_roles(dir.path(), "default", &BTreeMap::new());
376        assert!(
377            result.is_err(),
378            "should error on unsupported statement type"
379        );
380    }
381
382    #[mz_ore::test]
383    fn test_load_roles_alter_target_mismatch() {
384        let dir = create_test_dir();
385        let roles_dir = dir.path().join("roles");
386        fs::create_dir(&roles_dir).unwrap();
387
388        fs::write(
389            roles_dir.join("analyst.sql"),
390            "CREATE ROLE analyst;\n\
391             ALTER ROLE other_role SET cluster TO 'analytics';",
392        )
393        .unwrap();
394
395        let result = load_roles(dir.path(), "default", &BTreeMap::new());
396        assert!(
397            result.is_err(),
398            "should error when ALTER targets wrong role"
399        );
400    }
401
402    #[mz_ore::test]
403    fn test_load_roles_grant_target_mismatch() {
404        let dir = create_test_dir();
405        let roles_dir = dir.path().join("roles");
406        fs::create_dir(&roles_dir).unwrap();
407
408        fs::write(
409            roles_dir.join("analyst.sql"),
410            "CREATE ROLE analyst;\n\
411             GRANT other_role TO joe;",
412        )
413        .unwrap();
414
415        let result = load_roles(dir.path(), "default", &BTreeMap::new());
416        assert!(
417            result.is_err(),
418            "should error when grant targets wrong role"
419        );
420    }
421
422    #[mz_ore::test]
423    fn test_load_roles_comment_target_mismatch() {
424        let dir = create_test_dir();
425        let roles_dir = dir.path().join("roles");
426        fs::create_dir(&roles_dir).unwrap();
427
428        fs::write(
429            roles_dir.join("analyst.sql"),
430            "CREATE ROLE analyst;\n\
431             COMMENT ON ROLE other_role IS 'wrong target';",
432        )
433        .unwrap();
434
435        let result = load_roles(dir.path(), "default", &BTreeMap::new());
436        assert!(
437            result.is_err(),
438            "should error when comment targets wrong role"
439        );
440    }
441
442    #[mz_ore::test]
443    fn test_load_roles_multiple_files() {
444        let dir = create_test_dir();
445        let roles_dir = dir.path().join("roles");
446        fs::create_dir(&roles_dir).unwrap();
447
448        fs::write(roles_dir.join("analyst.sql"), "CREATE ROLE analyst;").unwrap();
449
450        fs::write(roles_dir.join("writer.sql"), "CREATE ROLE writer;").unwrap();
451
452        let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
453        assert_eq!(result.len(), 2);
454        // Sorted by filename
455        assert_eq!(result[0].name, "analyst");
456        assert_eq!(result[1].name, "writer");
457    }
458}