1use arrow_array::RecordBatch;
21
22use crate::spec::{DataContentType, DataFile, PartitionKey};
23use crate::writer::file_writer::FileWriterBuilder;
24use crate::writer::file_writer::location_generator::{FileNameGenerator, LocationGenerator};
25use crate::writer::file_writer::rolling_writer::{RollingFileWriter, RollingFileWriterBuilder};
26use crate::writer::{CurrentFileStatus, IcebergWriter, IcebergWriterBuilder};
27use crate::{Error, ErrorKind, Result};
28
29#[derive(Debug)]
31pub struct DataFileWriterBuilder<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
32 inner: RollingFileWriterBuilder<B, L, F>,
33}
34
35impl<B, L, F> DataFileWriterBuilder<B, L, F>
36where
37 B: FileWriterBuilder,
38 L: LocationGenerator,
39 F: FileNameGenerator,
40{
41 pub fn new(inner: RollingFileWriterBuilder<B, L, F>) -> Self {
43 Self { inner }
44 }
45}
46
47#[async_trait::async_trait]
48impl<B, L, F> IcebergWriterBuilder for DataFileWriterBuilder<B, L, F>
49where
50 B: FileWriterBuilder,
51 L: LocationGenerator,
52 F: FileNameGenerator,
53{
54 type R = DataFileWriter<B, L, F>;
55
56 async fn build(&self, partition_key: Option<PartitionKey>) -> Result<Self::R> {
57 Ok(DataFileWriter {
58 inner: Some(self.inner.build()),
59 partition_key,
60 })
61 }
62}
63
64#[derive(Debug)]
66pub struct DataFileWriter<B: FileWriterBuilder, L: LocationGenerator, F: FileNameGenerator> {
67 inner: Option<RollingFileWriter<B, L, F>>,
68 partition_key: Option<PartitionKey>,
69}
70
71#[async_trait::async_trait]
72impl<B, L, F> IcebergWriter for DataFileWriter<B, L, F>
73where
74 B: FileWriterBuilder,
75 L: LocationGenerator,
76 F: FileNameGenerator,
77{
78 async fn write(&mut self, batch: RecordBatch) -> Result<()> {
79 if let Some(writer) = self.inner.as_mut() {
80 writer.write(&self.partition_key, &batch).await
81 } else {
82 Err(Error::new(
83 ErrorKind::Unexpected,
84 "Writer is not initialized!",
85 ))
86 }
87 }
88
89 async fn close(&mut self) -> Result<Vec<DataFile>> {
90 if let Some(writer) = self.inner.take() {
91 writer
92 .close()
93 .await?
94 .into_iter()
95 .map(|mut res| {
96 res.content(DataContentType::Data);
97 if let Some(pk) = self.partition_key.as_ref() {
98 res.partition(pk.data().clone());
99 res.partition_spec_id(pk.spec().spec_id());
100 }
101 res.build().map_err(|e| {
102 Error::new(
103 ErrorKind::DataInvalid,
104 format!("Failed to build data file: {e}"),
105 )
106 })
107 })
108 .collect()
109 } else {
110 Err(Error::new(
111 ErrorKind::Unexpected,
112 "Data file writer has been closed.",
113 ))
114 }
115 }
116}
117
118impl<B, L, F> CurrentFileStatus for DataFileWriter<B, L, F>
119where
120 B: FileWriterBuilder,
121 L: LocationGenerator,
122 F: FileNameGenerator,
123{
124 fn current_file_path(&self) -> String {
125 self.inner.as_ref().unwrap().current_file_path()
126 }
127
128 fn current_row_num(&self) -> usize {
129 self.inner.as_ref().unwrap().current_row_num()
130 }
131
132 fn current_written_size(&self) -> usize {
133 self.inner.as_ref().unwrap().current_written_size()
134 }
135
136 fn current_schema(&self) -> crate::spec::SchemaRef {
137 self.inner.as_ref().unwrap().current_schema()
138 }
139}
140
141#[cfg(test)]
142mod test {
143 use std::collections::HashMap;
144 use std::sync::Arc;
145
146 use arrow_array::{Int32Array, StringArray};
147 use arrow_schema::{DataType, Field};
148 use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
149 use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions};
150 use parquet::file::properties::WriterProperties;
151 use tempfile::TempDir;
152
153 use crate::Result;
154 use crate::io::FileIOBuilder;
155 use crate::spec::{
156 DataContentType, DataFileFormat, Literal, NestedField, PartitionKey, PartitionSpec,
157 PrimitiveType, Schema, Struct, Type,
158 };
159 use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
160 use crate::writer::file_writer::ParquetWriterBuilder;
161 use crate::writer::file_writer::location_generator::{
162 DefaultFileNameGenerator, DefaultLocationGenerator,
163 };
164 use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
165 use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch};
166
167 #[tokio::test]
168 async fn test_parquet_writer() -> Result<()> {
169 let temp_dir = TempDir::new().unwrap();
170 let file_io = FileIOBuilder::new_fs_io().build().unwrap();
171 let location_gen = DefaultLocationGenerator::with_data_location(
172 temp_dir.path().to_str().unwrap().to_string(),
173 );
174 let file_name_gen =
175 DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
176
177 let schema = Arc::new(
178 Schema::builder()
179 .with_schema_id(3)
180 .with_fields(vec![
181 NestedField::required(3, "foo", Type::Primitive(PrimitiveType::Int)).into(),
182 NestedField::required(4, "bar", Type::Primitive(PrimitiveType::String)).into(),
183 ])
184 .build()?,
185 );
186
187 let pw = ParquetWriterBuilder::new(WriterProperties::builder().build(), schema.clone());
188
189 let rolling_file_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
190 pw,
191 schema,
192 file_io.clone(),
193 location_gen,
194 file_name_gen,
195 );
196
197 let mut data_file_writer = DataFileWriterBuilder::new(rolling_file_writer_builder)
198 .build(None)
199 .await
200 .unwrap();
201
202 let arrow_schema = arrow_schema::Schema::new(vec![
203 Field::new("foo", DataType::Int32, false).with_metadata(HashMap::from([(
204 PARQUET_FIELD_ID_META_KEY.to_string(),
205 3.to_string(),
206 )])),
207 Field::new("bar", DataType::Utf8, false).with_metadata(HashMap::from([(
208 PARQUET_FIELD_ID_META_KEY.to_string(),
209 4.to_string(),
210 )])),
211 ]);
212 let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
213 Arc::new(Int32Array::from(vec![1, 2, 3])),
214 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
215 ])?;
216 data_file_writer.write(batch).await?;
217
218 let data_files = data_file_writer.close().await.unwrap();
219 assert_eq!(data_files.len(), 1);
220
221 let data_file = &data_files[0];
222 assert_eq!(data_file.file_format, DataFileFormat::Parquet);
223 assert_eq!(data_file.content, DataContentType::Data);
224 assert_eq!(data_file.partition, Struct::empty());
225
226 let input_file = file_io.new_input(data_file.file_path.clone())?;
227 let input_content = input_file.read().await?;
228
229 let parquet_reader =
230 ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default())
231 .expect("Failed to load Parquet metadata");
232
233 let field_ids: Vec<i32> = parquet_reader
234 .parquet_schema()
235 .columns()
236 .iter()
237 .map(|col| col.self_type().get_basic_info().id())
238 .collect();
239
240 assert_eq!(field_ids, vec![3, 4]);
241 Ok(())
242 }
243
244 #[tokio::test]
245 async fn test_parquet_writer_with_partition() -> Result<()> {
246 let temp_dir = TempDir::new().unwrap();
247 let file_io = FileIOBuilder::new_fs_io().build().unwrap();
248 let location_gen = DefaultLocationGenerator::with_data_location(
249 temp_dir.path().to_str().unwrap().to_string(),
250 );
251 let file_name_gen = DefaultFileNameGenerator::new(
252 "test_partitioned".to_string(),
253 None,
254 DataFileFormat::Parquet,
255 );
256
257 let schema = Schema::builder()
258 .with_schema_id(5)
259 .with_fields(vec![
260 NestedField::required(5, "id", Type::Primitive(PrimitiveType::Int)).into(),
261 NestedField::required(6, "name", Type::Primitive(PrimitiveType::String)).into(),
262 ])
263 .build()?;
264 let schema_ref = Arc::new(schema);
265
266 let partition_value = Struct::from_iter([Some(Literal::int(1))]);
267 let partition_key = PartitionKey::new(
268 PartitionSpec::builder(schema_ref.clone()).build()?,
269 schema_ref.clone(),
270 partition_value.clone(),
271 );
272
273 let parquet_writer_builder =
274 ParquetWriterBuilder::new(WriterProperties::builder().build(), schema_ref.clone());
275
276 let rolling_file_writer_builder = RollingFileWriterBuilder::new_with_default_file_size(
277 parquet_writer_builder,
278 schema_ref.clone(),
279 file_io.clone(),
280 location_gen,
281 file_name_gen,
282 );
283
284 let mut data_file_writer = DataFileWriterBuilder::new(rolling_file_writer_builder)
285 .build(Some(partition_key))
286 .await?;
287
288 let arrow_schema = arrow_schema::Schema::new(vec![
289 Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([(
290 PARQUET_FIELD_ID_META_KEY.to_string(),
291 5.to_string(),
292 )])),
293 Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([(
294 PARQUET_FIELD_ID_META_KEY.to_string(),
295 6.to_string(),
296 )])),
297 ]);
298 let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![
299 Arc::new(Int32Array::from(vec![1, 2, 3])),
300 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
301 ])?;
302 data_file_writer.write(batch).await?;
303
304 let data_files = data_file_writer.close().await.unwrap();
305 assert_eq!(data_files.len(), 1);
306
307 let data_file = &data_files[0];
308 assert_eq!(data_file.file_format, DataFileFormat::Parquet);
309 assert_eq!(data_file.content, DataContentType::Data);
310 assert_eq!(data_file.partition, partition_value);
311
312 let input_file = file_io.new_input(data_file.file_path.clone())?;
313 let input_content = input_file.read().await?;
314
315 let parquet_reader =
316 ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default())?;
317
318 let field_ids: Vec<i32> = parquet_reader
319 .parquet_schema()
320 .columns()
321 .iter()
322 .map(|col| col.self_type().get_basic_info().id())
323 .collect();
324 assert_eq!(field_ids, vec![5, 6]);
325
326 let field_names: Vec<&str> = parquet_reader
327 .parquet_schema()
328 .columns()
329 .iter()
330 .map(|col| col.name())
331 .collect();
332 assert_eq!(field_names, vec!["id", "name"]);
333
334 Ok(())
335 }
336}