differential_dataflow/algorithms/prefix_sum.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
//! Implementation of Parallel Prefix Sum
use timely::dataflow::Scope;
use crate::{Collection, ExchangeData};
use crate::lattice::Lattice;
use crate::operators::*;
/// Extension trait for the prefix_sum method.
pub trait PrefixSum<G: Scope, K, D> {
/// Computes the prefix sum for each element in the collection.
///
/// The prefix sum is data-parallel, in the sense that the sums are computed independently for
/// each key of type `K`. For a single prefix sum this type can be `()`, but this permits the
/// more general accumulation of multiple independent sequences.
fn prefix_sum<F>(&self, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static;
/// Determine the prefix sum at each element of `location`.
fn prefix_sum_at<F>(&self, locations: Collection<G, (usize, K)>, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static;
}
impl<G, K, D> PrefixSum<G, K, D> for Collection<G, ((usize, K), D)>
where
G: Scope,
G::Timestamp: Lattice,
K: ExchangeData+::std::hash::Hash,
D: ExchangeData+::std::hash::Hash,
{
fn prefix_sum<F>(&self, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static {
self.prefix_sum_at(self.map(|(x,_)| x), zero, combine)
}
fn prefix_sum_at<F>(&self, locations: Collection<G, (usize, K)>, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static {
let combine1 = ::std::rc::Rc::new(combine);
let combine2 = combine1.clone();
let ranges = aggregate(self.clone(), move |k,x,y| (*combine1)(k,x,y));
broadcast(ranges, locations, zero, move |k,x,y| (*combine2)(k,x,y))
}
}
/// Accumulate data in `collection` into all powers-of-two intervals containing them.
pub fn aggregate<G, K, D, F>(collection: Collection<G, ((usize, K), D)>, combine: F) -> Collection<G, ((usize, usize, K), D)>
where
G: Scope,
G::Timestamp: Lattice,
K: ExchangeData+::std::hash::Hash,
D: ExchangeData+::std::hash::Hash,
F: Fn(&K,&D,&D)->D + 'static,
{
// initial ranges are at each index, and with width 2^0.
let unit_ranges = collection.map(|((index, key), data)| ((index, 0, key), data));
unit_ranges
.iterate(|ranges|
// Each available range, of size less than usize::max_value(), advertises itself as the range
// twice as large, aligned to integer multiples of its size. Each range, which may contain at
// most two elements, then summarizes itself using the `combine` function. Finally, we re-add
// the initial `unit_ranges` intervals, so that the set of ranges grows monotonically.
ranges
.filter(|&((_pos, log, _), _)| log < 64)
.map(|((pos, log, key), data)| ((pos >> 1, log + 1, key), (pos, data)))
.reduce(move |&(_pos, _log, ref key), input, output| {
let mut result = (input[0].0).1.clone();
if input.len() > 1 { result = combine(key, &result, &(input[1].0).1); }
output.push((result, 1));
})
.concat(&unit_ranges.enter(&ranges.scope()))
)
}
/// Produces the accumulated values at each of the `usize` locations in `queries`.
pub fn broadcast<G, K, D, F>(
ranges: Collection<G, ((usize, usize, K), D)>,
queries: Collection<G, (usize, K)>,
zero: D,
combine: F) -> Collection<G, ((usize, K), D)>
where
G: Scope,
G::Timestamp: Lattice+Ord+::std::fmt::Debug,
K: ExchangeData+::std::hash::Hash,
D: ExchangeData+::std::hash::Hash,
F: Fn(&K,&D,&D)->D + 'static,
{
let zero0 = zero.clone();
let zero1 = zero.clone();
let zero2 = zero.clone();
// The `queries` collection may not line up with an existing element of `ranges`, and so we must
// track down the first range that matches. If it doesn't exist, we will need to produce a zero
// value. We could produce the full path from (0, key) to (idx, key), and aggregate any and all
// matches. This has the defect of being n log n rather than linear, as the root ranges will be
// replicated for each query.
//
// I think it works to have each (idx, key) propose each of the intervals it knows should be used
// to assemble its input. We then `distinct` these and intersect them with the offered `ranges`,
// essentially performing a semijoin. We then perform the unfolding, where we might need to use
// empty ranges if none exist in `ranges`.
// We extract desired ranges for each `idx` from its binary representation: each set bit requires
// the contribution of a range, and we call out each of these. This could produce a super-linear
// amount of data (multiple requests for the roots), but it will be compacted down in `distinct`.
// We could reduce the amount of data by producing the requests iteratively, with a distinct in
// the loop to pre-suppress duplicate requests. This comes at a complexity cost, though.
let requests =
queries
.flat_map(|(idx, key)|
(0 .. 64)
.filter(move |i| (idx & (1usize << i)) != 0) // set bits require help.
.map(move |i| ((idx >> i) - 1, i, key.clone())) // width 2^i interval.
)
.distinct();
// Acquire each requested range.
let full_ranges =
ranges
.semijoin(&requests);
// Each requested range should exist, even if as a zero range, for correct reconstruction.
let zero_ranges =
full_ranges
.map(move |((idx, log, key), _)| ((idx, log, key), zero0.clone()))
.negate()
.concat(&requests.map(move |(idx, log, key)| ((idx, log, key), zero1.clone())));
// Merge occupied and empty ranges.
let used_ranges = full_ranges.concat(&zero_ranges);
// Each key should initiate a value of `zero` at position `0`.
let init_states =
queries
.map(move |(_, key)| ((0, key), zero2.clone()))
.distinct();
// Iteratively expand assigned values by joining existing ranges with current assignments.
init_states
.iterate(|states| {
used_ranges
.enter(&states.scope())
.map(|((pos, log, key), data)| ((pos << log, key), (log, data)))
.join_map(states, move |&(pos, ref key), &(log, ref data), state|
((pos + (1 << log), key.clone()), combine(key, state, data)))
.concat(&init_states.enter(&states.scope()))
.distinct()
})
.semijoin(&queries)
}