1use crate::project::error::{
17 LoadError, ProjectError, ValidationError, ValidationErrorKind, ValidationErrors,
18};
19use crate::project::syntax::parser::{
20 LocatedStatement, parse_statements_with_context, statement_type_name,
21};
22use crate::project::syntax::profile_files::collect_all_sql_files;
23use mz_sql_parser::ast::{
24 AlterRoleStatement, CommentObjectType, CommentStatement, CreateRoleStatement,
25 GrantRoleStatement, Raw, Statement,
26};
27use std::collections::BTreeMap;
28use std::path::Path;
29
30pub(crate) struct RoleDefinition {
32 pub name: String,
34 pub create_stmt: CreateRoleStatement,
36 pub alter_stmts: Vec<AlterRoleStatement<Raw>>,
38 pub grants: Vec<GrantRoleStatement<Raw>>,
40 pub comments: Vec<CommentStatement<Raw>>,
42}
43
44pub(crate) fn load_roles(
48 root: &Path,
49 profile: &str,
50 variables: &BTreeMap<String, String>,
51) -> Result<Vec<RoleDefinition>, ProjectError> {
52 let roles_dir = root.join("roles");
53
54 if !roles_dir.exists() {
55 return Ok(vec![]);
56 }
57
58 if !roles_dir.is_dir() {
59 return Err(LoadError::RootNotDirectory { path: roles_dir }.into());
60 }
61
62 let all_files = collect_all_sql_files(&roles_dir)?;
63
64 let mut definitions = Vec::new();
65 let mut errors = Vec::new();
66
67 for object_files in all_files {
68 let expected_name = &object_files.name;
69
70 let mut all_variant_paths = Vec::new();
72 if let Some(ref default_path) = object_files.default {
73 all_variant_paths.push(default_path.clone());
74 }
75 for (_, override_path) in &object_files.overrides {
76 all_variant_paths.push(override_path.clone());
77 }
78
79 for path in &all_variant_paths {
80 let sql = std::fs::read_to_string(path).map_err(|e| LoadError::FileReadFailed {
81 path: path.clone(),
82 source: e,
83 })?;
84 let located = parse_statements_with_context(&sql, path.clone(), variables, true)?;
85
86 if let Err(mut errs) = classify_role_statements(expected_name, path, located) {
87 errors.append(&mut errs);
88 }
89 }
90
91 let active_path = object_files
93 .overrides
94 .get(profile)
95 .or(object_files.default.as_ref());
96
97 let active_path = match active_path {
98 Some(p) => p.clone(),
99 None => continue,
100 };
101
102 let sql = std::fs::read_to_string(&active_path).map_err(|e| LoadError::FileReadFailed {
103 path: active_path.clone(),
104 source: e,
105 })?;
106 let located = parse_statements_with_context(&sql, active_path.clone(), variables, true)?;
107
108 match classify_role_statements(expected_name, &active_path, located) {
109 Ok(def) => definitions.push(def),
110 Err(mut errs) => errors.append(&mut errs),
111 }
112 }
113
114 if !errors.is_empty() {
115 return Err(ValidationErrors::new(errors).into());
116 }
117
118 Ok(definitions)
119}
120
121fn classify_role_statements(
123 expected_name: &str,
124 path: &Path,
125 located_statements: Vec<LocatedStatement>,
126) -> Result<RoleDefinition, Vec<ValidationError>> {
127 let mut create_stmts: Vec<(CreateRoleStatement, usize)> = Vec::new();
128 let mut alter_stmts: Vec<AlterRoleStatement<Raw>> = Vec::new();
129 let mut grants: Vec<GrantRoleStatement<Raw>> = Vec::new();
130 let mut comments: Vec<CommentStatement<Raw>> = Vec::new();
131 let mut errors = Vec::new();
132
133 for LocatedStatement {
134 ast: stmt,
135 byte_offset,
136 } in located_statements
137 {
138 match stmt {
139 Statement::CreateRole(s) => {
140 create_stmts.push((s, byte_offset));
141 }
142 Statement::AlterRole(s) => {
143 let target_name = s.name.to_string();
145 if target_name.to_lowercase() != expected_name.to_lowercase() {
146 errors.push(ValidationError::with_file_sql_and_offset(
147 ValidationErrorKind::RoleAlterTargetMismatch {
148 target: target_name,
149 role_name: expected_name.to_string(),
150 },
151 path.to_path_buf(),
152 s.to_string(),
153 byte_offset,
154 ));
155 } else {
156 alter_stmts.push(s);
157 }
158 }
159 Statement::GrantRole(s) => {
160 let has_match = s
162 .role_names
163 .iter()
164 .any(|r| r.to_string().to_lowercase() == expected_name.to_lowercase());
165 if !has_match {
166 let target_names: Vec<String> =
167 s.role_names.iter().map(|r| r.to_string()).collect();
168 errors.push(ValidationError::with_file_sql_and_offset(
169 ValidationErrorKind::RoleGrantTargetMismatch {
170 target: target_names.join(", "),
171 role_name: expected_name.to_string(),
172 },
173 path.to_path_buf(),
174 s.to_string(),
175 byte_offset,
176 ));
177 } else {
178 grants.push(s);
179 }
180 }
181 Statement::Comment(s) => {
182 match &s.object {
184 CommentObjectType::Role { name } => {
185 let target_name = name.to_string();
186 if target_name.to_lowercase() != expected_name.to_lowercase() {
187 errors.push(ValidationError::with_file_sql_and_offset(
188 ValidationErrorKind::RoleCommentTargetMismatch {
189 target: target_name,
190 role_name: expected_name.to_string(),
191 },
192 path.to_path_buf(),
193 s.to_string(),
194 byte_offset,
195 ));
196 }
197 comments.push(s);
198 }
199 _ => {
200 errors.push(ValidationError::with_file_sql_and_offset(
201 ValidationErrorKind::InvalidRoleStatement {
202 statement_type: "COMMENT (not targeting a role)".to_string(),
203 role_name: expected_name.to_string(),
204 },
205 path.to_path_buf(),
206 s.to_string(),
207 byte_offset,
208 ));
209 }
210 }
211 }
212 other => {
213 errors.push(ValidationError::with_file_sql_and_offset(
214 ValidationErrorKind::InvalidRoleStatement {
215 statement_type: statement_type_name(&other).to_string(),
216 role_name: expected_name.to_string(),
217 },
218 path.to_path_buf(),
219 other.to_string(),
220 byte_offset,
221 ));
222 }
223 }
224 }
225
226 if create_stmts.is_empty() {
228 errors.push(ValidationError::with_file(
229 ValidationErrorKind::RoleMissingCreateStatement {
230 role_name: expected_name.to_string(),
231 },
232 path.to_path_buf(),
233 ));
234 } else if create_stmts.len() > 1 {
235 errors.push(ValidationError::with_file_and_offset(
237 ValidationErrorKind::RoleMultipleCreateStatements {
238 role_name: expected_name.to_string(),
239 },
240 path.to_path_buf(),
241 create_stmts[1].1,
242 ));
243 }
244
245 if !errors.is_empty() {
246 return Err(errors);
247 }
248
249 let (create_stmt, create_offset) = create_stmts.into_iter().next().unwrap();
250
251 let declared_name = create_stmt.name.to_string();
253 if declared_name.to_lowercase() != expected_name.to_lowercase() {
254 return Err(vec![ValidationError::with_file_and_offset(
255 ValidationErrorKind::RoleNameMismatch {
256 declared: declared_name,
257 expected: expected_name.to_string(),
258 },
259 path.to_path_buf(),
260 create_offset,
261 )]);
262 }
263
264 Ok(RoleDefinition {
265 name: expected_name.to_string(),
266 create_stmt,
267 alter_stmts,
268 grants,
269 comments,
270 })
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use std::fs;
277 use tempfile::TempDir;
278
279 fn create_test_dir() -> TempDir {
280 TempDir::new().unwrap()
281 }
282
283 #[mz_ore::test]
284 fn test_load_roles_no_directory() {
285 let dir = create_test_dir();
286 let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
287 assert!(
288 result.is_empty(),
289 "should return empty vec when roles/ doesn't exist"
290 );
291 }
292
293 #[mz_ore::test]
294 fn test_load_roles_basic() {
295 let dir = create_test_dir();
296 let roles_dir = dir.path().join("roles");
297 fs::create_dir(&roles_dir).unwrap();
298
299 fs::write(
300 roles_dir.join("analyst.sql"),
301 "CREATE ROLE analyst INHERIT;\n\
302 ALTER ROLE analyst SET cluster TO 'analytics';\n\
303 GRANT analyst TO joe, jane;\n\
304 COMMENT ON ROLE analyst IS 'Read-only analytics access';",
305 )
306 .unwrap();
307
308 let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
309 assert_eq!(result.len(), 1);
310 assert_eq!(result[0].name, "analyst");
311 assert_eq!(result[0].alter_stmts.len(), 1);
312 assert_eq!(result[0].grants.len(), 1);
313 assert_eq!(result[0].comments.len(), 1);
314 }
315
316 #[mz_ore::test]
317 fn test_load_roles_create_only() {
318 let dir = create_test_dir();
319 let roles_dir = dir.path().join("roles");
320 fs::create_dir(&roles_dir).unwrap();
321
322 fs::write(roles_dir.join("reader.sql"), "CREATE ROLE reader;").unwrap();
323
324 let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
325 assert_eq!(result.len(), 1);
326 assert_eq!(result[0].name, "reader");
327 assert!(result[0].alter_stmts.is_empty());
328 assert!(result[0].grants.is_empty());
329 assert!(result[0].comments.is_empty());
330 }
331
332 #[mz_ore::test]
333 fn test_load_roles_name_mismatch() {
334 let dir = create_test_dir();
335 let roles_dir = dir.path().join("roles");
336 fs::create_dir(&roles_dir).unwrap();
337
338 fs::write(roles_dir.join("analyst.sql"), "CREATE ROLE wrong_name;").unwrap();
339
340 let result = load_roles(dir.path(), "default", &BTreeMap::new());
341 assert!(
342 result.is_err(),
343 "should error when role name doesn't match filename"
344 );
345 }
346
347 #[mz_ore::test]
348 fn test_load_roles_missing_create() {
349 let dir = create_test_dir();
350 let roles_dir = dir.path().join("roles");
351 fs::create_dir(&roles_dir).unwrap();
352
353 fs::write(roles_dir.join("analyst.sql"), "GRANT analyst TO joe;").unwrap();
354
355 let result = load_roles(dir.path(), "default", &BTreeMap::new());
356 assert!(
357 result.is_err(),
358 "should error when no CREATE ROLE statement"
359 );
360 }
361
362 #[mz_ore::test]
363 fn test_load_roles_unsupported_statement() {
364 let dir = create_test_dir();
365 let roles_dir = dir.path().join("roles");
366 fs::create_dir(&roles_dir).unwrap();
367
368 fs::write(
369 roles_dir.join("analyst.sql"),
370 "CREATE ROLE analyst;\n\
371 CREATE TABLE foo (id INT);",
372 )
373 .unwrap();
374
375 let result = load_roles(dir.path(), "default", &BTreeMap::new());
376 assert!(
377 result.is_err(),
378 "should error on unsupported statement type"
379 );
380 }
381
382 #[mz_ore::test]
383 fn test_load_roles_alter_target_mismatch() {
384 let dir = create_test_dir();
385 let roles_dir = dir.path().join("roles");
386 fs::create_dir(&roles_dir).unwrap();
387
388 fs::write(
389 roles_dir.join("analyst.sql"),
390 "CREATE ROLE analyst;\n\
391 ALTER ROLE other_role SET cluster TO 'analytics';",
392 )
393 .unwrap();
394
395 let result = load_roles(dir.path(), "default", &BTreeMap::new());
396 assert!(
397 result.is_err(),
398 "should error when ALTER targets wrong role"
399 );
400 }
401
402 #[mz_ore::test]
403 fn test_load_roles_grant_target_mismatch() {
404 let dir = create_test_dir();
405 let roles_dir = dir.path().join("roles");
406 fs::create_dir(&roles_dir).unwrap();
407
408 fs::write(
409 roles_dir.join("analyst.sql"),
410 "CREATE ROLE analyst;\n\
411 GRANT other_role TO joe;",
412 )
413 .unwrap();
414
415 let result = load_roles(dir.path(), "default", &BTreeMap::new());
416 assert!(
417 result.is_err(),
418 "should error when grant targets wrong role"
419 );
420 }
421
422 #[mz_ore::test]
423 fn test_load_roles_comment_target_mismatch() {
424 let dir = create_test_dir();
425 let roles_dir = dir.path().join("roles");
426 fs::create_dir(&roles_dir).unwrap();
427
428 fs::write(
429 roles_dir.join("analyst.sql"),
430 "CREATE ROLE analyst;\n\
431 COMMENT ON ROLE other_role IS 'wrong target';",
432 )
433 .unwrap();
434
435 let result = load_roles(dir.path(), "default", &BTreeMap::new());
436 assert!(
437 result.is_err(),
438 "should error when comment targets wrong role"
439 );
440 }
441
442 #[mz_ore::test]
443 fn test_load_roles_multiple_files() {
444 let dir = create_test_dir();
445 let roles_dir = dir.path().join("roles");
446 fs::create_dir(&roles_dir).unwrap();
447
448 fs::write(roles_dir.join("analyst.sql"), "CREATE ROLE analyst;").unwrap();
449
450 fs::write(roles_dir.join("writer.sql"), "CREATE ROLE writer;").unwrap();
451
452 let result = load_roles(dir.path(), "default", &BTreeMap::new()).unwrap();
453 assert_eq!(result.len(), 2);
454 assert_eq!(result[0].name, "analyst");
456 assert_eq!(result[1].name, "writer");
457 }
458}