iceberg/writer/base_writer/
data_file_writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module provide `DataFileWriter`.
19
20use arrow_array::RecordBatch;
21
22use crate::spec::{DataContentType, DataFile, PartitionKey};
23use crate::writer::file_writer::FileWriterBuilder;
24use crate::writer::file_writer::location_generator::{FileNameGenerator, LocationGenerator};
25use crate::writer::file_writer::rolling_writer::{RollingFileWriter, RollingFileWriterBuilder};
26use crate::writer::{CurrentFileStatus, IcebergWriter, IcebergWriterBuilder};
27use crate::{Error, ErrorKind, Result};
28
29/// Builder for `DataFileWriter`.
30#[derive(Debug)]
31pub struct DataFileWriterBuilder<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
32    inner: RollingFileWriterBuilder<B, L, F>,
33}
34
35impl<B, L, F> DataFileWriterBuilder<B, L, F>
36where
37    B: FileWriterBuilder,
38    L: LocationGenerator,
39    F: FileNameGenerator,
40{
41    /// Create a new `DataFileWriterBuilder` using a `RollingFileWriterBuilder`.
42    pub fn new(inner: RollingFileWriterBuilder<B, L, F>) -> Self {
43        Self { inner }
44    }
45}
46
47#[async_trait::async_trait]
48impl<B, L, F> IcebergWriterBuilder for DataFileWriterBuilder<B, L, F>
49where
50    B: FileWriterBuilder,
51    L: LocationGenerator,
52    F: FileNameGenerator,
53{
54    type R = DataFileWriter<B, L, F>;
55
56    async fn build(&self, partition_key: Option<PartitionKey>) -> Result<Self::R> {
57        Ok(DataFileWriter {
58            inner: Some(self.inner.build()),
59            partition_key,
60        })
61    }
62}
63
64/// A writer write data is within one spec/partition.
65#[derive(Debug)]
66pub struct DataFileWriter<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
67    inner: Option<RollingFileWriter<B, L, F>>,
68    partition_key: Option<PartitionKey>,
69}
70
71#[async_trait::async_trait]
72impl<B, L, F> IcebergWriter for DataFileWriter<B, L, F>
73where
74    B: FileWriterBuilder,
75    L: LocationGenerator,
76    F: FileNameGenerator,
77{
78    async fn write(&mut self, batch: RecordBatch) -> Result<()> {
79        if let Some(writer) = self.inner.as_mut() {
80            writer.write(&self.partition_key, &batch).await
81        } else {
82            Err(Error::new(
83                ErrorKind::Unexpected,
84                "Writer is not initialized!",
85            ))
86        }
87    }
88
89    async fn close(&mut self) -> Result<Vec<DataFile>> {
90        if let Some(writer) = self.inner.take() {
91            writer
92                .close()
93                .await?
94                .into_iter()
95                .map(|mut res| {
96                    res.content(DataContentType::Data);
97                    if let Some(pk) = self.partition_key.as_ref() {
98                        res.partition(pk.data().clone());
99                        res.partition_spec_id(pk.spec().spec_id());
100                    }
101                    res.build().map_err(|e| {
102                        Error::new(
103                            ErrorKind::DataInvalid,
104                            format!("Failed to build data file: {e}"),
105                        )
106                    })
107                })
108                .collect()
109        } else {
110            Err(Error::new(
111                ErrorKind::Unexpected,
112                "Data file writer has been closed.",
113            ))
114        }
115    }
116}
117
118impl<B, L, F> CurrentFileStatus for DataFileWriter<B, L, F>
119where
120    B: FileWriterBuilder,
121    L: LocationGenerator,
122    F: FileNameGenerator,
123{
124    fn current_file_path(&self) -> String {
125        self.inner.as_ref().unwrap().current_file_path()
126    }
127
128    fn current_row_num(&self) -> usize {
129        self.inner.as_ref().unwrap().current_row_num()
130    }
131
132    fn current_written_size(&self) -> usize {
133        self.inner.as_ref().unwrap().current_written_size()
134    }
135
136    fn current_schema(&self) -> crate::spec::SchemaRef {
137        self.inner.as_ref().unwrap().current_schema()
138    }
139}
140
141#[cfg(test)]
142mod test {
143    use std::collections::HashMap;
144    use std::sync::Arc;
145
146    use arrow_array::{Int32Array, StringArray};
147    use arrow_schema::{DataType, Field};
148    use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
149    use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions};
150    use parquet::file::properties::WriterProperties;
151    use tempfile::TempDir;
152
153    use crate::Result;
154    use crate::io::FileIOBuilder;
155    use crate::spec::{
156        DataContentType, DataFileFormat, Literal, NestedField, PartitionKey, PartitionSpec,
157        PrimitiveType, Schema, Struct, Type,
158    };
159    use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
160    use crate::writer::file_writer::ParquetWriterBuilder;
161    use crate::writer::file_writer::location_generator::{
162        DefaultFileNameGenerator, DefaultLocationGenerator,
163    };
164    use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
165    use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch};
166
167    #[tokio::test]
168    async fn test_parquet_writer() -> Result<()> {
169        let temp_dir = TempDir::new().unwrap();
170        let file_io = FileIOBuilder::new_fs_io().build().unwrap();
171        let location_gen = DefaultLocationGenerator::with_data_location(
172            temp_dir.path().to_str().unwrap().to_string(),
173        );
174        let file_name_gen =
175            DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
176
177        let schema = Arc::new(
178            Schema::builder()
179                .with_schema_id(3)
180                .with_fields(vec![
181                    NestedField::required(3, "foo", Type::Primitive(PrimitiveType::Int)).into(),
182                    NestedField::required(4, "bar", Type::Primitive(PrimitiveType::String)).into(),
183                ])
184                .build()?,
185        );
186
187        let pw = ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
188
189        let rolling_file_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
190            pw,
191            schema,
192            file_io.clone(),
193            location_gen,
194            file_name_gen,
195        );
196
197        let mut data_file_writer = DataFileWriterBuilder::new(rolling_file_writer_builder)
198            .build(None)
199            .await
200            .unwrap();
201
202        let arrow_schema = arrow_schema::Schema::new(vec![
203            Field::new("foo", DataType::Int32, false).with_metadata(HashMap::from([(
204                PARQUET_FIELD_ID_META_KEY.to_string(),
205                3.to_string(),
206            )])),
207            Field::new("bar", DataType::Utf8, false).with_metadata(HashMap::from([(
208                PARQUET_FIELD_ID_META_KEY.to_string(),
209                4.to_string(),
210            )])),
211        ]);
212        let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
213            Arc::new(Int32Array::from(vec![1, 2, 3])),
214            Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
215        ])?;
216        data_file_writer.write(batch).await?;
217
218        let data_files = data_file_writer.close().await.unwrap();
219        assert_eq!(data_files.len(), 1);
220
221        let data_file = &data_files[0];
222        assert_eq!(data_file.file_format, DataFileFormat::Parquet);
223        assert_eq!(data_file.content, DataContentType::Data);
224        assert_eq!(data_file.partition, Struct::empty());
225
226        let input_file = file_io.new_input(data_file.file_path.clone())?;
227        let input_content = input_file.read().await?;
228
229        let parquet_reader =
230            ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default())
231                .expect("Failed to load Parquet metadata");
232
233        let field_ids: Vec<i32> = parquet_reader
234            .parquet_schema()
235            .columns()
236            .iter()
237            .map(|col| col.self_type().get_basic_info().id())
238            .collect();
239
240        assert_eq!(field_ids, vec![3, 4]);
241        Ok(())
242    }
243
244    #[tokio::test]
245    async fn test_parquet_writer_with_partition() -> Result<()> {
246        let temp_dir = TempDir::new().unwrap();
247        let file_io = FileIOBuilder::new_fs_io().build().unwrap();
248        let location_gen = DefaultLocationGenerator::with_data_location(
249            temp_dir.path().to_str().unwrap().to_string(),
250        );
251        let file_name_gen = DefaultFileNameGenerator::new(
252            "test_partitioned".to_string(),
253            None,
254            DataFileFormat::Parquet,
255        );
256
257        let schema = Schema::builder()
258            .with_schema_id(5)
259            .with_fields(vec![
260                NestedField::required(5, "id", Type::Primitive(PrimitiveType::Int)).into(),
261                NestedField::required(6, "name", Type::Primitive(PrimitiveType::String)).into(),
262            ])
263            .build()?;
264        let schema_ref = Arc::new(schema);
265
266        let partition_value = Struct::from_iter([Some(Literal::int(1))]);
267        let partition_key = PartitionKey::new(
268            PartitionSpec::builder(schema_ref.clone()).build()?,
269            schema_ref.clone(),
270            partition_value.clone(),
271        );
272
273        let parquet_writer_builder =
274            ParquetWriterBuilder::new(WriterProperties::builder().build(), schema_ref.clone());
275
276        let rolling_file_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
277            parquet_writer_builder,
278            schema_ref.clone(),
279            file_io.clone(),
280            location_gen,
281            file_name_gen,
282        );
283
284        let mut data_file_writer = DataFileWriterBuilder::new(rolling_file_writer_builder)
285            .build(Some(partition_key))
286            .await?;
287
288        let arrow_schema = arrow_schema::Schema::new(vec![
289            Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
290                PARQUET_FIELD_ID_META_KEY.to_string(),
291                5.to_string(),
292            )])),
293            Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
294                PARQUET_FIELD_ID_META_KEY.to_string(),
295                6.to_string(),
296            )])),
297        ]);
298        let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
299            Arc::new(Int32Array::from(vec![1, 2, 3])),
300            Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
301        ])?;
302        data_file_writer.write(batch).await?;
303
304        let data_files = data_file_writer.close().await.unwrap();
305        assert_eq!(data_files.len(), 1);
306
307        let data_file = &data_files[0];
308        assert_eq!(data_file.file_format, DataFileFormat::Parquet);
309        assert_eq!(data_file.content, DataContentType::Data);
310        assert_eq!(data_file.partition, partition_value);
311
312        let input_file = file_io.new_input(data_file.file_path.clone())?;
313        let input_content = input_file.read().await?;
314
315        let parquet_reader =
316            ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default())?;
317
318        let field_ids: Vec<i32> = parquet_reader
319            .parquet_schema()
320            .columns()
321            .iter()
322            .map(|col| col.self_type().get_basic_info().id())
323            .collect();
324        assert_eq!(field_ids, vec![5, 6]);
325
326        let field_names: Vec<&str> = parquet_reader
327            .parquet_schema()
328            .columns()
329            .iter()
330            .map(|col| col.name())
331            .collect();
332        assert_eq!(field_names, vec!["id", "name"]);
333
334        Ok(())
335    }
336}