parquet/schema/
visitor.rs1use crate::basic::{ConvertedType, Repetition};
21use crate::errors::ParquetError::General;
22use crate::errors::Result;
23use crate::schema::types::{Type, TypePtr};
24
25pub trait TypeVisitor<R, C> {
27    fn visit_primitive(&mut self, primitive_type: TypePtr, context: C) -> Result<R>;
29
30    fn visit_list(&mut self, list_type: TypePtr, context: C) -> Result<R> {
52        match list_type.as_ref() {
53            Type::PrimitiveType { .. } => {
54                panic!("{list_type:?} is a list type and must be a group type")
55            }
56            Type::GroupType {
57                basic_info: _,
58                fields,
59            } if fields.len() == 1 => {
60                let list_item = fields.first().unwrap();
61
62                match list_item.as_ref() {
63                    Type::PrimitiveType { .. } => {
64                        if list_item.get_basic_info().repetition() == Repetition::REPEATED {
65                            self.visit_list_with_item(list_type.clone(), list_item.clone(), context)
66                        } else {
67                            Err(General(
68                                "Primitive element type of list must be repeated.".to_string(),
69                            ))
70                        }
71                    }
72                    Type::GroupType {
73                        basic_info: _,
74                        fields,
75                    } => {
76                        if fields.len() == 1
77                            && list_item.name() != "array"
78                            && list_item.name() != format!("{}_tuple", list_type.name())
79                        {
80                            self.visit_list_with_item(
81                                list_type.clone(),
82                                fields.first().unwrap().clone(),
83                                context,
84                            )
85                        } else {
86                            self.visit_list_with_item(list_type.clone(), list_item.clone(), context)
87                        }
88                    }
89                }
90            }
91            _ => Err(General(
92                "Group element type of list can only contain one field.".to_string(),
93            )),
94        }
95    }
96
97    fn visit_struct(&mut self, struct_type: TypePtr, context: C) -> Result<R>;
99
100    fn visit_map(&mut self, map_type: TypePtr, context: C) -> Result<R>;
102
103    fn dispatch(&mut self, cur_type: TypePtr, context: C) -> Result<R> {
105        if cur_type.is_primitive() {
106            self.visit_primitive(cur_type, context)
107        } else {
108            match cur_type.get_basic_info().converted_type() {
109                ConvertedType::LIST => self.visit_list(cur_type, context),
110                ConvertedType::MAP | ConvertedType::MAP_KEY_VALUE => {
111                    self.visit_map(cur_type, context)
112                }
113                _ => self.visit_struct(cur_type, context),
114            }
115        }
116    }
117
118    fn visit_list_with_item(
120        &mut self,
121        list_type: TypePtr,
122        item_type: TypePtr,
123        context: C,
124    ) -> Result<R>;
125}
126
127#[cfg(test)]
128mod tests {
129    use super::TypeVisitor;
130    use crate::basic::Type as PhysicalType;
131    use crate::errors::Result;
132    use crate::schema::parser::parse_message_type;
133    use crate::schema::types::TypePtr;
134    use std::sync::Arc;
135
136    struct TestVisitorContext {}
137    struct TestVisitor {
138        primitive_visited: bool,
139        struct_visited: bool,
140        list_visited: bool,
141        root_type: TypePtr,
142    }
143
144    impl TypeVisitor<bool, TestVisitorContext> for TestVisitor {
145        fn visit_primitive(
146            &mut self,
147            primitive_type: TypePtr,
148            _context: TestVisitorContext,
149        ) -> Result<bool> {
150            assert_eq!(
151                self.get_field_by_name(primitive_type.name()).as_ref(),
152                primitive_type.as_ref()
153            );
154            self.primitive_visited = true;
155            Ok(true)
156        }
157
158        fn visit_struct(
159            &mut self,
160            struct_type: TypePtr,
161            _context: TestVisitorContext,
162        ) -> Result<bool> {
163            assert_eq!(
164                self.get_field_by_name(struct_type.name()).as_ref(),
165                struct_type.as_ref()
166            );
167            self.struct_visited = true;
168            Ok(true)
169        }
170
171        fn visit_map(&mut self, _map_type: TypePtr, _context: TestVisitorContext) -> Result<bool> {
172            unimplemented!()
173        }
174
175        fn visit_list_with_item(
176            &mut self,
177            list_type: TypePtr,
178            item_type: TypePtr,
179            _context: TestVisitorContext,
180        ) -> Result<bool> {
181            assert_eq!(
182                self.get_field_by_name(list_type.name()).as_ref(),
183                list_type.as_ref()
184            );
185            assert_eq!("element", item_type.name());
186            assert_eq!(PhysicalType::INT32, item_type.get_physical_type());
187            self.list_visited = true;
188            Ok(true)
189        }
190    }
191
192    impl TestVisitor {
193        fn new(root: TypePtr) -> Self {
194            Self {
195                primitive_visited: false,
196                struct_visited: false,
197                list_visited: false,
198                root_type: root,
199            }
200        }
201
202        fn get_field_by_name(&self, name: &str) -> TypePtr {
203            self.root_type
204                .get_fields()
205                .iter()
206                .find(|t| t.name() == name)
207                .cloned()
208                .unwrap()
209        }
210    }
211
212    #[test]
213    fn test_visitor() {
214        let message_type = "
215          message spark_schema {
216            REQUIRED INT32 a;
217            OPTIONAL group inner_schema {
218              REQUIRED INT32 b;
219              REQUIRED DOUBLE c;
220            }
221
222            OPTIONAL group e (LIST) {
223              REPEATED group list {
224                REQUIRED INT32 element;
225              }
226            }
227        ";
228
229        let parquet_type = Arc::new(parse_message_type(message_type).unwrap());
230
231        let mut visitor = TestVisitor::new(parquet_type.clone());
232        for f in parquet_type.get_fields() {
233            let c = TestVisitorContext {};
234            assert!(visitor.dispatch(f.clone(), c).unwrap());
235        }
236
237        assert!(visitor.struct_visited);
238        assert!(visitor.primitive_visited);
239        assert!(visitor.list_visited);
240    }
241}