iceberg/writer/partitioning/
clustered_writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module provides the `ClusteredWriter` implementation.
19
20use std::collections::HashSet;
21use std::marker::PhantomData;
22
23use async_trait::async_trait;
24
25use crate::spec::{PartitionKey, Struct};
26use crate::writer::partitioning::PartitioningWriter;
27use crate::writer::{DefaultInput, DefaultOutput, IcebergWriter, IcebergWriterBuilder};
28use crate::{Error, ErrorKind, Result};
29
30/// A writer that writes data to a single partition at a time.
31///
32/// This writer expects input data to be sorted by partition key. It maintains only one
33/// active writer at a time, making it memory efficient for sorted data.
34///
35/// # Type Parameters
36///
37/// * `B` - The inner writer builder type
38/// * `I` - Input type (defaults to `RecordBatch`)
39/// * `O` - Output collection type (defaults to `Vec<DataFile>`)
40pub struct ClusteredWriter<B, I = DefaultInput, O = DefaultOutput>
41where
42    B: IcebergWriterBuilder<I, O>,
43    O: IntoIterator + FromIterator<<O as IntoIterator>::Item>,
44    <O as IntoIterator>::Item: Clone,
45{
46    inner_builder: B,
47    current_writer: Option<B::R>,
48    current_partition: Option<Struct>,
49    closed_partitions: HashSet<Struct>,
50    output: Vec<<O as IntoIterator>::Item>,
51    _phantom: PhantomData<I>,
52}
53
54impl<B, I, O> ClusteredWriter<B, I, O>
55where
56    B: IcebergWriterBuilder<I, O>,
57    I: Send + 'static,
58    O: IntoIterator + FromIterator<<O as IntoIterator>::Item>,
59    <O as IntoIterator>::Item: Send + Clone,
60{
61    /// Create a new `ClusteredWriter`.
62    pub fn new(inner_builder: B) -> Self {
63        Self {
64            inner_builder,
65            current_writer: None,
66            current_partition: None,
67            closed_partitions: HashSet::new(),
68            output: Vec::new(),
69            _phantom: PhantomData,
70        }
71    }
72
73    /// Closes the current writer if it exists, flushes the written data to output, and record closed partition.
74    async fn close_current_writer(&mut self) -> Result<()> {
75        if let Some(mut writer) = self.current_writer.take() {
76            self.output.extend(writer.close().await?);
77
78            // Add the current partition to the set of closed partitions
79            if let Some(current_partition) = self.current_partition.take() {
80                self.closed_partitions.insert(current_partition);
81            }
82        }
83
84        Ok(())
85    }
86}
87
88#[async_trait]
89impl<B, I, O> PartitioningWriter<I, O> for ClusteredWriter<B, I, O>
90where
91    B: IcebergWriterBuilder<I, O>,
92    I: Send + 'static,
93    O: IntoIterator + FromIterator<<O as IntoIterator>::Item> + Send + 'static,
94    <O as IntoIterator>::Item: Send + Clone,
95{
96    async fn write(&mut self, partition_key: PartitionKey, input: I) -> Result<()> {
97        let partition_value = partition_key.data();
98
99        // Check if this partition has been closed already
100        if self.closed_partitions.contains(partition_value) {
101            return Err(Error::new(
102                ErrorKind::Unexpected,
103                format!(
104                    "The input is not sorted! Cannot write to partition that was previously closed: {partition_key:?}"
105                ),
106            ));
107        }
108
109        // Check if we need to switch to a new partition
110        let need_new_writer = match &self.current_partition {
111            Some(current) => current != partition_value,
112            None => true,
113        };
114
115        if need_new_writer {
116            self.close_current_writer().await?;
117
118            // Create a new writer for the new partition
119            self.current_writer = Some(
120                self.inner_builder
121                    .build(Some(partition_key.clone()))
122                    .await?,
123            );
124            self.current_partition = Some(partition_value.clone());
125        }
126
127        // do write
128        self.current_writer
129            .as_mut()
130            .expect("Writer should be initialized")
131            .write(input)
132            .await
133    }
134
135    async fn close(mut self) -> Result<O> {
136        self.close_current_writer().await?;
137
138        // Collect all output items into the output collection type
139        Ok(O::from_iter(self.output))
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::collections::HashMap;
146    use std::sync::Arc;
147
148    use arrow_array::{Int32Array, RecordBatch, StringArray};
149    use arrow_schema::{DataType, Field, Schema};
150    use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
151    use parquet::file::properties::WriterProperties;
152    use tempfile::TempDir;
153
154    use super::*;
155    use crate::io::FileIOBuilder;
156    use crate::spec::{DataFileFormat, NestedField, PrimitiveType, Type};
157    use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
158    use crate::writer::file_writer::ParquetWriterBuilder;
159    use crate::writer::file_writer::location_generator::{
160        DefaultFileNameGenerator, DefaultLocationGenerator,
161    };
162    use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
163
164    #[tokio::test]
165    async fn test_clustered_writer_single_partition() -> Result<()> {
166        let temp_dir = TempDir::new()?;
167        let file_io = FileIOBuilder::new_fs_io().build()?;
168        let location_gen = DefaultLocationGenerator::with_data_location(
169            temp_dir.path().to_str().unwrap().to_string(),
170        );
171        let file_name_gen =
172            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
173
174        // Create schema with partition field
175        let schema = Arc::new(
176            crate::spec::Schema::builder()
177                .with_schema_id(1)
178                .with_fields(vec![
179                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
180                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
181                    NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
182                        .into(),
183                ])
184                .build()?,
185        );
186
187        // Create partition spec and key
188        let partition_spec = crate::spec::PartitionSpec::builder(schema.clone()).build()?;
189        let partition_value =
190            crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("US"))]);
191        let partition_key =
192            crate::spec::PartitionKey::new(partition_spec, schema.clone(), partition_value.clone());
193
194        // Create writer builder
195        let parquet_writer_builder =
196            ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
197
198        // Create rolling file writer builder
199        let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
200            parquet_writer_builder,
201            schema.clone(),
202            file_io.clone(),
203            location_gen,
204            file_name_gen,
205        );
206
207        // Create data file writer builder
208        let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
209
210        // Create clustered writer
211        let mut writer = ClusteredWriter::new(data_file_writer_builder);
212
213        // Create test data with proper field ID metadata
214        let arrow_schema = Schema::new(vec![
215            Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
216                PARQUET_FIELD_ID_META_KEY.to_string(),
217                1.to_string(),
218            )])),
219            Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
220                PARQUET_FIELD_ID_META_KEY.to_string(),
221                2.to_string(),
222            )])),
223            Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
224                PARQUET_FIELD_ID_META_KEY.to_string(),
225                3.to_string(),
226            )])),
227        ]);
228
229        let batch1 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
230            Arc::new(Int32Array::from(vec![1, 2])),
231            Arc::new(StringArray::from(vec!["Alice", "Bob"])),
232            Arc::new(StringArray::from(vec!["US", "US"])),
233        ])?;
234
235        let batch2 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
236            Arc::new(Int32Array::from(vec![3, 4])),
237            Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
238            Arc::new(StringArray::from(vec!["US", "US"])),
239        ])?;
240
241        // Write data to the same partition (this should work)
242        writer.write(partition_key.clone(), batch1).await?;
243        writer.write(partition_key.clone(), batch2).await?;
244
245        // Close writer and get data files
246        let data_files = writer.close().await?;
247
248        // Verify at least one file was created
249        assert!(
250            !data_files.is_empty(),
251            "Expected at least one data file to be created"
252        );
253
254        // Verify that all data files have the correct partition value
255        for data_file in &data_files {
256            assert_eq!(data_file.partition, partition_value);
257        }
258
259        Ok(())
260    }
261
262    #[tokio::test]
263    async fn test_clustered_writer_sorted_partitions() -> Result<()> {
264        let temp_dir = TempDir::new()?;
265        let file_io = FileIOBuilder::new_fs_io().build()?;
266        let location_gen = DefaultLocationGenerator::with_data_location(
267            temp_dir.path().to_str().unwrap().to_string(),
268        );
269        let file_name_gen =
270            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
271
272        // Create schema with partition field
273        let schema = Arc::new(
274            crate::spec::Schema::builder()
275                .with_schema_id(1)
276                .with_fields(vec![
277                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
278                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
279                    NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
280                        .into(),
281                ])
282                .build()?,
283        );
284
285        // Create partition spec
286        let partition_spec = crate::spec::PartitionSpec::builder(schema.clone()).build()?;
287
288        // Create partition keys for different regions (in sorted order)
289        let partition_value_asia =
290            crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("ASIA"))]);
291        let partition_key_asia = crate::spec::PartitionKey::new(
292            partition_spec.clone(),
293            schema.clone(),
294            partition_value_asia.clone(),
295        );
296
297        let partition_value_eu =
298            crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("EU"))]);
299        let partition_key_eu = crate::spec::PartitionKey::new(
300            partition_spec.clone(),
301            schema.clone(),
302            partition_value_eu.clone(),
303        );
304
305        let partition_value_us =
306            crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("US"))]);
307        let partition_key_us = crate::spec::PartitionKey::new(
308            partition_spec.clone(),
309            schema.clone(),
310            partition_value_us.clone(),
311        );
312
313        // Create writer builder
314        let parquet_writer_builder =
315            ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
316
317        // Create rolling file writer builder
318        let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
319            parquet_writer_builder,
320            schema.clone(),
321            file_io.clone(),
322            location_gen,
323            file_name_gen,
324        );
325
326        // Create data file writer builder
327        let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
328
329        // Create clustered writer
330        let mut writer = ClusteredWriter::new(data_file_writer_builder);
331
332        // Create test data with proper field ID metadata
333        let arrow_schema = Schema::new(vec![
334            Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
335                PARQUET_FIELD_ID_META_KEY.to_string(),
336                1.to_string(),
337            )])),
338            Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
339                PARQUET_FIELD_ID_META_KEY.to_string(),
340                2.to_string(),
341            )])),
342            Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
343                PARQUET_FIELD_ID_META_KEY.to_string(),
344                3.to_string(),
345            )])),
346        ]);
347
348        // Create batches for different partitions (in sorted order)
349        let batch_asia = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
350            Arc::new(Int32Array::from(vec![1, 2])),
351            Arc::new(StringArray::from(vec!["Alice", "Bob"])),
352            Arc::new(StringArray::from(vec!["ASIA", "ASIA"])),
353        ])?;
354
355        let batch_eu = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
356            Arc::new(Int32Array::from(vec![3, 4])),
357            Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
358            Arc::new(StringArray::from(vec!["EU", "EU"])),
359        ])?;
360
361        let batch_us = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
362            Arc::new(Int32Array::from(vec![5, 6])),
363            Arc::new(StringArray::from(vec!["Eve", "Frank"])),
364            Arc::new(StringArray::from(vec!["US", "US"])),
365        ])?;
366
367        // Write data in sorted partition order (this should work)
368        writer.write(partition_key_asia.clone(), batch_asia).await?;
369        writer.write(partition_key_eu.clone(), batch_eu).await?;
370        writer.write(partition_key_us.clone(), batch_us).await?;
371
372        // Close writer and get data files
373        let data_files = writer.close().await?;
374
375        // Verify files were created for all partitions
376        assert!(
377            data_files.len() >= 3,
378            "Expected at least 3 data files (one per partition), got {}",
379            data_files.len()
380        );
381
382        // Verify that we have files for each partition
383        let mut partitions_found = std::collections::HashSet::new();
384        for data_file in &data_files {
385            partitions_found.insert(data_file.partition.clone());
386        }
387
388        assert!(
389            partitions_found.contains(&partition_value_asia),
390            "Missing ASIA partition"
391        );
392        assert!(
393            partitions_found.contains(&partition_value_eu),
394            "Missing EU partition"
395        );
396        assert!(
397            partitions_found.contains(&partition_value_us),
398            "Missing US partition"
399        );
400
401        Ok(())
402    }
403
404    #[tokio::test]
405    async fn test_clustered_writer_unsorted_partitions_error() -> Result<()> {
406        let temp_dir = TempDir::new()?;
407        let file_io = FileIOBuilder::new_fs_io().build()?;
408        let location_gen = DefaultLocationGenerator::with_data_location(
409            temp_dir.path().to_str().unwrap().to_string(),
410        );
411        let file_name_gen =
412            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
413
414        // Create schema with partition field
415        let schema = Arc::new(
416            crate::spec::Schema::builder()
417                .with_schema_id(1)
418                .with_fields(vec![
419                    NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
420                    NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
421                    NestedField::required(3, "region", Type::Primitive(PrimitiveType::String))
422                        .into(),
423                ])
424                .build()?,
425        );
426
427        // Create partition spec
428        let partition_spec = crate::spec::PartitionSpec::builder(schema.clone()).build()?;
429
430        // Create partition keys for different regions
431        let partition_value_us =
432            crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("US"))]);
433        let partition_key_us = crate::spec::PartitionKey::new(
434            partition_spec.clone(),
435            schema.clone(),
436            partition_value_us.clone(),
437        );
438
439        let partition_value_eu =
440            crate::spec::Struct::from_iter([Some(crate::spec::Literal::string("EU"))]);
441        let partition_key_eu = crate::spec::PartitionKey::new(
442            partition_spec.clone(),
443            schema.clone(),
444            partition_value_eu.clone(),
445        );
446
447        // Create writer builder
448        let parquet_writer_builder =
449            ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
450
451        // Create rolling file writer builder
452        let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
453            parquet_writer_builder,
454            schema.clone(),
455            file_io.clone(),
456            location_gen,
457            file_name_gen,
458        );
459
460        // Create data file writer builder
461        let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
462
463        // Create clustered writer
464        let mut writer = ClusteredWriter::new(data_file_writer_builder);
465
466        // Create test data with proper field ID metadata
467        let arrow_schema = Schema::new(vec![
468            Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
469                PARQUET_FIELD_ID_META_KEY.to_string(),
470                1.to_string(),
471            )])),
472            Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
473                PARQUET_FIELD_ID_META_KEY.to_string(),
474                2.to_string(),
475            )])),
476            Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([(
477                PARQUET_FIELD_ID_META_KEY.to_string(),
478                3.to_string(),
479            )])),
480        ]);
481
482        // Create batches for different partitions
483        let batch_us = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
484            Arc::new(Int32Array::from(vec![1, 2])),
485            Arc::new(StringArray::from(vec!["Alice", "Bob"])),
486            Arc::new(StringArray::from(vec!["US", "US"])),
487        ])?;
488
489        let batch_eu = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
490            Arc::new(Int32Array::from(vec![3, 4])),
491            Arc::new(StringArray::from(vec!["Charlie", "Dave"])),
492            Arc::new(StringArray::from(vec!["EU", "EU"])),
493        ])?;
494
495        let batch_us2 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
496            Arc::new(Int32Array::from(vec![5])),
497            Arc::new(StringArray::from(vec!["Eve"])),
498            Arc::new(StringArray::from(vec!["US"])),
499        ])?;
500
501        // Write data to US partition first
502        writer.write(partition_key_us.clone(), batch_us).await?;
503
504        // Write data to EU partition (this closes US partition)
505        writer.write(partition_key_eu.clone(), batch_eu).await?;
506
507        // Try to write to US partition again - this should fail because data is not sorted
508        let result = writer.write(partition_key_us.clone(), batch_us2).await;
509
510        assert!(result.is_err(), "Expected error when writing unsorted data");
511
512        let error = result.unwrap_err();
513        assert!(
514            error.to_string().contains("The input is not sorted"),
515            "Expected 'input is not sorted' error, got: {error}"
516        );
517
518        Ok(())
519    }
520}