iceberg/writer/file_writer/
rolling_writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_array::RecordBatch;
19
20use crate::spec::{DataFileBuilder, SchemaRef};
21use crate::writer::CurrentFileStatus;
22use crate::writer::file_writer::{FileWriter, FileWriterBuilder};
23use crate::{Error, ErrorKind, Result};
24
25/// Builder for creating a `RollingFileWriter` that rolls over to a new file
26/// when the data size exceeds a target threshold.
27#[derive(Clone)]
28pub struct RollingFileWriterBuilder<B: FileWriterBuilder> {
29    inner_builder: B,
30    target_file_size: usize,
31}
32
33impl<B: FileWriterBuilder> RollingFileWriterBuilder<B> {
34    /// Creates a new `RollingFileWriterBuilder` with the specified inner builder and target size.
35    ///
36    /// # Arguments
37    ///
38    /// * `inner_builder` - The builder for the underlying file writer
39    /// * `target_file_size` - The target size in bytes before rolling over to a new file
40    ///
41    /// NOTE: The `target_file_size` does not exactly reflect the final size on physical storage.
42    /// This is because the input size is based on the Arrow in-memory format and cannot precisely control rollover behavior.
43    /// The actual file size on disk is expected to be slightly larger than `target_file_size`.
44    pub fn new(inner_builder: B, target_file_size: usize) -> Self {
45        Self {
46            inner_builder,
47            target_file_size,
48        }
49    }
50}
51
52impl<B: FileWriterBuilder> FileWriterBuilder for RollingFileWriterBuilder<B> {
53    type R = RollingFileWriter<B>;
54
55    async fn build(self) -> Result<Self::R> {
56        Ok(RollingFileWriter {
57            inner: None,
58            inner_builder: self.inner_builder,
59            target_file_size: self.target_file_size,
60            data_file_builders: vec![],
61        })
62    }
63}
64
65/// A writer that automatically rolls over to a new file when the data size
66/// exceeds a target threshold.
67///
68/// This writer wraps another file writer that tracks the amount of data written.
69/// When the data size exceeds the target size, it closes the current file and
70/// starts writing to a new one.
71pub struct RollingFileWriter<B: FileWriterBuilder> {
72    inner: Option<B::R>,
73    inner_builder: B,
74    target_file_size: usize,
75    data_file_builders: Vec<DataFileBuilder>,
76}
77
78impl<B: FileWriterBuilder> RollingFileWriter<B> {
79    /// Determines if the writer should roll over to a new file.
80    ///
81    /// # Returns
82    ///
83    /// `true` if a new file should be started, `false` otherwise
84    fn should_roll(&self) -> bool {
85        self.current_written_size() > self.target_file_size
86    }
87}
88
89impl<B: FileWriterBuilder> FileWriter for RollingFileWriter<B> {
90    async fn write(&mut self, input: &RecordBatch) -> Result<()> {
91        if self.inner.is_none() {
92            // initialize inner writer
93            self.inner = Some(self.inner_builder.clone().build().await?);
94        }
95
96        if self.should_roll() {
97            if let Some(inner) = self.inner.take() {
98                // close the current writer, roll to a new file
99                self.data_file_builders.extend(inner.close().await?);
100
101                // start a new writer
102                self.inner = Some(self.inner_builder.clone().build().await?);
103            }
104        }
105
106        // write the input
107        if let Some(writer) = self.inner.as_mut() {
108            Ok(writer.write(input).await?)
109        } else {
110            Err(Error::new(
111                ErrorKind::Unexpected,
112                "Writer is not initialized!",
113            ))
114        }
115    }
116
117    async fn close(mut self) -> Result<Vec<DataFileBuilder>> {
118        // close the current writer and merge the output
119        if let Some(current_writer) = self.inner {
120            self.data_file_builders
121                .extend(current_writer.close().await?);
122        }
123
124        Ok(self.data_file_builders)
125    }
126}
127
128impl<B: FileWriterBuilder> CurrentFileStatus for RollingFileWriter<B> {
129    fn current_file_path(&self) -> String {
130        self.inner.as_ref().unwrap().current_file_path()
131    }
132
133    fn current_row_num(&self) -> usize {
134        self.inner.as_ref().unwrap().current_row_num()
135    }
136
137    fn current_written_size(&self) -> usize {
138        self.inner.as_ref().unwrap().current_written_size()
139    }
140
141    fn current_schema(&self) -> SchemaRef {
142        self.inner.as_ref().unwrap().current_schema()
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use std::collections::HashMap;
149    use std::sync::Arc;
150
151    use arrow_array::{ArrayRef, Int32Array, StringArray};
152    use arrow_schema::{DataType, Field, Schema as ArrowSchema};
153    use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
154    use parquet::file::properties::WriterProperties;
155    use rand::prelude::IteratorRandom;
156    use tempfile::TempDir;
157
158    use super::*;
159    use crate::io::FileIOBuilder;
160    use crate::spec::{DataFileFormat, NestedField, PrimitiveType, Schema, Type};
161    use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
162    use crate::writer::file_writer::ParquetWriterBuilder;
163    use crate::writer::file_writer::location_generator::{
164        DefaultFileNameGenerator, DefaultLocationGenerator,
165    };
166    use crate::writer::tests::check_parquet_data_file;
167    use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch};
168
169    fn make_test_schema() -> Result<Schema> {
170        Schema::builder()
171            .with_schema_id(1)
172            .with_fields(vec![
173                NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
174                NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
175            ])
176            .build()
177    }
178
179    fn make_test_arrow_schema() -> ArrowSchema {
180        ArrowSchema::new(vec![
181            Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
182                PARQUET_FIELD_ID_META_KEY.to_string(),
183                1.to_string(),
184            )])),
185            Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
186                PARQUET_FIELD_ID_META_KEY.to_string(),
187                2.to_string(),
188            )])),
189        ])
190    }
191
192    #[tokio::test]
193    async fn test_rolling_writer_basic() -> Result<()> {
194        let temp_dir = TempDir::new()?;
195        let file_io = FileIOBuilder::new_fs_io().build()?;
196        let location_gen = DefaultLocationGenerator::with_data_location(
197            temp_dir.path().to_str().unwrap().to_string(),
198        );
199        let file_name_gen =
200            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
201
202        // Create schema
203        let schema = make_test_schema()?;
204
205        // Create writer builders
206        let parquet_writer_builder = ParquetWriterBuilder::new(
207            WriterProperties::builder().build(),
208            Arc::new(schema),
209            None,
210            file_io.clone(),
211            location_gen,
212            file_name_gen,
213        );
214
215        // Set a large target size so no rolling occurs
216        let rolling_writer_builder = RollingFileWriterBuilder::new(
217            parquet_writer_builder,
218            1024 * 1024, // 1MB, large enough to not trigger rolling
219        );
220
221        let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder, None, 0);
222
223        // Create writer
224        let mut writer = data_file_writer_builder.build().await?;
225
226        // Create test data
227        let arrow_schema = make_test_arrow_schema();
228
229        let batch = RecordBatch::try_new(Arc::new(arrow_schema), vec![
230            Arc::new(Int32Array::from(vec![1, 2, 3])),
231            Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
232        ])?;
233
234        // Write data
235        writer.write(batch.clone()).await?;
236
237        // Close writer and get data files
238        let data_files = writer.close().await?;
239
240        // Verify only one file was created
241        assert_eq!(
242            data_files.len(),
243            1,
244            "Expected only one data file to be created"
245        );
246
247        // Verify file content
248        check_parquet_data_file(&file_io, &data_files[0], &batch).await;
249
250        Ok(())
251    }
252
253    #[tokio::test]
254    async fn test_rolling_writer_with_rolling() -> Result<()> {
255        let temp_dir = TempDir::new()?;
256        let file_io = FileIOBuilder::new_fs_io().build()?;
257        let location_gen = DefaultLocationGenerator::with_data_location(
258            temp_dir.path().to_str().unwrap().to_string(),
259        );
260        let file_name_gen =
261            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
262
263        // Create schema
264        let schema = make_test_schema()?;
265
266        // Create writer builders
267        let parquet_writer_builder = ParquetWriterBuilder::new(
268            WriterProperties::builder().build(),
269            Arc::new(schema),
270            None,
271            file_io.clone(),
272            location_gen,
273            file_name_gen,
274        );
275
276        // Set a very small target size to trigger rolling
277        let rolling_writer_builder = RollingFileWriterBuilder::new(parquet_writer_builder, 1024);
278
279        let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder, None, 0);
280
281        // Create writer
282        let mut writer = data_file_writer_builder.build().await?;
283
284        // Create test data
285        let arrow_schema = make_test_arrow_schema();
286        let arrow_schema_ref = Arc::new(arrow_schema.clone());
287
288        let names = vec![
289            "Alice", "Bob", "Charlie", "Dave", "Eve", "Frank", "Grace", "Heidi", "Ivan", "Judy",
290            "Kelly", "Larry", "Mallory", "Shawn",
291        ];
292
293        let mut rng = rand::thread_rng();
294        let batch_num = 10;
295        let batch_rows = 100;
296        let expected_rows = batch_num * batch_rows;
297
298        for i in 0..batch_num {
299            let int_values: Vec<i32> = (0..batch_rows).map(|row| i * batch_rows + row).collect();
300            let str_values: Vec<&str> = (0..batch_rows)
301                .map(|_| *names.iter().choose(&mut rng).unwrap())
302                .collect();
303
304            let int_array = Arc::new(Int32Array::from(int_values)) as ArrayRef;
305            let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
306
307            let batch =
308                RecordBatch::try_new(Arc::clone(&arrow_schema_ref), vec![int_array, str_array])
309                    .expect("Failed to create RecordBatch");
310
311            writer.write(batch).await?;
312        }
313
314        // Close writer and get data files
315        let data_files = writer.close().await?;
316
317        // Verify multiple files were created (at least 4)
318        assert!(
319            data_files.len() > 4,
320            "Expected at least 4 data files to be created, but got {}",
321            data_files.len()
322        );
323
324        // Verify total record count across all files
325        let total_records: u64 = data_files.iter().map(|file| file.record_count).sum();
326        assert_eq!(
327            total_records, expected_rows as u64,
328            "Expected {} total records across all files",
329            expected_rows
330        );
331
332        Ok(())
333    }
334}