iceberg/writer/file_writer/
rolling_writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::{Debug, Formatter};
19
20use arrow_array::RecordBatch;
21
22use crate::io::{FileIO, OutputFile};
23use crate::spec::{DataFileBuilder, PartitionKey, TableProperties, SchemaRef};
24use crate::writer::CurrentFileStatus;
25use crate::writer::file_writer::location_generator::{FileNameGenerator, LocationGenerator};
26use crate::writer::file_writer::{FileWriter, FileWriterBuilder};
27use crate::{Error, ErrorKind, Result};
28
29/// Builder for [`RollingFileWriter`].
30#[derive(Clone, Debug)]
31pub struct RollingFileWriterBuilder<
32    B: FileWriterBuilder,
33    L: LocationGenerator,
34    F: FileNameGenerator,
35> {
36    inner_builder: B,
37    schema: SchemaRef,
38    target_file_size: usize,
39    file_io: FileIO,
40    location_generator: L,
41    file_name_generator: F,
42}
43
44impl<B, L, F> RollingFileWriterBuilder<B, L, F>
45where
46    B: FileWriterBuilder,
47    L: LocationGenerator,
48    F: FileNameGenerator,
49{
50    /// Creates a new `RollingFileWriterBuilder` with the specified target file size.
51    ///
52    /// # Parameters
53    ///
54    /// * `inner_builder` - The builder for the underlying file writer
55    /// * `schema` - The schema for the data being written
56    /// * `target_file_size` - The target file size in bytes that triggers rollover
57    /// * `file_io` - The file IO interface for creating output files
58    /// * `location_generator` - Generator for file locations
59    /// * `file_name_generator` - Generator for file names
60    ///
61    /// # Returns
62    ///
63    /// A new `RollingFileWriterBuilder` instance
64    pub fn new(
65        inner_builder: B,
66        schema: SchemaRef,
67        target_file_size: usize,
68        file_io: FileIO,
69        location_generator: L,
70        file_name_generator: F,
71    ) -> Self {
72        Self {
73            inner_builder,
74            schema,
75            target_file_size,
76            file_io,
77            location_generator,
78            file_name_generator,
79        }
80    }
81
82    /// Creates a new `RollingFileWriterBuilder` with the default target file size.
83    ///
84    /// # Parameters
85    ///
86    /// * `inner_builder` - The builder for the underlying file writer
87    /// * `schema` - The schema for the data being written
88    /// * `file_io` - The file IO interface for creating output files
89    /// * `location_generator` - Generator for file locations
90    /// * `file_name_generator` - Generator for file names
91    ///
92    /// # Returns
93    ///
94    /// A new `RollingFileWriterBuilder` instance with default target file size
95    pub fn new_with_default_file_size(
96        inner_builder: B,
97        schema: SchemaRef,
98        file_io: FileIO,
99        location_generator: L,
100        file_name_generator: F,
101    ) -> Self {
102        Self {
103            inner_builder,
104            schema,
105            target_file_size: TableProperties::PROPERTY_WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
106            file_io,
107            location_generator,
108            file_name_generator,
109        }
110    }
111
112    /// Build a new [`RollingFileWriter`].
113    pub fn build(&self) -> RollingFileWriter<B, L, F> {
114        RollingFileWriter {
115            inner: None,
116            inner_builder: self.inner_builder.clone(),
117            schema: self.schema.clone(),
118            target_file_size: self.target_file_size,
119            data_file_builders: vec![],
120            file_io: self.file_io.clone(),
121            location_generator: self.location_generator.clone(),
122            file_name_generator: self.file_name_generator.clone(),
123        }
124    }
125}
126
127/// A writer that automatically rolls over to a new file when the data size
128/// exceeds a target threshold.
129///
130/// This writer wraps another file writer that tracks the amount of data written.
131/// When the data size exceeds the target size, it closes the current file and
132/// starts writing to a new one.
133pub struct RollingFileWriter<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
134    inner: Option<B::R>,
135    inner_builder: B,
136    schema: SchemaRef,
137    target_file_size: usize,
138    data_file_builders: Vec<DataFileBuilder>,
139    file_io: FileIO,
140    location_generator: L,
141    file_name_generator: F,
142}
143
144impl<B, L, F> Debug for RollingFileWriter<B, L, F>
145where
146    B: FileWriterBuilder,
147    L: LocationGenerator,
148    F: FileNameGenerator,
149{
150    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("RollingFileWriter")
152            .field("target_file_size", &self.target_file_size)
153            .field("file_io", &self.file_io)
154            .finish()
155    }
156}
157
158impl<B, L, F> RollingFileWriter<B, L, F>
159where
160    B: FileWriterBuilder,
161    L: LocationGenerator,
162    F: FileNameGenerator,
163{
164    /// Determines if the writer should roll over to a new file.
165    ///
166    /// # Returns
167    ///
168    /// `true` if a new file should be started, `false` otherwise
169    fn should_roll(&self) -> bool {
170        self.current_written_size() > self.target_file_size
171    }
172
173    fn new_output_file(&self, partition_key: &Option<PartitionKey>) -> Result<OutputFile> {
174        self.file_io
175            .new_output(self.location_generator.generate_location(
176                partition_key.as_ref(),
177                &self.file_name_generator.generate_file_name(),
178            ))
179    }
180
181    /// Writes a record batch to the current file, rolling over to a new file if necessary.
182    ///
183    /// # Parameters
184    ///
185    /// * `partition_key` - Optional partition key for the data
186    /// * `input` - The record batch to write
187    ///
188    /// # Returns
189    ///
190    /// A `Result` indicating success or failure
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if the writer is not initialized or if writing fails
195    pub async fn write(
196        &mut self,
197        partition_key: &Option<PartitionKey>,
198        input: &RecordBatch,
199    ) -> Result<()> {
200        if self.inner.is_none() {
201            // initialize inner writer
202            self.inner = Some(
203                self.inner_builder
204                    .build(self.new_output_file(partition_key)?)
205                    .await?,
206            );
207        }
208
209        if self.should_roll()
210            && let Some(inner) = self.inner.take()
211        {
212            // close the current writer, roll to a new file
213            self.data_file_builders.extend(inner.close().await?);
214
215            // start a new writer
216            self.inner = Some(
217                self.inner_builder
218                    .build(self.new_output_file(partition_key)?)
219                    .await?,
220            );
221        }
222
223        // write the input
224        if let Some(writer) = self.inner.as_mut() {
225            Ok(writer.write(input).await?)
226        } else {
227            Err(Error::new(
228                ErrorKind::Unexpected,
229                "Writer is not initialized!",
230            ))
231        }
232    }
233
234    /// Closes the writer and returns all data file builders.
235    ///
236    /// # Returns
237    ///
238    /// A `Result` containing a vector of `DataFileBuilder` instances representing
239    /// all files that were written, including any that were created due to rollover
240    pub async fn close(mut self) -> Result<Vec<DataFileBuilder>> {
241        // close the current writer and merge the output
242        if let Some(current_writer) = self.inner {
243            self.data_file_builders
244                .extend(current_writer.close().await?);
245        }
246
247        Ok(self.data_file_builders)
248    }
249}
250
251impl<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> CurrentFileStatus
252    for RollingFileWriter<B, L, F>
253{
254    fn current_file_path(&self) -> String {
255        self.inner.as_ref().unwrap().current_file_path()
256    }
257
258    fn current_row_num(&self) -> usize {
259        self.inner
260            .as_ref()
261            .map(|inner| inner.current_row_num())
262            .unwrap_or(0)
263    }
264
265    fn current_written_size(&self) -> usize {
266        self.inner
267            .as_ref()
268            .map(|inner| inner.current_written_size())
269            .unwrap_or(0)
270    }
271
272    fn current_schema(&self) -> SchemaRef {
273        self.schema.clone()
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use std::collections::HashMap;
280    use std::sync::Arc;
281
282    use arrow_array::{ArrayRef, Int32Array, StringArray};
283    use arrow_schema::{DataType, Field, Schema as ArrowSchema};
284    use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
285    use parquet::file::properties::WriterProperties;
286    use rand::prelude::IteratorRandom;
287    use tempfile::TempDir;
288
289    use super::*;
290    use crate::io::FileIOBuilder;
291    use crate::spec::{DataFileFormat, NestedField, PrimitiveType, Schema, Type};
292    use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
293    use crate::writer::file_writer::ParquetWriterBuilder;
294    use crate::writer::file_writer::location_generator::{
295        DefaultFileNameGenerator, DefaultLocationGenerator,
296    };
297    use crate::writer::tests::check_parquet_data_file;
298    use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch};
299
300    fn make_test_schema() -> Result<Schema> {
301        Schema::builder()
302            .with_schema_id(1)
303            .with_fields(vec![
304                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
305                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
306            ])
307            .build()
308    }
309
310    fn make_test_arrow_schema() -> ArrowSchema {
311        ArrowSchema::new(vec![
312            Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
313                PARQUET_FIELD_ID_META_KEY.to_string(),
314                1.to_string(),
315            )])),
316            Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
317                PARQUET_FIELD_ID_META_KEY.to_string(),
318                2.to_string(),
319            )])),
320        ])
321    }
322
323    #[tokio::test]
324    async fn test_rolling_writer_basic() -> Result<()> {
325        let temp_dir = TempDir::new()?;
326        let file_io = FileIOBuilder::new_fs_io().build()?;
327        let location_gen = DefaultLocationGenerator::with_data_location(
328            temp_dir.path().to_str().unwrap().to_string(),
329        );
330        let file_name_gen =
331            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
332
333        // Create schema
334        let schema = Arc::new(make_test_schema()?);
335
336        // Create writer builders
337        let parquet_writer_builder =
338            ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
339
340        // Set a large target size so no rolling occurs
341        let rolling_file_writer_builder = RollingFileWriterBuilder::new(
342            parquet_writer_builder,
343            schema,
344            1024 * 1024,
345            file_io.clone(),
346            location_gen,
347            file_name_gen,
348        );
349
350        let data_file_writer_builder = DataFileWriterBuilder::new(rolling_file_writer_builder);
351
352        // Create writer
353        let mut writer = data_file_writer_builder.build(None).await?;
354
355        // Create test data
356        let arrow_schema = make_test_arrow_schema();
357
358        let batch = RecordBatch::try_new(Arc::new(arrow_schema), vec![
359            Arc::new(Int32Array::from(vec![1, 2, 3])),
360            Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
361        ])?;
362
363        // Write data
364        writer.write(batch.clone()).await?;
365
366        // Close writer and get data files
367        let data_files = writer.close().await?;
368
369        // Verify only one file was created
370        assert_eq!(
371            data_files.len(),
372            1,
373            "Expected only one data file to be created"
374        );
375
376        // Verify file content
377        check_parquet_data_file(&file_io, &data_files[0], &batch).await;
378
379        Ok(())
380    }
381
382    #[tokio::test]
383    async fn test_rolling_writer_with_rolling() -> Result<()> {
384        let temp_dir = TempDir::new()?;
385        let file_io = FileIOBuilder::new_fs_io().build()?;
386        let location_gen = DefaultLocationGenerator::with_data_location(
387            temp_dir.path().to_str().unwrap().to_string(),
388        );
389        let file_name_gen =
390            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
391
392        // Create schema
393        let schema = Arc::new(make_test_schema()?);
394
395        // Create writer builders
396        let parquet_writer_builder =
397            ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
398
399        // Set a very small target size to trigger rolling
400        let rolling_writer_builder = RollingFileWriterBuilder::new(
401            parquet_writer_builder,
402            schema,
403            1024,
404            file_io,
405            location_gen,
406            file_name_gen,
407        );
408
409        let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder);
410
411        // Create writer
412        let mut writer = data_file_writer_builder.build(None).await?;
413
414        // Create test data
415        let arrow_schema = make_test_arrow_schema();
416        let arrow_schema_ref = Arc::new(arrow_schema.clone());
417
418        let names = vec![
419            "Alice", "Bob", "Charlie", "Dave", "Eve", "Frank", "Grace", "Heidi", "Ivan", "Judy",
420            "Kelly", "Larry", "Mallory", "Shawn",
421        ];
422
423        let mut rng = rand::thread_rng();
424        let batch_num = 10;
425        let batch_rows = 100;
426        let expected_rows = batch_num * batch_rows;
427
428        for i in 0..batch_num {
429            let int_values: Vec<i32> = (0..batch_rows).map(|row| i * batch_rows + row).collect();
430            let str_values: Vec<&str> = (0..batch_rows)
431                .map(|_| *names.iter().choose(&mut rng).unwrap())
432                .collect();
433
434            let int_array = Arc::new(Int32Array::from(int_values)) as ArrayRef;
435            let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
436
437            let batch =
438                RecordBatch::try_new(Arc::clone(&arrow_schema_ref), vec![int_array, str_array])
439                    .expect("Failed to create RecordBatch");
440
441            writer.write(batch).await?;
442        }
443
444        // Close writer and get data files
445        let data_files = writer.close().await?;
446
447        // Verify multiple files were created (at least 4)
448        assert!(
449            data_files.len() > 4,
450            "Expected at least 4 data files to be created, but got {}",
451            data_files.len()
452        );
453
454        // Verify total record count across all files
455        let total_records: u64 = data_files.iter().map(|file| file.record_count).sum();
456        assert_eq!(
457            total_records, expected_rows as u64,
458            "Expected {expected_rows} total records across all files"
459        );
460
461        Ok(())
462    }
463}