1use std::any::Any;
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22use arrow_array::{Array, ArrayRef, OffsetSizeTrait};
23use arrow_buffer::ArrowNativeType;
24use arrow_schema::DataType as ArrowType;
25use bytes::Bytes;
26
27use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, ByteArrayDecoderPlain};
28use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
29use crate::arrow::buffer::{dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer};
30use crate::arrow::record_reader::GenericRecordReader;
31use crate::arrow::schema::parquet_to_arrow_field;
32use crate::basic::{ConvertedType, Encoding};
33use crate::column::page::PageIterator;
34use crate::column::reader::decoder::ColumnValueDecoder;
35use crate::encodings::rle::RleDecoder;
36use crate::errors::{ParquetError, Result};
37use crate::schema::types::ColumnDescPtr;
38use crate::util::bit_util::FromBytes;
39
40macro_rules! make_reader {
42 (
43 ($pages:expr, $column_desc:expr, $data_type:expr) => match ($k:expr, $v:expr) {
44 $(($key_arrow:pat, $value_arrow:pat) => ($key_type:ty, $value_type:ty),)+
45 }
46 ) => {
47 match (($k, $v)) {
48 $(
49 ($key_arrow, $value_arrow) => {
50 let reader = GenericRecordReader::new($column_desc);
51 Ok(Box::new(ByteArrayDictionaryReader::<$key_type, $value_type>::new(
52 $pages, $data_type, reader,
53 )))
54 }
55 )+
56 _ => Err(general_err!(
57 "unsupported data type for byte array dictionary reader - {}",
58 $data_type
59 )),
60 }
61 }
62}
63
64pub fn make_byte_array_dictionary_reader(
77 pages: Box<dyn PageIterator>,
78 column_desc: ColumnDescPtr,
79 arrow_type: Option<ArrowType>,
80) -> Result<Box<dyn ArrayReader>> {
81 let data_type = match arrow_type {
83 Some(t) => t,
84 None => parquet_to_arrow_field(column_desc.as_ref())?
85 .data_type()
86 .clone(),
87 };
88
89 match &data_type {
90 ArrowType::Dictionary(key_type, value_type) => {
91 make_reader! {
92 (pages, column_desc, data_type) => match (key_type.as_ref(), value_type.as_ref()) {
93 (ArrowType::UInt8, ArrowType::Binary | ArrowType::Utf8) => (u8, i32),
94 (ArrowType::UInt8, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u8, i64),
95 (ArrowType::Int8, ArrowType::Binary | ArrowType::Utf8) => (i8, i32),
96 (ArrowType::Int8, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i8, i64),
97 (ArrowType::UInt16, ArrowType::Binary | ArrowType::Utf8) => (u16, i32),
98 (ArrowType::UInt16, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u16, i64),
99 (ArrowType::Int16, ArrowType::Binary | ArrowType::Utf8) => (i16, i32),
100 (ArrowType::Int16, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i16, i64),
101 (ArrowType::UInt32, ArrowType::Binary | ArrowType::Utf8) => (u32, i32),
102 (ArrowType::UInt32, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u32, i64),
103 (ArrowType::Int32, ArrowType::Binary | ArrowType::Utf8) => (i32, i32),
104 (ArrowType::Int32, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i32, i64),
105 (ArrowType::UInt64, ArrowType::Binary | ArrowType::Utf8) => (u64, i32),
106 (ArrowType::UInt64, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u64, i64),
107 (ArrowType::Int64, ArrowType::Binary | ArrowType::Utf8) => (i64, i32),
108 (ArrowType::Int64, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i64, i64),
109 }
110 }
111 }
112 _ => Err(general_err!(
113 "invalid non-dictionary data type for byte array dictionary reader - {}",
114 data_type
115 )),
116 }
117}
118
119struct ByteArrayDictionaryReader<K: ArrowNativeType, V: OffsetSizeTrait> {
123 data_type: ArrowType,
124 pages: Box<dyn PageIterator>,
125 def_levels_buffer: Option<Vec<i16>>,
126 rep_levels_buffer: Option<Vec<i16>>,
127 record_reader: GenericRecordReader<DictionaryBuffer<K, V>, DictionaryDecoder<K, V>>,
128}
129
130impl<K, V> ByteArrayDictionaryReader<K, V>
131where
132 K: FromBytes + Ord + ArrowNativeType,
133 V: OffsetSizeTrait,
134{
135 fn new(
136 pages: Box<dyn PageIterator>,
137 data_type: ArrowType,
138 record_reader: GenericRecordReader<DictionaryBuffer<K, V>, DictionaryDecoder<K, V>>,
139 ) -> Self {
140 Self {
141 data_type,
142 pages,
143 def_levels_buffer: None,
144 rep_levels_buffer: None,
145 record_reader,
146 }
147 }
148}
149
150impl<K, V> ArrayReader for ByteArrayDictionaryReader<K, V>
151where
152 K: FromBytes + Ord + ArrowNativeType,
153 V: OffsetSizeTrait,
154{
155 fn as_any(&self) -> &dyn Any {
156 self
157 }
158
159 fn get_data_type(&self) -> &ArrowType {
160 &self.data_type
161 }
162
163 fn read_records(&mut self, batch_size: usize) -> Result<usize> {
164 read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
165 }
166
167 fn consume_batch(&mut self) -> Result<ArrayRef> {
168 let buffer = self.record_reader.consume_record_data();
169 let null_buffer = self.record_reader.consume_bitmap_buffer();
170 let array = buffer.into_array(null_buffer, &self.data_type)?;
171
172 self.def_levels_buffer = self.record_reader.consume_def_levels();
173 self.rep_levels_buffer = self.record_reader.consume_rep_levels();
174 self.record_reader.reset();
175
176 Ok(array)
177 }
178
179 fn skip_records(&mut self, num_records: usize) -> Result<usize> {
180 skip_records(&mut self.record_reader, self.pages.as_mut(), num_records)
181 }
182
183 fn get_def_levels(&self) -> Option<&[i16]> {
184 self.def_levels_buffer.as_deref()
185 }
186
187 fn get_rep_levels(&self) -> Option<&[i16]> {
188 self.rep_levels_buffer.as_deref()
189 }
190}
191
192enum MaybeDictionaryDecoder {
196 Dict {
197 decoder: RleDecoder,
198 max_remaining_values: usize,
201 },
202 Fallback(ByteArrayDecoder),
203}
204
205struct DictionaryDecoder<K, V> {
207 dict: Option<ArrayRef>,
209
210 decoder: Option<MaybeDictionaryDecoder>,
212
213 validate_utf8: bool,
214
215 value_type: ArrowType,
216
217 phantom: PhantomData<(K, V)>,
218}
219
220impl<K, V> ColumnValueDecoder for DictionaryDecoder<K, V>
221where
222 K: FromBytes + Ord + ArrowNativeType,
223 V: OffsetSizeTrait,
224{
225 type Buffer = DictionaryBuffer<K, V>;
226
227 fn new(col: &ColumnDescPtr) -> Self {
228 let validate_utf8 = col.converted_type() == ConvertedType::UTF8;
229
230 let value_type = match (V::IS_LARGE, col.converted_type() == ConvertedType::UTF8) {
231 (true, true) => ArrowType::LargeUtf8,
232 (true, false) => ArrowType::LargeBinary,
233 (false, true) => ArrowType::Utf8,
234 (false, false) => ArrowType::Binary,
235 };
236
237 Self {
238 dict: None,
239 decoder: None,
240 validate_utf8,
241 value_type,
242 phantom: Default::default(),
243 }
244 }
245
246 fn set_dict(
247 &mut self,
248 buf: Bytes,
249 num_values: u32,
250 encoding: Encoding,
251 _is_sorted: bool,
252 ) -> Result<()> {
253 if !matches!(
254 encoding,
255 Encoding::PLAIN | Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY
256 ) {
257 return Err(nyi_err!(
258 "Invalid/Unsupported encoding type for dictionary: {}",
259 encoding
260 ));
261 }
262
263 if K::from_usize(num_values as usize).is_none() {
264 return Err(general_err!("dictionary too large for index type"));
265 }
266
267 let len = num_values as usize;
268 let mut buffer = OffsetBuffer::<V>::default();
269 let mut decoder = ByteArrayDecoderPlain::new(buf, len, Some(len), self.validate_utf8);
270 decoder.read(&mut buffer, usize::MAX)?;
271
272 let array = buffer.into_array(None, self.value_type.clone());
273 self.dict = Some(Arc::new(array));
274 Ok(())
275 }
276
277 fn set_data(
278 &mut self,
279 encoding: Encoding,
280 data: Bytes,
281 num_levels: usize,
282 num_values: Option<usize>,
283 ) -> Result<()> {
284 let decoder = match encoding {
285 Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => {
286 let bit_width = data[0];
287 let mut decoder = RleDecoder::new(bit_width);
288 decoder.set_data(data.slice(1..));
289 MaybeDictionaryDecoder::Dict {
290 decoder,
291 max_remaining_values: num_values.unwrap_or(num_levels),
292 }
293 }
294 _ => MaybeDictionaryDecoder::Fallback(ByteArrayDecoder::new(
295 encoding,
296 data,
297 num_levels,
298 num_values,
299 self.validate_utf8,
300 )?),
301 };
302
303 self.decoder = Some(decoder);
304 Ok(())
305 }
306
307 fn read(&mut self, out: &mut Self::Buffer, num_values: usize) -> Result<usize> {
308 match self.decoder.as_mut().expect("decoder set") {
309 MaybeDictionaryDecoder::Fallback(decoder) => {
310 decoder.read(out.spill_values()?, num_values, None)
311 }
312 MaybeDictionaryDecoder::Dict {
313 decoder,
314 max_remaining_values,
315 } => {
316 let len = num_values.min(*max_remaining_values);
317
318 let dict = self
319 .dict
320 .as_ref()
321 .ok_or_else(|| general_err!("missing dictionary page for column"))?;
322
323 assert_eq!(dict.data_type(), &self.value_type);
324
325 if dict.is_empty() {
326 return Ok(0); }
328
329 match out.as_keys(dict) {
330 Some(keys) => {
331 let start = keys.len();
336 keys.resize(start + len, K::default());
337 let len = decoder.get_batch(&mut keys[start..])?;
338 keys.truncate(start + len);
339 *max_remaining_values -= len;
340 Ok(len)
341 }
342 None => {
343 let values = out.spill_values()?;
348 let mut keys = vec![K::default(); len];
349 let len = decoder.get_batch(&mut keys)?;
350
351 assert_eq!(dict.data_type(), &self.value_type);
352
353 let data = dict.to_data();
354 let dict_buffers = data.buffers();
355 let dict_offsets = dict_buffers[0].typed_data::<V>();
356 let dict_values = dict_buffers[1].as_slice();
357
358 values.extend_from_dictionary(&keys[..len], dict_offsets, dict_values)?;
359 *max_remaining_values -= len;
360 Ok(len)
361 }
362 }
363 }
364 }
365 }
366
367 fn skip_values(&mut self, num_values: usize) -> Result<usize> {
368 match self.decoder.as_mut().expect("decoder set") {
369 MaybeDictionaryDecoder::Fallback(decoder) => decoder.skip::<V>(num_values, None),
370 MaybeDictionaryDecoder::Dict {
371 decoder,
372 max_remaining_values,
373 } => {
374 let num_values = num_values.min(*max_remaining_values);
375 *max_remaining_values -= num_values;
376 decoder.skip(num_values)
377 }
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use arrow::compute::cast;
385 use arrow_array::{Array, StringArray};
386 use arrow_buffer::Buffer;
387
388 use crate::arrow::array_reader::test_util::{
389 byte_array_all_encodings, encode_dictionary, utf8_column,
390 };
391 use crate::arrow::record_reader::buffer::ValuesBuffer;
392 use crate::data_type::ByteArray;
393
394 use super::*;
395
396 fn utf8_dictionary() -> ArrowType {
397 ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8))
398 }
399
400 #[test]
401 fn test_dictionary_preservation() {
402 let data_type = utf8_dictionary();
403
404 let data: Vec<_> = vec!["0", "1", "0", "1", "2", "1", "2"]
405 .into_iter()
406 .map(ByteArray::from)
407 .collect();
408 let (dict, encoded) = encode_dictionary(&data);
409
410 let column_desc = utf8_column();
411 let mut decoder = DictionaryDecoder::<i32, i32>::new(&column_desc);
412
413 decoder
414 .set_dict(dict, 3, Encoding::RLE_DICTIONARY, false)
415 .unwrap();
416
417 decoder
418 .set_data(Encoding::RLE_DICTIONARY, encoded, 14, Some(data.len()))
419 .unwrap();
420
421 let mut output = DictionaryBuffer::<i32, i32>::default();
422 assert_eq!(decoder.read(&mut output, 3).unwrap(), 3);
423
424 let mut valid = vec![false, false, true, true, false, true];
425 let valid_buffer = Buffer::from_iter(valid.iter().cloned());
426 output.pad_nulls(0, 3, valid.len(), valid_buffer.as_slice());
427
428 assert!(matches!(output, DictionaryBuffer::Dict { .. }));
429
430 assert_eq!(decoder.read(&mut output, 4).unwrap(), 4);
431
432 valid.extend_from_slice(&[false, false, true, true, false, true, true, false]);
433 let valid_buffer = Buffer::from_iter(valid.iter().cloned());
434 output.pad_nulls(6, 4, 8, valid_buffer.as_slice());
435
436 assert!(matches!(output, DictionaryBuffer::Dict { .. }));
437
438 let array = output.into_array(Some(valid_buffer), &data_type).unwrap();
439 assert_eq!(array.data_type(), &data_type);
440
441 let array = cast(&array, &ArrowType::Utf8).unwrap();
442 let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
443 assert_eq!(strings.len(), 14);
444
445 assert_eq!(
446 strings.iter().collect::<Vec<_>>(),
447 vec![
448 None,
449 None,
450 Some("0"),
451 Some("1"),
452 None,
453 Some("0"),
454 None,
455 None,
456 Some("1"),
457 Some("2"),
458 None,
459 Some("1"),
460 Some("2"),
461 None
462 ]
463 )
464 }
465
466 #[test]
467 fn test_dictionary_preservation_skip() {
468 let data_type = utf8_dictionary();
469
470 let data: Vec<_> = vec!["0", "1", "0", "1", "2", "1", "2"]
471 .into_iter()
472 .map(ByteArray::from)
473 .collect();
474 let (dict, encoded) = encode_dictionary(&data);
475
476 let column_desc = utf8_column();
477 let mut decoder = DictionaryDecoder::<i32, i32>::new(&column_desc);
478
479 decoder
480 .set_dict(dict, 3, Encoding::RLE_DICTIONARY, false)
481 .unwrap();
482
483 decoder
484 .set_data(Encoding::RLE_DICTIONARY, encoded, 7, Some(data.len()))
485 .unwrap();
486
487 let mut output = DictionaryBuffer::<i32, i32>::default();
488
489 assert_eq!(decoder.read(&mut output, 2).unwrap(), 2);
491 assert_eq!(decoder.skip_values(1).unwrap(), 1);
492
493 assert!(matches!(output, DictionaryBuffer::Dict { .. }));
494
495 assert_eq!(decoder.read(&mut output, 2).unwrap(), 2);
497 assert_eq!(decoder.skip_values(1).unwrap(), 1);
498
499 assert_eq!(decoder.read(&mut output, 1).unwrap(), 1);
501 assert_eq!(decoder.skip_values(4).unwrap(), 0);
502
503 let valid = [true, true, true, true, true];
504 let valid_buffer = Buffer::from_iter(valid.iter().cloned());
505 output.pad_nulls(0, 5, 5, valid_buffer.as_slice());
506
507 assert!(matches!(output, DictionaryBuffer::Dict { .. }));
508
509 let array = output.into_array(Some(valid_buffer), &data_type).unwrap();
510 assert_eq!(array.data_type(), &data_type);
511
512 let array = cast(&array, &ArrowType::Utf8).unwrap();
513 let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
514 assert_eq!(strings.len(), 5);
515
516 assert_eq!(
517 strings.iter().collect::<Vec<_>>(),
518 vec![Some("0"), Some("1"), Some("1"), Some("2"), Some("2"),]
519 )
520 }
521
522 #[test]
523 fn test_dictionary_fallback() {
524 let data_type = utf8_dictionary();
525 let data = vec!["hello", "world", "a", "b"];
526
527 let (pages, encoded_dictionary) = byte_array_all_encodings(data.clone());
528 let num_encodings = pages.len();
529
530 let column_desc = utf8_column();
531 let mut decoder = DictionaryDecoder::<i32, i32>::new(&column_desc);
532
533 decoder
534 .set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false)
535 .unwrap();
536
537 let mut output = DictionaryBuffer::<i32, i32>::default();
539
540 for (encoding, page) in pages {
541 decoder.set_data(encoding, page, 4, Some(4)).unwrap();
542 assert_eq!(decoder.read(&mut output, 1024).unwrap(), 4);
543 }
544 let array = output.into_array(None, &data_type).unwrap();
545 assert_eq!(array.data_type(), &data_type);
546
547 let array = cast(&array, &ArrowType::Utf8).unwrap();
548 let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
549 assert_eq!(strings.len(), data.len() * num_encodings);
550
551 for i in 0..num_encodings {
553 assert_eq!(
554 strings
555 .iter()
556 .skip(i * data.len())
557 .take(data.len())
558 .map(|x| x.unwrap())
559 .collect::<Vec<_>>(),
560 data
561 )
562 }
563 }
564
565 #[test]
566 fn test_dictionary_skip_fallback() {
567 let data_type = utf8_dictionary();
568 let data = vec!["hello", "world", "a", "b"];
569
570 let (pages, encoded_dictionary) = byte_array_all_encodings(data.clone());
571 let num_encodings = pages.len();
572
573 let column_desc = utf8_column();
574 let mut decoder = DictionaryDecoder::<i32, i32>::new(&column_desc);
575
576 decoder
577 .set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false)
578 .unwrap();
579
580 let mut output = DictionaryBuffer::<i32, i32>::default();
582
583 for (encoding, page) in pages {
584 decoder.set_data(encoding, page, 4, Some(4)).unwrap();
585 decoder.skip_values(2).expect("skipping two values");
586 assert_eq!(decoder.read(&mut output, 1024).unwrap(), 2);
587 }
588 let array = output.into_array(None, &data_type).unwrap();
589 assert_eq!(array.data_type(), &data_type);
590
591 let array = cast(&array, &ArrowType::Utf8).unwrap();
592 let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
593 assert_eq!(strings.len(), (data.len() - 2) * num_encodings);
594
595 for i in 0..num_encodings {
597 assert_eq!(
598 &strings
599 .iter()
600 .skip(i * (data.len() - 2))
601 .take(data.len() - 2)
602 .map(|x| x.unwrap())
603 .collect::<Vec<_>>(),
604 &data[2..]
605 )
606 }
607 }
608
609 #[test]
610 fn test_too_large_dictionary() {
611 let data: Vec<_> = (0..128)
612 .map(|x| ByteArray::from(x.to_string().as_str()))
613 .collect();
614 let (dictionary, _) = encode_dictionary(&data);
615
616 let column_desc = utf8_column();
617
618 let mut decoder = DictionaryDecoder::<i8, i32>::new(&column_desc);
619 let err = decoder
620 .set_dict(dictionary.clone(), 128, Encoding::RLE_DICTIONARY, false)
621 .unwrap_err()
622 .to_string();
623
624 assert!(err.contains("dictionary too large for index type"));
625
626 let mut decoder = DictionaryDecoder::<i16, i32>::new(&column_desc);
627 decoder
628 .set_dict(dictionary, 128, Encoding::RLE_DICTIONARY, false)
629 .unwrap();
630 }
631
632 #[test]
633 fn test_nulls() {
634 let data_type = utf8_dictionary();
635 let (pages, encoded_dictionary) = byte_array_all_encodings(Vec::<&str>::new());
636
637 let column_desc = utf8_column();
638 let mut decoder = DictionaryDecoder::new(&column_desc);
639
640 decoder
641 .set_dict(encoded_dictionary, 4, Encoding::PLAIN_DICTIONARY, false)
642 .unwrap();
643
644 for (encoding, page) in pages.clone() {
645 let mut output = DictionaryBuffer::<i32, i32>::default();
646 decoder.set_data(encoding, page, 8, None).unwrap();
647 assert_eq!(decoder.read(&mut output, 1024).unwrap(), 0);
648
649 output.pad_nulls(0, 0, 8, &[0]);
650 let array = output
651 .into_array(Some(Buffer::from(&[0])), &data_type)
652 .unwrap();
653
654 assert_eq!(array.len(), 8);
655 assert_eq!(array.null_count(), 8);
656 assert_eq!(array.logical_null_count(), 8);
657 }
658
659 for (encoding, page) in pages {
660 let mut output = DictionaryBuffer::<i32, i32>::default();
661 decoder.set_data(encoding, page, 8, None).unwrap();
662 assert_eq!(decoder.skip_values(1024).unwrap(), 0);
663
664 output.pad_nulls(0, 0, 8, &[0]);
665 let array = output
666 .into_array(Some(Buffer::from(&[0])), &data_type)
667 .unwrap();
668
669 assert_eq!(array.len(), 8);
670 assert_eq!(array.null_count(), 8);
671 assert_eq!(array.logical_null_count(), 8);
672 }
673 }
674}