1use std::collections::HashMap;
21use std::marker::PhantomData;
22
23use async_trait::async_trait;
24
25use crate::spec::{PartitionKey, Struct};
26use crate::writer::partitioning::PartitioningWriter;
27use crate::writer::{DefaultInput, DefaultOutput, IcebergWriter, IcebergWriterBuilder};
28use crate::{Error, ErrorKind, Result};
29
30pub struct FanoutWriter<B, I = DefaultInput, O = DefaultOutput>
43where
44 B: IcebergWriterBuilder<I, O>,
45 O: IntoIterator + FromIterator<<O as IntoIterator>::Item>,
46 <O as IntoIterator>::Item: Clone,
47{
48 inner_builder: B,
49 partition_writers: HashMap<Struct, B::R>,
50 output: Vec<<O as IntoIterator>::Item>,
51 _phantom: PhantomData<I>,
52}
53
54impl<B, I, O> FanoutWriter<B, I, O>
55where
56 B: IcebergWriterBuilder<I, O>,
57 I: Send + 'static,
58 O: IntoIterator + FromIterator<<O as IntoIterator>::Item>,
59 <O as IntoIterator>::Item: Send + Clone,
60{
61 pub fn new(inner_builder: B) -> Self {
63 Self {
64 inner_builder,
65 partition_writers: HashMap::new(),
66 output: Vec::new(),
67 _phantom: PhantomData,
68 }
69 }
70
71 async fn get_or_create_writer(&mut self, partition_key: &PartitionKey) -> Result<&mut B::R> {
73 if !self.partition_writers.contains_key(partition_key.data()) {
74 let writer = self
75 .inner_builder
76 .build(Some(partition_key.clone()))
77 .await?;
78 self.partition_writers
79 .insert(partition_key.data().clone(), writer);
80 }
81
82 self.partition_writers
83 .get_mut(partition_key.data())
84 .ok_or_else(|| {
85 Error::new(
86 ErrorKind::Unexpected,
87 "Failed to get partition writer after creation",
88 )
89 })
90 }
91}
92
93#[async_trait]
94impl<B, I, O> PartitioningWriter<I, O> for FanoutWriter<B, I, O>
95where
96 B: IcebergWriterBuilder<I, O>,
97 I: Send + 'static,
98 O: IntoIterator + FromIterator<<O as IntoIterator>::Item> + Send + 'static,
99 <O as IntoIterator>::Item: Send + Clone,
100{
101 async fn write(&mut self, partition_key: PartitionKey, input: I) -> Result<()> {
102 let writer = self.get_or_create_writer(&partition_key).await?;
103 writer.write(input).await
104 }
105
106 async fn close(mut self) -> Result<O> {
107 for (_, mut writer) in self.partition_writers {
109 self.output.extend(writer.close().await?);
110 }
111
112 Ok(O::from_iter(self.output))
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use std::collections::HashMap;
120 use std::sync::Arc;
121
122 use arrow_array::{Int32Array, RecordBatch, StringArray};
123 use arrow_schema::{DataType, Field, Schema};
124 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
125 use parquet::file::properties::WriterProperties;
126 use tempfile::TempDir;
127
128 use super::*;
129 use crate::io::FileIOBuilder;
130 use crate::spec::{
131 DataFileFormat, Literal, NestedField, PartitionKey, PartitionSpec, PrimitiveType, Struct,
132 Type,
133 };
134 use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
135 use crate::writer::file_writer::ParquetWriterBuilder;
136 use crate::writer::file_writer::location_generator::{
137 DefaultFileNameGenerator, DefaultLocationGenerator,
138 };
139 use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
140
141 #[tokio::test]
142 async fn test_fanout_writer_single_partition() -> Result<()> {
143 let temp_dir = TempDir::new()?;
144 let file_io = FileIOBuilder::new_fs_io().build()?;
145 let location_gen = DefaultLocationGenerator::with_data_location(
146 temp_dir.path().to_str().unwrap().to_string(),
147 );
148 let file_name_gen =
149 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
150
151 let schema = Arc::new(
153 crate::spec::Schema::builder()
154 .with_schema_id(1)
155 .with_fields(vec![
156 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
157 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
158 NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
159 .into(),
160 ])
161 .build()?,
162 );
163
164 let partition_spec = PartitionSpec::builder(schema.clone()).build()?;
166 let partition_value = Struct::from_iter([Some(Literal::string("US"))]);
167 let partition_key =
168 PartitionKey::new(partition_spec, schema.clone(), partition_value.clone());
169
170 let parquet_writer_builder =
172 ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
173
174 let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
176 parquet_writer_builder,
177 schema.clone(),
178 file_io.clone(),
179 location_gen,
180 file_name_gen,
181 );
182
183 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
185
186 let mut writer = FanoutWriter::new(data_file_writer_builder);
188
189 let arrow_schema = Schema::new(vec![
191 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
192 PARQUET_FIELD_ID_META_KEY.to_string(),
193 1.to_string(),
194 )])),
195 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
196 PARQUET_FIELD_ID_META_KEY.to_string(),
197 2.to_string(),
198 )])),
199 Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
200 PARQUET_FIELD_ID_META_KEY.to_string(),
201 3.to_string(),
202 )])),
203 ]);
204
205 let batch1 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
206 Arc::new(Int32Array::from(vec![1, 2])),
207 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
208 Arc::new(StringArray::from(vec!["US", "US"])),
209 ])?;
210
211 let batch2 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
212 Arc::new(Int32Array::from(vec![3, 4])),
213 Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
214 Arc::new(StringArray::from(vec!["US", "US"])),
215 ])?;
216
217 writer.write(partition_key.clone(), batch1).await?;
219 writer.write(partition_key.clone(), batch2).await?;
220
221 let data_files = writer.close().await?;
223
224 assert!(
226 !data_files.is_empty(),
227 "Expected at least one data file to be created"
228 );
229
230 for data_file in &data_files {
232 assert_eq!(data_file.partition, partition_value);
233 }
234
235 Ok(())
236 }
237
238 #[tokio::test]
239 async fn test_fanout_writer_multiple_partitions() -> Result<()> {
240 let temp_dir = TempDir::new()?;
241 let file_io = FileIOBuilder::new_fs_io().build()?;
242 let location_gen = DefaultLocationGenerator::with_data_location(
243 temp_dir.path().to_str().unwrap().to_string(),
244 );
245 let file_name_gen =
246 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
247
248 let schema = Arc::new(
250 crate::spec::Schema::builder()
251 .with_schema_id(1)
252 .with_fields(vec![
253 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
254 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
255 NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
256 .into(),
257 ])
258 .build()?,
259 );
260
261 let partition_spec = PartitionSpec::builder(schema.clone()).build()?;
263
264 let partition_value_us = Struct::from_iter([Some(Literal::string("US"))]);
266 let partition_key_us = PartitionKey::new(
267 partition_spec.clone(),
268 schema.clone(),
269 partition_value_us.clone(),
270 );
271
272 let partition_value_eu = Struct::from_iter([Some(Literal::string("EU"))]);
273 let partition_key_eu = PartitionKey::new(
274 partition_spec.clone(),
275 schema.clone(),
276 partition_value_eu.clone(),
277 );
278
279 let partition_value_asia = Struct::from_iter([Some(Literal::string("ASIA"))]);
280 let partition_key_asia = PartitionKey::new(
281 partition_spec.clone(),
282 schema.clone(),
283 partition_value_asia.clone(),
284 );
285
286 let parquet_writer_builder =
288 ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
289
290 let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
292 parquet_writer_builder,
293 schema.clone(),
294 file_io.clone(),
295 location_gen,
296 file_name_gen,
297 );
298
299 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
301
302 let mut writer = FanoutWriter::new(data_file_writer_builder);
304
305 let arrow_schema = Schema::new(vec![
307 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
308 PARQUET_FIELD_ID_META_KEY.to_string(),
309 1.to_string(),
310 )])),
311 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
312 PARQUET_FIELD_ID_META_KEY.to_string(),
313 2.to_string(),
314 )])),
315 Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
316 PARQUET_FIELD_ID_META_KEY.to_string(),
317 3.to_string(),
318 )])),
319 ]);
320
321 let batch_us1 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
323 Arc::new(Int32Array::from(vec![1, 2])),
324 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
325 Arc::new(StringArray::from(vec!["US", "US"])),
326 ])?;
327
328 let batch_eu1 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
329 Arc::new(Int32Array::from(vec![3, 4])),
330 Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
331 Arc::new(StringArray::from(vec!["EU", "EU"])),
332 ])?;
333
334 let batch_us2 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
335 Arc::new(Int32Array::from(vec![5])),
336 Arc::new(StringArray::from(vec!["Eve"])),
337 Arc::new(StringArray::from(vec!["US"])),
338 ])?;
339
340 let batch_asia1 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
341 Arc::new(Int32Array::from(vec![6, 7])),
342 Arc::new(StringArray::from(vec!["Frank", "Grace"])),
343 Arc::new(StringArray::from(vec!["ASIA", "ASIA"])),
344 ])?;
345
346 writer.write(partition_key_us.clone(), batch_us1).await?;
349 writer.write(partition_key_eu.clone(), batch_eu1).await?;
350 writer.write(partition_key_us.clone(), batch_us2).await?; writer
352 .write(partition_key_asia.clone(), batch_asia1)
353 .await?;
354
355 let data_files = writer.close().await?;
357
358 assert!(
360 data_files.len() >= 3,
361 "Expected at least 3 data files (one per partition), got {}",
362 data_files.len()
363 );
364
365 let mut partitions_found = std::collections::HashSet::new();
367 for data_file in &data_files {
368 partitions_found.insert(data_file.partition.clone());
369 }
370
371 assert!(
372 partitions_found.contains(&partition_value_us),
373 "Missing US partition"
374 );
375 assert!(
376 partitions_found.contains(&partition_value_eu),
377 "Missing EU partition"
378 );
379 assert!(
380 partitions_found.contains(&partition_value_asia),
381 "Missing ASIA partition"
382 );
383
384 Ok(())
385 }
386}