Skip to main content

mz_deploy/project/
network_policies.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Network policy definition loading and validation.
11//!
12//! Loads network policy definitions from `<root>/network-policies/` directory. Each `.sql` file
13//! defines a single network policy with a required `CREATE NETWORK POLICY` statement and optional
14//! `GRANT` and `COMMENT` statements.
15
16use crate::project::error::{
17    LoadError, ProjectError, ValidationError, ValidationErrorKind, ValidationErrors,
18};
19use crate::project::syntax::parser::{
20    LocatedStatement, parse_statements_with_context, statement_type_name,
21};
22use crate::project::syntax::profile_files::collect_all_sql_files;
23use mz_sql_parser::ast::{
24    CommentObjectType, CommentStatement, CreateNetworkPolicyStatement, GrantPrivilegesStatement,
25    GrantTargetSpecification, GrantTargetSpecificationInner, ObjectType, Raw, RawNetworkPolicyName,
26    Statement,
27};
28use std::collections::BTreeMap;
29use std::path::Path;
30
31/// A parsed network policy definition from a `.sql` file in the `network-policies/` directory.
32pub(crate) struct NetworkPolicyDefinition {
33    /// Network policy name (derived from filename and validated against CREATE statement).
34    pub name: String,
35    /// The CREATE NETWORK POLICY statement.
36    pub create_stmt: CreateNetworkPolicyStatement<Raw>,
37    /// Optional GRANT statements targeting this network policy.
38    pub grants: Vec<GrantPrivilegesStatement<Raw>>,
39    /// Optional COMMENT statements targeting this network policy.
40    pub comments: Vec<CommentStatement<Raw>>,
41}
42
43/// Load all network policy definitions from `<root>/network-policies/`.
44///
45/// Returns an empty vec if `network-policies/` doesn't exist (the directory is optional).
46pub(crate) fn load_network_policies(
47    root: &Path,
48    profile: &str,
49    variables: &BTreeMap<String, String>,
50) -> Result<Vec<NetworkPolicyDefinition>, ProjectError> {
51    let policies_dir = root.join("network-policies");
52
53    if !policies_dir.exists() {
54        return Ok(vec![]);
55    }
56
57    if !policies_dir.is_dir() {
58        return Err(LoadError::RootNotDirectory { path: policies_dir }.into());
59    }
60
61    let all_files = collect_all_sql_files(&policies_dir)?;
62
63    let mut definitions = Vec::new();
64    let mut errors = Vec::new();
65
66    for object_files in all_files {
67        let expected_name = &object_files.name;
68
69        // Validate all variants independently
70        let mut all_variant_paths = Vec::new();
71        if let Some(ref default_path) = object_files.default {
72            all_variant_paths.push(default_path.clone());
73        }
74        for (_, override_path) in &object_files.overrides {
75            all_variant_paths.push(override_path.clone());
76        }
77
78        for path in &all_variant_paths {
79            let sql = std::fs::read_to_string(path).map_err(|e| LoadError::FileReadFailed {
80                path: path.clone(),
81                source: e,
82            })?;
83            let located = parse_statements_with_context(&sql, path.clone(), variables, true)?;
84
85            if let Err(mut errs) = classify_network_policy_statements(expected_name, path, located)
86            {
87                errors.append(&mut errs);
88            }
89        }
90
91        // Resolve the active variant: prefer profile match, fall back to default
92        let active_path = object_files
93            .overrides
94            .get(profile)
95            .or(object_files.default.as_ref());
96
97        let active_path = match active_path {
98            Some(p) => p.clone(),
99            None => continue,
100        };
101
102        let sql = std::fs::read_to_string(&active_path).map_err(|e| LoadError::FileReadFailed {
103            path: active_path.clone(),
104            source: e,
105        })?;
106        let located = parse_statements_with_context(&sql, active_path.clone(), variables, true)?;
107
108        match classify_network_policy_statements(expected_name, &active_path, located) {
109            Ok(def) => definitions.push(def),
110            Err(mut errs) => errors.append(&mut errs),
111        }
112    }
113
114    if !errors.is_empty() {
115        return Err(ValidationErrors::new(errors).into());
116    }
117
118    Ok(definitions)
119}
120
121/// Classify parsed statements into a `NetworkPolicyDefinition`, returning validation errors.
122fn classify_network_policy_statements(
123    expected_name: &str,
124    path: &Path,
125    located_statements: Vec<LocatedStatement>,
126) -> Result<NetworkPolicyDefinition, Vec<ValidationError>> {
127    let mut create_stmts: Vec<(CreateNetworkPolicyStatement<Raw>, usize)> = Vec::new();
128    let mut grants: Vec<GrantPrivilegesStatement<Raw>> = Vec::new();
129    let mut comments: Vec<CommentStatement<Raw>> = Vec::new();
130    let mut errors = Vec::new();
131
132    for LocatedStatement {
133        ast: stmt,
134        byte_offset,
135    } in located_statements
136    {
137        match stmt {
138            Statement::CreateNetworkPolicy(s) => {
139                create_stmts.push((s, byte_offset));
140            }
141            Statement::GrantPrivileges(s) => {
142                // Validate that the grant targets a network policy
143                match &s.target {
144                    GrantTargetSpecification::Object {
145                        object_type: ObjectType::NetworkPolicy,
146                        object_spec_inner: GrantTargetSpecificationInner::Objects { names },
147                    } => {
148                        // Validate policy name matches
149                        for name in names {
150                            let target_name = name.to_string();
151                            if target_name.to_lowercase() != expected_name.to_lowercase() {
152                                errors.push(ValidationError::with_file_sql_and_offset(
153                                    ValidationErrorKind::NetworkPolicyGrantTargetMismatch {
154                                        target: target_name,
155                                        policy_name: expected_name.to_string(),
156                                    },
157                                    path.to_path_buf(),
158                                    s.to_string(),
159                                    byte_offset,
160                                ));
161                            }
162                        }
163                        grants.push(s);
164                    }
165                    _ => {
166                        errors.push(ValidationError::with_file_sql_and_offset(
167                            ValidationErrorKind::InvalidNetworkPolicyStatement {
168                                statement_type: "GRANT (not targeting a network policy)"
169                                    .to_string(),
170                                policy_name: expected_name.to_string(),
171                            },
172                            path.to_path_buf(),
173                            s.to_string(),
174                            byte_offset,
175                        ));
176                    }
177                }
178            }
179            Statement::Comment(s) => {
180                // Validate that the comment targets a network policy
181                match &s.object {
182                    CommentObjectType::NetworkPolicy { name } => {
183                        let target_name = match name {
184                            RawNetworkPolicyName::Unresolved(ident) => ident.to_string(),
185                            RawNetworkPolicyName::Resolved(id) => id.clone(),
186                        };
187                        if target_name.to_lowercase() != expected_name.to_lowercase() {
188                            errors.push(ValidationError::with_file_sql_and_offset(
189                                ValidationErrorKind::NetworkPolicyCommentTargetMismatch {
190                                    target: target_name,
191                                    policy_name: expected_name.to_string(),
192                                },
193                                path.to_path_buf(),
194                                s.to_string(),
195                                byte_offset,
196                            ));
197                        }
198                        comments.push(s);
199                    }
200                    _ => {
201                        errors.push(ValidationError::with_file_sql_and_offset(
202                            ValidationErrorKind::InvalidNetworkPolicyStatement {
203                                statement_type: "COMMENT (not targeting a network policy)"
204                                    .to_string(),
205                                policy_name: expected_name.to_string(),
206                            },
207                            path.to_path_buf(),
208                            s.to_string(),
209                            byte_offset,
210                        ));
211                    }
212                }
213            }
214            other => {
215                errors.push(ValidationError::with_file_sql_and_offset(
216                    ValidationErrorKind::InvalidNetworkPolicyStatement {
217                        statement_type: statement_type_name(&other).to_string(),
218                        policy_name: expected_name.to_string(),
219                    },
220                    path.to_path_buf(),
221                    other.to_string(),
222                    byte_offset,
223                ));
224            }
225        }
226    }
227
228    // Validate exactly one CREATE NETWORK POLICY
229    if create_stmts.is_empty() {
230        errors.push(ValidationError::with_file(
231            ValidationErrorKind::NetworkPolicyMissingCreateStatement {
232                policy_name: expected_name.to_string(),
233            },
234            path.to_path_buf(),
235        ));
236    } else if create_stmts.len() > 1 {
237        // Point to the second CREATE NETWORK POLICY
238        errors.push(ValidationError::with_file_and_offset(
239            ValidationErrorKind::NetworkPolicyMultipleCreateStatements {
240                policy_name: expected_name.to_string(),
241            },
242            path.to_path_buf(),
243            create_stmts[1].1,
244        ));
245    }
246
247    if !errors.is_empty() {
248        return Err(errors);
249    }
250
251    let (create_stmt, create_offset) = create_stmts.into_iter().next().unwrap();
252
253    // Validate policy name matches filename
254    let declared_name = create_stmt.name.to_string();
255    if declared_name.to_lowercase() != expected_name.to_lowercase() {
256        return Err(vec![ValidationError::with_file_and_offset(
257            ValidationErrorKind::NetworkPolicyNameMismatch {
258                declared: declared_name,
259                expected: expected_name.to_string(),
260            },
261            path.to_path_buf(),
262            create_offset,
263        )]);
264    }
265
266    Ok(NetworkPolicyDefinition {
267        name: expected_name.to_string(),
268        create_stmt,
269        grants,
270        comments,
271    })
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use std::fs;
278    use tempfile::TempDir;
279
280    fn create_test_dir() -> TempDir {
281        TempDir::new().unwrap()
282    }
283
284    #[mz_ore::test]
285    fn test_load_network_policies_no_directory() {
286        let dir = create_test_dir();
287        let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
288        assert!(
289            result.is_empty(),
290            "should return empty vec when network-policies/ doesn't exist"
291        );
292    }
293
294    #[mz_ore::test]
295    fn test_load_network_policies_basic() {
296        let dir = create_test_dir();
297        let policies_dir = dir.path().join("network-policies");
298        fs::create_dir(&policies_dir).unwrap();
299
300        fs::write(
301            policies_dir.join("office_access.sql"),
302            "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));\n\
303             COMMENT ON NETWORK POLICY office_access IS 'Office network access';",
304        )
305        .unwrap();
306
307        let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
308        assert_eq!(result.len(), 1);
309        assert_eq!(result[0].name, "office_access");
310        assert_eq!(result[0].comments.len(), 1);
311    }
312
313    #[mz_ore::test]
314    fn test_load_network_policies_create_only() {
315        let dir = create_test_dir();
316        let policies_dir = dir.path().join("network-policies");
317        fs::create_dir(&policies_dir).unwrap();
318
319        fs::write(
320            policies_dir.join("office_access.sql"),
321            "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));",
322        )
323        .unwrap();
324
325        let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
326        assert_eq!(result.len(), 1);
327        assert_eq!(result[0].name, "office_access");
328        assert!(result[0].grants.is_empty());
329        assert!(result[0].comments.is_empty());
330    }
331
332    #[mz_ore::test]
333    fn test_load_network_policies_name_mismatch() {
334        let dir = create_test_dir();
335        let policies_dir = dir.path().join("network-policies");
336        fs::create_dir(&policies_dir).unwrap();
337
338        fs::write(
339            policies_dir.join("office_access.sql"),
340            "CREATE NETWORK POLICY wrong_name (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));",
341        )
342        .unwrap();
343
344        let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
345        assert!(
346            result.is_err(),
347            "should error when policy name doesn't match filename"
348        );
349    }
350
351    #[mz_ore::test]
352    fn test_load_network_policies_missing_create() {
353        let dir = create_test_dir();
354        let policies_dir = dir.path().join("network-policies");
355        fs::create_dir(&policies_dir).unwrap();
356
357        fs::write(
358            policies_dir.join("office_access.sql"),
359            "COMMENT ON NETWORK POLICY office_access IS 'Office network access';",
360        )
361        .unwrap();
362
363        let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
364        assert!(
365            result.is_err(),
366            "should error when no CREATE NETWORK POLICY statement"
367        );
368    }
369
370    #[mz_ore::test]
371    fn test_load_network_policies_unsupported_statement() {
372        let dir = create_test_dir();
373        let policies_dir = dir.path().join("network-policies");
374        fs::create_dir(&policies_dir).unwrap();
375
376        fs::write(
377            policies_dir.join("office_access.sql"),
378            "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));\n\
379             CREATE TABLE foo (id INT);",
380        )
381        .unwrap();
382
383        let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
384        assert!(
385            result.is_err(),
386            "should error on unsupported statement type"
387        );
388    }
389
390    #[mz_ore::test]
391    fn test_load_network_policies_comment_target_mismatch() {
392        let dir = create_test_dir();
393        let policies_dir = dir.path().join("network-policies");
394        fs::create_dir(&policies_dir).unwrap();
395
396        fs::write(
397            policies_dir.join("office_access.sql"),
398            "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));\n\
399             COMMENT ON NETWORK POLICY other_policy IS 'wrong target';",
400        )
401        .unwrap();
402
403        let result = load_network_policies(dir.path(), "default", &BTreeMap::new());
404        assert!(
405            result.is_err(),
406            "should error when comment targets wrong policy"
407        );
408    }
409
410    #[mz_ore::test]
411    fn test_load_network_policies_multiple_files() {
412        let dir = create_test_dir();
413        let policies_dir = dir.path().join("network-policies");
414        fs::create_dir(&policies_dir).unwrap();
415
416        fs::write(
417            policies_dir.join("office_access.sql"),
418            "CREATE NETWORK POLICY office_access (RULES (office (action = 'allow', direction = 'ingress', address = '1.2.3.4/28')));",
419        )
420        .unwrap();
421
422        fs::write(
423            policies_dir.join("vpn_access.sql"),
424            "CREATE NETWORK POLICY vpn_access (RULES (vpn (action = 'allow', direction = 'ingress', address = '10.0.0.0/8')));",
425        )
426        .unwrap();
427
428        let result = load_network_policies(dir.path(), "default", &BTreeMap::new()).unwrap();
429        assert_eq!(result.len(), 2);
430        // Sorted by filename
431        assert_eq!(result[0].name, "office_access");
432        assert_eq!(result[1].name, "vpn_access");
433    }
434}