iceberg/writer/file_writer/
rolling_writer.rs1use arrow_array::RecordBatch;
19
20use crate::spec::{DataFileBuilder, SchemaRef};
21use crate::writer::CurrentFileStatus;
22use crate::writer::file_writer::{FileWriter, FileWriterBuilder};
23use crate::{Error, ErrorKind, Result};
24
25#[derive(Clone)]
28pub struct RollingFileWriterBuilder<B: FileWriterBuilder> {
29 inner_builder: B,
30 target_file_size: usize,
31}
32
33impl<B: FileWriterBuilder> RollingFileWriterBuilder<B> {
34 pub fn new(inner_builder: B, target_file_size: usize) -> Self {
45 Self {
46 inner_builder,
47 target_file_size,
48 }
49 }
50}
51
52impl<B: FileWriterBuilder> FileWriterBuilder for RollingFileWriterBuilder<B> {
53 type R = RollingFileWriter<B>;
54
55 async fn build(self) -> Result<Self::R> {
56 Ok(RollingFileWriter {
57 inner: None,
58 inner_builder: self.inner_builder,
59 target_file_size: self.target_file_size,
60 data_file_builders: vec![],
61 })
62 }
63}
64
65pub struct RollingFileWriter<B: FileWriterBuilder> {
72 inner: Option<B::R>,
73 inner_builder: B,
74 target_file_size: usize,
75 data_file_builders: Vec<DataFileBuilder>,
76}
77
78impl<B: FileWriterBuilder> RollingFileWriter<B> {
79 fn should_roll(&self) -> bool {
85 self.current_written_size() > self.target_file_size
86 }
87}
88
89impl<B: FileWriterBuilder> FileWriter for RollingFileWriter<B> {
90 async fn write(&mut self, input: &RecordBatch) -> Result<()> {
91 if self.inner.is_none() {
92 self.inner = Some(self.inner_builder.clone().build().await?);
94 }
95
96 if self.should_roll() {
97 if let Some(inner) = self.inner.take() {
98 self.data_file_builders.extend(inner.close().await?);
100
101 self.inner = Some(self.inner_builder.clone().build().await?);
103 }
104 }
105
106 if let Some(writer) = self.inner.as_mut() {
108 Ok(writer.write(input).await?)
109 } else {
110 Err(Error::new(
111 ErrorKind::Unexpected,
112 "Writer is not initialized!",
113 ))
114 }
115 }
116
117 async fn close(mut self) -> Result<Vec<DataFileBuilder>> {
118 if let Some(current_writer) = self.inner {
120 self.data_file_builders
121 .extend(current_writer.close().await?);
122 }
123
124 Ok(self.data_file_builders)
125 }
126}
127
128impl<B: FileWriterBuilder> CurrentFileStatus for RollingFileWriter<B> {
129 fn current_file_path(&self) -> String {
130 self.inner.as_ref().unwrap().current_file_path()
131 }
132
133 fn current_row_num(&self) -> usize {
134 self.inner.as_ref().unwrap().current_row_num()
135 }
136
137 fn current_written_size(&self) -> usize {
138 self.inner.as_ref().unwrap().current_written_size()
139 }
140
141 fn current_schema(&self) -> SchemaRef {
142 self.inner.as_ref().unwrap().current_schema()
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use std::collections::HashMap;
149 use std::sync::Arc;
150
151 use arrow_array::{ArrayRef, Int32Array, StringArray};
152 use arrow_schema::{DataType, Field, Schema as ArrowSchema};
153 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
154 use parquet::file::properties::WriterProperties;
155 use rand::prelude::IteratorRandom;
156 use tempfile::TempDir;
157
158 use super::*;
159 use crate::io::FileIOBuilder;
160 use crate::spec::{DataFileFormat, NestedField, PrimitiveType, Schema, Type};
161 use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
162 use crate::writer::file_writer::ParquetWriterBuilder;
163 use crate::writer::file_writer::location_generator::{
164 DefaultFileNameGenerator, DefaultLocationGenerator,
165 };
166 use crate::writer::tests::check_parquet_data_file;
167 use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch};
168
169 fn make_test_schema() -> Result<Schema> {
170 Schema::builder()
171 .with_schema_id(1)
172 .with_fields(vec![
173 NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
174 NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
175 ])
176 .build()
177 }
178
179 fn make_test_arrow_schema() -> ArrowSchema {
180 ArrowSchema::new(vec![
181 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
182 PARQUET_FIELD_ID_META_KEY.to_string(),
183 1.to_string(),
184 )])),
185 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
186 PARQUET_FIELD_ID_META_KEY.to_string(),
187 2.to_string(),
188 )])),
189 ])
190 }
191
192 #[tokio::test]
193 async fn test_rolling_writer_basic() -> Result<()> {
194 let temp_dir = TempDir::new()?;
195 let file_io = FileIOBuilder::new_fs_io().build()?;
196 let location_gen = DefaultLocationGenerator::with_data_location(
197 temp_dir.path().to_str().unwrap().to_string(),
198 );
199 let file_name_gen =
200 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
201
202 let schema = make_test_schema()?;
204
205 let parquet_writer_builder = ParquetWriterBuilder::new(
207 WriterProperties::builder().build(),
208 Arc::new(schema),
209 None,
210 file_io.clone(),
211 location_gen,
212 file_name_gen,
213 );
214
215 let rolling_writer_builder = RollingFileWriterBuilder::new(
217 parquet_writer_builder,
218 1024 * 1024, );
220
221 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder, None, 0);
222
223 let mut writer = data_file_writer_builder.build().await?;
225
226 let arrow_schema = make_test_arrow_schema();
228
229 let batch = RecordBatch::try_new(Arc::new(arrow_schema), vec![
230 Arc::new(Int32Array::from(vec![1, 2, 3])),
231 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
232 ])?;
233
234 writer.write(batch.clone()).await?;
236
237 let data_files = writer.close().await?;
239
240 assert_eq!(
242 data_files.len(),
243 1,
244 "Expected only one data file to be created"
245 );
246
247 check_parquet_data_file(&file_io, &data_files[0], &batch).await;
249
250 Ok(())
251 }
252
253 #[tokio::test]
254 async fn test_rolling_writer_with_rolling() -> Result<()> {
255 let temp_dir = TempDir::new()?;
256 let file_io = FileIOBuilder::new_fs_io().build()?;
257 let location_gen = DefaultLocationGenerator::with_data_location(
258 temp_dir.path().to_str().unwrap().to_string(),
259 );
260 let file_name_gen =
261 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
262
263 let schema = make_test_schema()?;
265
266 let parquet_writer_builder = ParquetWriterBuilder::new(
268 WriterProperties::builder().build(),
269 Arc::new(schema),
270 None,
271 file_io.clone(),
272 location_gen,
273 file_name_gen,
274 );
275
276 let rolling_writer_builder = RollingFileWriterBuilder::new(parquet_writer_builder, 1024);
278
279 let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder, None, 0);
280
281 let mut writer = data_file_writer_builder.build().await?;
283
284 let arrow_schema = make_test_arrow_schema();
286 let arrow_schema_ref = Arc::new(arrow_schema.clone());
287
288 let names = vec![
289 "Alice", "Bob", "Charlie", "Dave", "Eve", "Frank", "Grace", "Heidi", "Ivan", "Judy",
290 "Kelly", "Larry", "Mallory", "Shawn",
291 ];
292
293 let mut rng = rand::thread_rng();
294 let batch_num = 10;
295 let batch_rows = 100;
296 let expected_rows = batch_num * batch_rows;
297
298 for i in 0..batch_num {
299 let int_values: Vec<i32> = (0..batch_rows).map(|row| i * batch_rows + row).collect();
300 let str_values: Vec<&str> = (0..batch_rows)
301 .map(|_| *names.iter().choose(&mut rng).unwrap())
302 .collect();
303
304 let int_array = Arc::new(Int32Array::from(int_values)) as ArrayRef;
305 let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
306
307 let batch =
308 RecordBatch::try_new(Arc::clone(&arrow_schema_ref), vec![int_array, str_array])
309 .expect("Failed to create RecordBatch");
310
311 writer.write(batch).await?;
312 }
313
314 let data_files = writer.close().await?;
316
317 assert!(
319 data_files.len() > 4,
320 "Expected at least 4 data files to be created, but got {}",
321 data_files.len()
322 );
323
324 let total_records: u64 = data_files.iter().map(|file| file.record_count).sum();
326 assert_eq!(
327 total_records, expected_rows as u64,
328 "Expected {} total records across all files",
329 expected_rows
330 );
331
332 Ok(())
333 }
334}