1use crate::project::error::{
17 LoadError, ProjectError, ValidationError, ValidationErrorKind, ValidationErrors,
18};
19use crate::project::syntax::parser::{
20 LocatedStatement, parse_statements_with_context, statement_type_name,
21};
22use crate::project::syntax::profile_files::collect_all_sql_files;
23use mz_sql_parser::ast::{
24 CommentObjectType, CommentStatement, CreateNetworkPolicyStatement, GrantPrivilegesStatement,
25 GrantTargetSpecification, GrantTargetSpecificationInner, ObjectType, Raw, RawNetworkPolicyName,
26 Statement,
27};
28use std::collections::BTreeMap;
29use std::path::Path;
30
31pub(crate) struct NetworkPolicyDefinition {
33 pub name: String,
35 pub create_stmt: CreateNetworkPolicyStatement<Raw>,
37 pub grants: Vec<GrantPrivilegesStatement<Raw>>,
39 pub comments: Vec<CommentStatement<Raw>>,
41}
42
43pub(crate) fn load_network_policies(
47 root: &Path,
48 profile: &str,
49 variables: &BTreeMap<String, String>,
50) -> Result<Vec<NetworkPolicyDefinition>, ProjectError> {
51 let policies_dir = root.join("network-policies");
52
53 if !policies_dir.exists() {
54 return Ok(vec![]);
55 }
56
57 if !policies_dir.is_dir() {
58 return Err(LoadError::RootNotDirectory { path: policies_dir }.into());
59 }
60
61 let all_files = collect_all_sql_files(&policies_dir)?;
62
63 let mut definitions = Vec::new();
64 let mut errors = Vec::new();
65
66 for object_files in all_files {
67 let expected_name = &object_files.name;
68
69 let mut all_variant_paths = Vec::new();
71 if let Some(ref default_path) = object_files.default {
72 all_variant_paths.push(default_path.clone());
73 }
74 for (_, override_path) in &object_files.overrides {
75 all_variant_paths.push(override_path.clone());
76 }
77
78 for path in &all_variant_paths {
79 let sql = std::fs::read_to_string(path).map_err(|e| LoadError::FileReadFailed {
80 path: path.clone(),
81 source: e,
82 })?;
83 let located = parse_statements_with_context(&sql, path.clone(), variables, true)?;
84
85 if let Err(mut errs) = classify_network_policy_statements(expected_name, path, located)
86 {
87 errors.append(&mut errs);
88 }
89 }
90
91 let active_path = object_files
93 .overrides
94 .get(profile)
95 .or(object_files.default.as_ref());
96
97 let active_path = match active_path {
98 Some(p) => p.clone(),
99 None => continue,
100 };
101
102 let sql = std::fs::read_to_string(&active_path).map_err(|e| LoadError::FileReadFailed {
103 path: active_path.clone(),
104 source: e,
105 })?;
106 let located = parse_statements_with_context(&sql, active_path.clone(), variables, true)?;
107
108 match classify_network_policy_statements(expected_name, &active_path, located) {
109 Ok(def) => definitions.push(def),
110 Err(mut errs) => errors.append(&mut errs),
111 }
112 }
113
114 if !errors.is_empty() {
115 return Err(ValidationErrors::new(errors).into());
116 }
117
118 Ok(definitions)
119}
120
121fn classify_network_policy_statements(
123 expected_name: &str,
124 path: &Path,
125 located_statements: Vec<LocatedStatement>,
126) -> Result<NetworkPolicyDefinition, Vec<ValidationError>> {
127 let mut create_stmts: Vec<(CreateNetworkPolicyStatement<Raw>, usize)> = Vec::new();
128 let mut grants: Vec<GrantPrivilegesStatement<Raw>> = Vec::new();
129 let mut comments: Vec<CommentStatement<Raw>> = Vec::new();
130 let mut errors = Vec::new();
131
132 for LocatedStatement {
133 ast: stmt,
134 byte_offset,
135 } in located_statements
136 {
137 match stmt {
138 Statement::CreateNetworkPolicy(s) => {
139 create_stmts.push((s, byte_offset));
140 }
141 Statement::GrantPrivileges(s) => {
142 match &s.target {
144 GrantTargetSpecification::Object {
145 object_type: ObjectType::NetworkPolicy,
146 object_spec_inner: GrantTargetSpecificationInner::Objects { names },
147 } => {
148 for name in names {
150 let target_name = name.to_string();
151 if target_name.to_lowercase() != expected_name.to_lowercase() {
152 errors.push(ValidationError::with_file_sql_and_offset(
153 ValidationErrorKind::NetworkPolicyGrantTargetMismatch {
154 target: target_name,
155 policy_name: expected_name.to_string(),
156 },
157 path.to_path_buf(),
158 s.to_string(),
159 byte_offset,
160 ));
161 }
162 }
163 grants.push(s);
164 }
165 _ => {
166 errors.push(ValidationError::with_file_sql_and_offset(
167 ValidationErrorKind::InvalidNetworkPolicyStatement {
168 statement_type: "GRANT (not targeting a network policy)"
169 .to_string(),
170 policy_name: expected_name.to_string(),
171 },
172 path.to_path_buf(),
173 s.to_string(),
174 byte_offset,
175 ));
176 }
177 }
178 }
179 Statement::Comment(s) => {
180 match &s.object {
182 CommentObjectType::NetworkPolicy { name } => {
183 let target_name = match name {
184 RawNetworkPolicyName::Unresolved(ident) => ident.to_string(),
185 RawNetworkPolicyName::Resolved(id) => id.clone(),
186 };
187 if target_name.to_lowercase() != expected_name.to_lowercase() {
188 errors.push(ValidationError::with_file_sql_and_offset(
189 ValidationErrorKind::NetworkPolicyCommentTargetMismatch {
190 target: target_name,
191 policy_name: expected_name.to_string(),
192 },
193 path.to_path_buf(),
194 s.to_string(),
195 byte_offset,
196 ));
197 }
198 comments.push(s);
199 }
200 _ => {
201 errors.push(ValidationError::with_file_sql_and_offset(
202 ValidationErrorKind::InvalidNetworkPolicyStatement {
203 statement_type: "COMMENT (not targeting a network policy)"
204 .to_string(),
205 policy_name: expected_name.to_string(),
206 },
207 path.to_path_buf(),
208 s.to_string(),
209 byte_offset,
210 ));
211 }
212 }
213 }
214 other => {
215 errors.push(ValidationError::with_file_sql_and_offset(
216 ValidationErrorKind::InvalidNetworkPolicyStatement {
217 statement_type: statement_type_name(&other).to_string(),
218 policy_name: expected_name.to_string(),
219 },
220 path.to_path_buf(),
221 other.to_string(),
222 byte_offset,
223 ));
224 }
225 }
226 }
227
228 if create_stmts.is_empty() {
230 errors.push(ValidationError::with_file(
231 ValidationErrorKind::NetworkPolicyMissingCreateStatement {
232 policy_name: expected_name.to_string(),
233 },
234 path.to_path_buf(),
235 ));
236 } else if create_stmts.len() > 1 {
237 errors.push(ValidationError::with_file_and_offset(
239 ValidationErrorKind::NetworkPolicyMultipleCreateStatements {
240 policy_name: expected_name.to_string(),
241 },
242 path.to_path_buf(),
243 create_stmts[1].1,
244 ));
245 }
246
247 if !errors.is_empty() {
248 return Err(errors);
249 }
250
251 let (create_stmt, create_offset) = create_stmts.into_iter().next().unwrap();
252
253 let declared_name = create_stmt.name.to_string();
255 if declared_name.to_lowercase() != expected_name.to_lowercase() {
256 return Err(vec![ValidationError::with_file_and_offset(
257 ValidationErrorKind::NetworkPolicyNameMismatch {
258 declared: declared_name,
259 expected: expected_name.to_string(),
260 },
261 path.to_path_buf(),
262 create_offset,
263 )]);
264 }
265
266 Ok(NetworkPolicyDefinition {
267 name: expected_name.to_string(),
268 create_stmt,
269 grants,
270 comments,
271 })
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use std::fs;
278 use tempfile::TempDir;
279
280 fn create_test_dir() -> TempDir {
281 TempDir::new().unwrap()
282 }
283
284 #[mz_ore::test]
285 fn test_load_network_policies_no_directory() {
286 let dir = create_test_dir();
287 let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
288 assert!(
289 result.is_empty(),
290 "should return empty vec when network-policies/ doesn't exist"
291 );
292 }
293
294 #[mz_ore::test]
295 fn test_load_network_policies_basic() {
296 let dir = create_test_dir();
297 let policies_dir = dir.path().join("network-policies");
298 fs::create_dir(&policies_dir).unwrap();
299
300 fs::write(
301 policies_dir.join("office_access.sql"),
302 "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));\n\
303 COMMENT ON NETWORK POLICY office_access IS 'Office network access';",
304 )
305 .unwrap();
306
307 let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
308 assert_eq!(result.len(), 1);
309 assert_eq!(result[0].name, "office_access");
310 assert_eq!(result[0].comments.len(), 1);
311 }
312
313 #[mz_ore::test]
314 fn test_load_network_policies_create_only() {
315 let dir = create_test_dir();
316 let policies_dir = dir.path().join("network-policies");
317 fs::create_dir(&policies_dir).unwrap();
318
319 fs::write(
320 policies_dir.join("office_access.sql"),
321 "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));",
322 )
323 .unwrap();
324
325 let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
326 assert_eq!(result.len(), 1);
327 assert_eq!(result[0].name, "office_access");
328 assert!(result[0].grants.is_empty());
329 assert!(result[0].comments.is_empty());
330 }
331
332 #[mz_ore::test]
333 fn test_load_network_policies_name_mismatch() {
334 let dir = create_test_dir();
335 let policies_dir = dir.path().join("network-policies");
336 fs::create_dir(&policies_dir).unwrap();
337
338 fs::write(
339 policies_dir.join("office_access.sql"),
340 "CREATE NETWORK POLICY wrong_name (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));",
341 )
342 .unwrap();
343
344 let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
345 assert!(
346 result.is_err(),
347 "should error when policy name doesn't match filename"
348 );
349 }
350
351 #[mz_ore::test]
352 fn test_load_network_policies_missing_create() {
353 let dir = create_test_dir();
354 let policies_dir = dir.path().join("network-policies");
355 fs::create_dir(&policies_dir).unwrap();
356
357 fs::write(
358 policies_dir.join("office_access.sql"),
359 "COMMENT ON NETWORK POLICY office_access IS 'Office network access';",
360 )
361 .unwrap();
362
363 let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
364 assert!(
365 result.is_err(),
366 "should error when no CREATE NETWORK POLICY statement"
367 );
368 }
369
370 #[mz_ore::test]
371 fn test_load_network_policies_unsupported_statement() {
372 let dir = create_test_dir();
373 let policies_dir = dir.path().join("network-policies");
374 fs::create_dir(&policies_dir).unwrap();
375
376 fs::write(
377 policies_dir.join("office_access.sql"),
378 "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));\n\
379 CREATE TABLE foo (id INT);",
380 )
381 .unwrap();
382
383 let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
384 assert!(
385 result.is_err(),
386 "should error on unsupported statement type"
387 );
388 }
389
390 #[mz_ore::test]
391 fn test_load_network_policies_comment_target_mismatch() {
392 let dir = create_test_dir();
393 let policies_dir = dir.path().join("network-policies");
394 fs::create_dir(&policies_dir).unwrap();
395
396 fs::write(
397 policies_dir.join("office_access.sql"),
398 "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));\n\
399 COMMENT ON NETWORK POLICY other_policy IS 'wrong target';",
400 )
401 .unwrap();
402
403 let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
404 assert!(
405 result.is_err(),
406 "should error when comment targets wrong policy"
407 );
408 }
409
410 #[mz_ore::test]
411 fn test_load_network_policies_multiple_files() {
412 let dir = create_test_dir();
413 let policies_dir = dir.path().join("network-policies");
414 fs::create_dir(&policies_dir).unwrap();
415
416 fs::write(
417 policies_dir.join("office_access.sql"),
418 "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));",
419 )
420 .unwrap();
421
422 fs::write(
423 policies_dir.join("vpn_access.sql"),
424 "CREATE NETWORK POLICY vpn_access (RULES (vpn (action = 'allow', direction = 'ingress', address = '10.0.0.0/8')));",
425 )
426 .unwrap();
427
428 let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
429 assert_eq!(result.len(), 2);
430 assert_eq!(result[0].name, "office_access");
432 assert_eq!(result[1].name, "vpn_access");
433 }
434}