1use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, GenericByteArray};
23use arrow_buffer::NullBuffer;
24use arrow_schema::{ArrowError, DataType, SortOptions};
25use std::cmp::Ordering;
26
27pub fn rank(array: &dyn Array, options: Option<SortOptions>) -> Result<Vec<u32>, ArrowError> {
40 let options = options.unwrap_or_default();
41 let ranks = downcast_primitive_array! {
42 array => primitive_rank(array.values(), array.nulls(), options),
43 DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
44 DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), options),
45 DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), options),
46 DataType::LargeBinary => bytes_rank(array.as_bytes::<LargeBinaryType>(), options),
47 d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank")))
48 };
49 Ok(ranks)
50}
51
52#[inline(never)]
53fn primitive_rank<T: ArrowNativeTypeOp>(
54 values: &[T],
55 nulls: Option<&NullBuffer>,
56 options: SortOptions,
57) -> Vec<u32> {
58 let len: u32 = values.len().try_into().unwrap();
59 let to_sort = match nulls.filter(|n| n.null_count() > 0) {
60 Some(n) => n
61 .valid_indices()
62 .map(|idx| (values[idx], idx as u32))
63 .collect(),
64 None => values.iter().copied().zip(0..len).collect(),
65 };
66 rank_impl(values.len(), to_sort, options, T::compare, T::is_eq)
67}
68
69#[inline(never)]
70fn bytes_rank<T: ByteArrayType>(array: &GenericByteArray<T>, options: SortOptions) -> Vec<u32> {
71 let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) {
72 Some(n) => n
73 .valid_indices()
74 .map(|idx| (array.value(idx).as_ref(), idx as u32))
75 .collect(),
76 None => (0..array.len())
77 .map(|idx| (array.value(idx).as_ref(), idx as u32))
78 .collect(),
79 };
80 rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq)
81}
82
83fn rank_impl<T, C, E>(
84 len: usize,
85 mut valid: Vec<(T, u32)>,
86 options: SortOptions,
87 compare: C,
88 eq: E,
89) -> Vec<u32>
90where
91 T: Copy,
92 C: Fn(T, T) -> Ordering,
93 E: Fn(T, T) -> bool,
94{
95 valid.sort_unstable_by(|a, b| compare(a.0, b.0));
97 if options.descending {
98 valid.reverse();
99 }
100
101 let (mut valid_rank, null_rank) = match options.nulls_first {
102 true => (len as u32, (len - valid.len()) as u32),
103 false => (valid.len() as u32, len as u32),
104 };
105
106 let mut out: Vec<_> = vec![null_rank; len];
107 if let Some(v) = valid.last() {
108 out[v.1 as usize] = valid_rank;
109 }
110
111 let mut count = 1; for w in valid.windows(2).rev() {
113 match eq(w[0].0, w[1].0) {
114 true => {
115 count += 1;
116 out[w[0].1 as usize] = valid_rank;
117 }
118 false => {
119 valid_rank -= count;
120 count = 1;
121 out[w[0].1 as usize] = valid_rank
122 }
123 }
124 }
125
126 out
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use arrow_array::*;
133
134 #[test]
135 fn test_primitive() {
136 let descending = SortOptions {
137 descending: true,
138 nulls_first: true,
139 };
140
141 let nulls_last = SortOptions {
142 descending: false,
143 nulls_first: false,
144 };
145
146 let nulls_last_descending = SortOptions {
147 descending: true,
148 nulls_first: false,
149 };
150
151 let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]);
152 let res = rank(&a, None).unwrap();
153 assert_eq!(res, &[3, 3, 1, 5, 5, 6]);
154
155 let res = rank(&a, Some(descending)).unwrap();
156 assert_eq!(res, &[6, 6, 1, 4, 4, 2]);
157
158 let res = rank(&a, Some(nulls_last)).unwrap();
159 assert_eq!(res, &[2, 2, 6, 4, 4, 5]);
160
161 let res = rank(&a, Some(nulls_last_descending)).unwrap();
162 assert_eq!(res, &[5, 5, 6, 3, 3, 1]);
163
164 let nulls = NullBuffer::from(vec![true, true, false, true, false, false]);
166 let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls));
167 let res = rank(&a, None).unwrap();
168 assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
169 }
170
171 #[test]
172 fn test_bytes() {
173 let v = vec!["foo", "fo", "bar", "bar"];
174 let values = StringArray::from(v.clone());
175 let res = rank(&values, None).unwrap();
176 assert_eq!(res, &[4, 3, 2, 2]);
177
178 let values = LargeStringArray::from(v.clone());
179 let res = rank(&values, None).unwrap();
180 assert_eq!(res, &[4, 3, 2, 2]);
181
182 let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]];
183 let values = LargeBinaryArray::from(v.clone());
184 let res = rank(&values, None).unwrap();
185 assert_eq!(res, &[3, 1, 4, 3]);
186
187 let values = BinaryArray::from(v);
188 let res = rank(&values, None).unwrap();
189 assert_eq!(res, &[3, 1, 4, 3]);
190 }
191}