arrow_ord/
rank.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Provides `rank` function to assign a rank to each value in an array
19
20use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, GenericByteArray};
23use arrow_buffer::NullBuffer;
24use arrow_schema::{ArrowError, DataType, SortOptions};
25use std::cmp::Ordering;
26
27/// Assigns a rank to each value in `array` based on its position in the sorted order
28///
29/// Where values are equal, they will be assigned the highest of their ranks,
30/// leaving gaps in the overall rank assignment
31///
32/// ```
33/// # use arrow_array::StringArray;
34/// # use arrow_ord::rank::rank;
35/// let array = StringArray::from(vec![Some("foo"), None, Some("foo"), None, Some("bar")]);
36/// let ranks = rank(&array, None).unwrap();
37/// assert_eq!(ranks, &[5, 2, 5, 2, 3]);
38/// ```
39pub fn rank(array: &dyn Array, options: Option<SortOptions>) -> Result<Vec<u32>, ArrowError> {
40    let options = options.unwrap_or_default();
41    let ranks = downcast_primitive_array! {
42        array => primitive_rank(array.values(), array.nulls(), options),
43        DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
44        DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), options),
45        DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), options),
46        DataType::LargeBinary => bytes_rank(array.as_bytes::<LargeBinaryType>(), options),
47        d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank")))
48    };
49    Ok(ranks)
50}
51
52#[inline(never)]
53fn primitive_rank<T: ArrowNativeTypeOp>(
54    values: &[T],
55    nulls: Option<&NullBuffer>,
56    options: SortOptions,
57) -> Vec<u32> {
58    let len: u32 = values.len().try_into().unwrap();
59    let to_sort = match nulls.filter(|n| n.null_count() > 0) {
60        Some(n) => n
61            .valid_indices()
62            .map(|idx| (values[idx], idx as u32))
63            .collect(),
64        None => values.iter().copied().zip(0..len).collect(),
65    };
66    rank_impl(values.len(), to_sort, options, T::compare, T::is_eq)
67}
68
69#[inline(never)]
70fn bytes_rank<T: ByteArrayType>(array: &GenericByteArray<T>, options: SortOptions) -> Vec<u32> {
71    let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) {
72        Some(n) => n
73            .valid_indices()
74            .map(|idx| (array.value(idx).as_ref(), idx as u32))
75            .collect(),
76        None => (0..array.len())
77            .map(|idx| (array.value(idx).as_ref(), idx as u32))
78            .collect(),
79    };
80    rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq)
81}
82
83fn rank_impl<T, C, E>(
84    len: usize,
85    mut valid: Vec<(T, u32)>,
86    options: SortOptions,
87    compare: C,
88    eq: E,
89) -> Vec<u32>
90where
91    T: Copy,
92    C: Fn(T, T) -> Ordering,
93    E: Fn(T, T) -> bool,
94{
95    // We can use an unstable sort as we combine equal values later
96    valid.sort_unstable_by(|a, b| compare(a.0, b.0));
97    if options.descending {
98        valid.reverse();
99    }
100
101    let (mut valid_rank, null_rank) = match options.nulls_first {
102        true => (len as u32, (len - valid.len()) as u32),
103        false => (valid.len() as u32, len as u32),
104    };
105
106    let mut out: Vec<_> = vec![null_rank; len];
107    if let Some(v) = valid.last() {
108        out[v.1 as usize] = valid_rank;
109    }
110
111    let mut count = 1; // Number of values in rank
112    for w in valid.windows(2).rev() {
113        match eq(w[0].0, w[1].0) {
114            true => {
115                count += 1;
116                out[w[0].1 as usize] = valid_rank;
117            }
118            false => {
119                valid_rank -= count;
120                count = 1;
121                out[w[0].1 as usize] = valid_rank
122            }
123        }
124    }
125
126    out
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use arrow_array::*;
133
134    #[test]
135    fn test_primitive() {
136        let descending = SortOptions {
137            descending: true,
138            nulls_first: true,
139        };
140
141        let nulls_last = SortOptions {
142            descending: false,
143            nulls_first: false,
144        };
145
146        let nulls_last_descending = SortOptions {
147            descending: true,
148            nulls_first: false,
149        };
150
151        let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]);
152        let res = rank(&a, None).unwrap();
153        assert_eq!(res, &[3, 3, 1, 5, 5, 6]);
154
155        let res = rank(&a, Some(descending)).unwrap();
156        assert_eq!(res, &[6, 6, 1, 4, 4, 2]);
157
158        let res = rank(&a, Some(nulls_last)).unwrap();
159        assert_eq!(res, &[2, 2, 6, 4, 4, 5]);
160
161        let res = rank(&a, Some(nulls_last_descending)).unwrap();
162        assert_eq!(res, &[5, 5, 6, 3, 3, 1]);
163
164        // Test with non-zero null values
165        let nulls = NullBuffer::from(vec![true, true, false, true, false, false]);
166        let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls));
167        let res = rank(&a, None).unwrap();
168        assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
169    }
170
171    #[test]
172    fn test_bytes() {
173        let v = vec!["foo", "fo", "bar", "bar"];
174        let values = StringArray::from(v.clone());
175        let res = rank(&values, None).unwrap();
176        assert_eq!(res, &[4, 3, 2, 2]);
177
178        let values = LargeStringArray::from(v.clone());
179        let res = rank(&values, None).unwrap();
180        assert_eq!(res, &[4, 3, 2, 2]);
181
182        let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]];
183        let values = LargeBinaryArray::from(v.clone());
184        let res = rank(&values, None).unwrap();
185        assert_eq!(res, &[3, 1, 4, 3]);
186
187        let values = BinaryArray::from(v);
188        let res = rank(&values, None).unwrap();
189        assert_eq!(res, &[3, 1, 4, 3]);
190    }
191}