1use std::collections::HashSet;
21use std::marker::PhantomData;
22
23use async_trait::async_trait;
24
25use crate::spec::{PartitionKey, Struct};
26use crate::writer::partitioning::PartitioningWriter;
27use crate::writer::{DefaultInput, DefaultOutput, IcebergWriter, IcebergWriterBuilder};
28use crate::{Error, ErrorKind, Result};
29
30pub struct ClusteredWriter<B, I = DefaultInput, O = DefaultOutput>
41where
42 B: IcebergWriterBuilder<I, O>,
43 O: IntoIterator + FromIterator<<O as IntoIterator>::Item>,
44 <O as IntoIterator>::Item: Clone,
45{
46 inner_builder: B,
47 current_writer: Option<B::R>,
48 current_partition: Option<Struct>,
49 closed_partitions: HashSet<Struct>,
50 output: Vec<<O as IntoIterator>::Item>,
51 _phantom: PhantomData<I>,
52}
53
54impl<B, I, O> ClusteredWriter<B, I, O>
55where
56 B: IcebergWriterBuilder<I, O>,
57 I: Send + 'static,
58 O: IntoIterator + FromIterator<<O as IntoIterator>::Item>,
59 <O as IntoIterator>::Item: Send + Clone,
60{
61 pub fn new(inner_builder: B) -> Self {
63 Self {
64 inner_builder,
65 current_writer: None,
66 current_partition: None,
67 closed_partitions: HashSet::new(),
68 output: Vec::new(),
69 _phantom: PhantomData,
70 }
71 }
72
73 async fn close_current_writer(&mut self) -> Result<()> {
75 if let Some(mut writer) = self.current_writer.take() {
76 self.output.extend(writer.close().await?);
77
78 if let Some(current_partition) = self.current_partition.take() {
80 self.closed_partitions.insert(current_partition);
81 }
82 }
83
84 Ok(())
85 }
86}
87
88#[async_trait]
89impl<B, I, O> PartitioningWriter<I, O> for ClusteredWriter<B, I, O>
90where
91 B: IcebergWriterBuilder<I, O>,
92 I: Send + 'static,
93 O: IntoIterator + FromIterator<<O as IntoIterator>::Item> + Send + 'static,
94 <O as IntoIterator>::Item: Send + Clone,
95{
96 async fn write(&mut self, partition_key: PartitionKey, input: I) -> Result<()> {
97 let partition_value = partition_key.data();
98
99 if self.closed_partitions.contains(partition_value) {
101 return Err(Error::new(
102 ErrorKind::Unexpected,
103 format!(
104 "The input is not sorted! Cannot write to partition that was previously closed: {partition_key:?}"
105 ),
106 ));
107 }
108
109 let need_new_writer = match &self.current_partition {
111 Some(current) => current != partition_value,
112 None => true,
113 };
114
115 if need_new_writer {
116 self.close_current_writer().await?;
117
118 self.current_writer = Some(
120 self.inner_builder
121 .build(Some(partition_key.clone()))
122 .await?,
123 );
124 self.current_partition = Some(partition_value.clone());
125 }
126
127 self.current_writer
129 .as_mut()
130 .expect("Writer should be initialized")
131 .write(input)
132 .await
133 }
134
135 async fn close(mut self) -> Result<O> {
136 self.close_current_writer().await?;
137
138 Ok(O::from_iter(self.output))
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::collections::HashMap;
146 use std::sync::Arc;
147
148 use arrow_array::{Int32Array, RecordBatch, StringArray};
149 use arrow_schema::{DataType, Field, Schema};
150 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
151 use parquet::file::properties::WriterProperties;
152 use tempfile::TempDir;
153
154 use super::*;
155 use crate::io::FileIOBuilder;
156 use crate::spec::{DataFileFormat, NestedField, PrimitiveType, Type};
157 use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
158 use crate::writer::file_writer::ParquetWriterBuilder;
159 use crate::writer::file_writer::location_generator::{
160 DefaultFileNameGenerator, DefaultLocationGenerator,
161 };
162 use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
163
164 #[tokio::test]
165 async fn test_clustered_writer_single_partition() -> Result<()> {
166 let temp_dir = TempDir::new()?;
167 let file_io = FileIOBuilder::new_fs_io().build()?;
168 let location_gen = DefaultLocationGenerator::with_data_location(
169 temp_dir.path().to_str().unwrap().to_string(),
170 );
171 let file_name_gen =
172 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
173
174 let schema = Arc::new(
176 crate::spec::Schema::builder()
177 .with_schema_id(1)
178 .with_fields(vec![
179 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
180 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
181 NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
182 .into(),
183 ])
184 .build()?,
185 );
186
187 let partition_spec = crate::spec::PartitionSpec::builder(schema.clone()).build()?;
189 let partition_value =
190 crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("US"))]);
191 let partition_key =
192 crate::spec::PartitionKey::new(partition_spec, schema.clone(), partition_value.clone());
193
194 let parquet_writer_builder =
196 ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
197
198 let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
200 parquet_writer_builder,
201 schema.clone(),
202 file_io.clone(),
203 location_gen,
204 file_name_gen,
205 );
206
207 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
209
210 let mut writer = ClusteredWriter::new(data_file_writer_builder);
212
213 let arrow_schema = Schema::new(vec![
215 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
216 PARQUET_FIELD_ID_META_KEY.to_string(),
217 1.to_string(),
218 )])),
219 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
220 PARQUET_FIELD_ID_META_KEY.to_string(),
221 2.to_string(),
222 )])),
223 Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
224 PARQUET_FIELD_ID_META_KEY.to_string(),
225 3.to_string(),
226 )])),
227 ]);
228
229 let batch1 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
230 Arc::new(Int32Array::from(vec![1, 2])),
231 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
232 Arc::new(StringArray::from(vec!["US", "US"])),
233 ])?;
234
235 let batch2 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
236 Arc::new(Int32Array::from(vec![3, 4])),
237 Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
238 Arc::new(StringArray::from(vec!["US", "US"])),
239 ])?;
240
241 writer.write(partition_key.clone(), batch1).await?;
243 writer.write(partition_key.clone(), batch2).await?;
244
245 let data_files = writer.close().await?;
247
248 assert!(
250 !data_files.is_empty(),
251 "Expected at least one data file to be created"
252 );
253
254 for data_file in &data_files {
256 assert_eq!(data_file.partition, partition_value);
257 }
258
259 Ok(())
260 }
261
262 #[tokio::test]
263 async fn test_clustered_writer_sorted_partitions() -> Result<()> {
264 let temp_dir = TempDir::new()?;
265 let file_io = FileIOBuilder::new_fs_io().build()?;
266 let location_gen = DefaultLocationGenerator::with_data_location(
267 temp_dir.path().to_str().unwrap().to_string(),
268 );
269 let file_name_gen =
270 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
271
272 let schema = Arc::new(
274 crate::spec::Schema::builder()
275 .with_schema_id(1)
276 .with_fields(vec![
277 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
278 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
279 NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
280 .into(),
281 ])
282 .build()?,
283 );
284
285 let partition_spec = crate::spec::PartitionSpec::builder(schema.clone()).build()?;
287
288 let partition_value_asia =
290 crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("ASIA"))]);
291 let partition_key_asia = crate::spec::PartitionKey::new(
292 partition_spec.clone(),
293 schema.clone(),
294 partition_value_asia.clone(),
295 );
296
297 let partition_value_eu =
298 crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("EU"))]);
299 let partition_key_eu = crate::spec::PartitionKey::new(
300 partition_spec.clone(),
301 schema.clone(),
302 partition_value_eu.clone(),
303 );
304
305 let partition_value_us =
306 crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("US"))]);
307 let partition_key_us = crate::spec::PartitionKey::new(
308 partition_spec.clone(),
309 schema.clone(),
310 partition_value_us.clone(),
311 );
312
313 let parquet_writer_builder =
315 ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
316
317 let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
319 parquet_writer_builder,
320 schema.clone(),
321 file_io.clone(),
322 location_gen,
323 file_name_gen,
324 );
325
326 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
328
329 let mut writer = ClusteredWriter::new(data_file_writer_builder);
331
332 let arrow_schema = Schema::new(vec![
334 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
335 PARQUET_FIELD_ID_META_KEY.to_string(),
336 1.to_string(),
337 )])),
338 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
339 PARQUET_FIELD_ID_META_KEY.to_string(),
340 2.to_string(),
341 )])),
342 Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
343 PARQUET_FIELD_ID_META_KEY.to_string(),
344 3.to_string(),
345 )])),
346 ]);
347
348 let batch_asia = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
350 Arc::new(Int32Array::from(vec![1, 2])),
351 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
352 Arc::new(StringArray::from(vec!["ASIA", "ASIA"])),
353 ])?;
354
355 let batch_eu = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
356 Arc::new(Int32Array::from(vec![3, 4])),
357 Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
358 Arc::new(StringArray::from(vec!["EU", "EU"])),
359 ])?;
360
361 let batch_us = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
362 Arc::new(Int32Array::from(vec![5, 6])),
363 Arc::new(StringArray::from(vec!["Eve", "Frank"])),
364 Arc::new(StringArray::from(vec!["US", "US"])),
365 ])?;
366
367 writer.write(partition_key_asia.clone(), batch_asia).await?;
369 writer.write(partition_key_eu.clone(), batch_eu).await?;
370 writer.write(partition_key_us.clone(), batch_us).await?;
371
372 let data_files = writer.close().await?;
374
375 assert!(
377 data_files.len() >= 3,
378 "Expected at least 3 data files (one per partition), got {}",
379 data_files.len()
380 );
381
382 let mut partitions_found = std::collections::HashSet::new();
384 for data_file in &data_files {
385 partitions_found.insert(data_file.partition.clone());
386 }
387
388 assert!(
389 partitions_found.contains(&partition_value_asia),
390 "Missing ASIA partition"
391 );
392 assert!(
393 partitions_found.contains(&partition_value_eu),
394 "Missing EU partition"
395 );
396 assert!(
397 partitions_found.contains(&partition_value_us),
398 "Missing US partition"
399 );
400
401 Ok(())
402 }
403
404 #[tokio::test]
405 async fn test_clustered_writer_unsorted_partitions_error() -> Result<()> {
406 let temp_dir = TempDir::new()?;
407 let file_io = FileIOBuilder::new_fs_io().build()?;
408 let location_gen = DefaultLocationGenerator::with_data_location(
409 temp_dir.path().to_str().unwrap().to_string(),
410 );
411 let file_name_gen =
412 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
413
414 let schema = Arc::new(
416 crate::spec::Schema::builder()
417 .with_schema_id(1)
418 .with_fields(vec![
419 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
420 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
421 NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
422 .into(),
423 ])
424 .build()?,
425 );
426
427 let partition_spec = crate::spec::PartitionSpec::builder(schema.clone()).build()?;
429
430 let partition_value_us =
432 crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("US"))]);
433 let partition_key_us = crate::spec::PartitionKey::new(
434 partition_spec.clone(),
435 schema.clone(),
436 partition_value_us.clone(),
437 );
438
439 let partition_value_eu =
440 crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("EU"))]);
441 let partition_key_eu = crate::spec::PartitionKey::new(
442 partition_spec.clone(),
443 schema.clone(),
444 partition_value_eu.clone(),
445 );
446
447 let parquet_writer_builder =
449 ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
450
451 let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
453 parquet_writer_builder,
454 schema.clone(),
455 file_io.clone(),
456 location_gen,
457 file_name_gen,
458 );
459
460 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
462
463 let mut writer = ClusteredWriter::new(data_file_writer_builder);
465
466 let arrow_schema = Schema::new(vec![
468 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
469 PARQUET_FIELD_ID_META_KEY.to_string(),
470 1.to_string(),
471 )])),
472 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
473 PARQUET_FIELD_ID_META_KEY.to_string(),
474 2.to_string(),
475 )])),
476 Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
477 PARQUET_FIELD_ID_META_KEY.to_string(),
478 3.to_string(),
479 )])),
480 ]);
481
482 let batch_us = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
484 Arc::new(Int32Array::from(vec![1, 2])),
485 Arc::new(StringArray::from(vec!["Alice", "Bob"])),
486 Arc::new(StringArray::from(vec!["US", "US"])),
487 ])?;
488
489 let batch_eu = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
490 Arc::new(Int32Array::from(vec![3, 4])),
491 Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
492 Arc::new(StringArray::from(vec!["EU", "EU"])),
493 ])?;
494
495 let batch_us2 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
496 Arc::new(Int32Array::from(vec![5])),
497 Arc::new(StringArray::from(vec!["Eve"])),
498 Arc::new(StringArray::from(vec!["US"])),
499 ])?;
500
501 writer.write(partition_key_us.clone(), batch_us).await?;
503
504 writer.write(partition_key_eu.clone(), batch_eu).await?;
506
507 let result = writer.write(partition_key_us.clone(), batch_us2).await;
509
510 assert!(result.is_err(), "Expected error when writing unsorted data");
511
512 let error = result.unwrap_err();
513 assert!(
514 error.to_string().contains("The input is not sorted"),
515 "Expected 'input is not sorted' error, got: {error}"
516 );
517
518 Ok(())
519 }
520}