guppy/petgraph_support/
topo.rs

1// Copyright (c) The cargo-guppy Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use petgraph::{
5    graph::IndexType,
6    prelude::*,
7    visit::{
8        GraphRef, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCompactIndexable, VisitMap,
9        Visitable, Walker,
10    },
11};
12use std::marker::PhantomData;
13
14/// A cycle-aware topological sort of a graph.
15#[derive(Clone, Debug)]
16pub struct TopoWithCycles<Ix> {
17    // This is a map of each node index to its corresponding topo index.
18    reverse_index: Box<[usize]>,
19    // Prevent mixing up index types.
20    _phantom: PhantomData<Ix>,
21}
22
23impl<Ix: IndexType> TopoWithCycles<Ix> {
24    pub fn new<G>(graph: G) -> Self
25    where
26        G: GraphRef
27            + Visitable<NodeId = NodeIndex<Ix>>
28            + IntoNodeIdentifiers
29            + IntoNeighborsDirected<NodeId = NodeIndex<Ix>>
30            + NodeCompactIndexable,
31        G::Map: VisitMap<NodeIndex<Ix>>,
32    {
33        // petgraph's default topo algorithms don't handle cycles. Use DfsPostOrder which does.
34        let mut dfs = DfsPostOrder::empty(graph);
35
36        let roots = graph
37            .node_identifiers()
38            .filter(move |&a| graph.neighbors_directed(a, Incoming).next().is_none());
39        dfs.stack.extend(roots);
40
41        let mut topo: Vec<NodeIndex<Ix>> = (&mut dfs).iter(graph).collect();
42        // dfs returns its data in postorder (reverse topo order), so reverse that for forward topo
43        // order.
44        topo.reverse();
45
46        // Because the graph is NodeCompactIndexable, the indexes are in the range
47        // (0..graph.node_count()).
48        // Use this property to build a reverse map.
49        let mut reverse_index = vec![0; graph.node_count()];
50        topo.iter().enumerate().for_each(|(topo_ix, node_ix)| {
51            reverse_index[node_ix.index()] = topo_ix;
52        });
53
54        // topo.len cannot possibly exceed graph.node_count().
55        assert!(
56            topo.len() <= graph.node_count(),
57            "topo.len() <= graph.node_count() ({} is actually > {})",
58            topo.len(),
59            graph.node_count(),
60        );
61        if topo.len() < graph.node_count() {
62            // This means there was a cycle in the graph which caused some nodes to be skipped (e.g.
63            // consider a node with a self-loop -- it will be filtered out by the
64            // graph.neighbors_directed call above, and might not end up being part of the topo
65            // order).
66            //
67            // In this case, do a best-effort job: fill in the missing nodes with their reverse
68            // index set to the end of the topo order. We could do something fancier here with sccs,
69            // but for guppy this should never happen in practice. (In fact, the one time this code
70            // was hit there was actually an underlying bug.)
71            let mut next = topo.len();
72            for n in 0..graph.node_count() {
73                let a = NodeIndex::new(n);
74                if !dfs.finished.is_visited(&a) {
75                    // a is a missing index.
76                    reverse_index[a.index()] = next;
77                    next += 1;
78                }
79            }
80        }
81
82        Self {
83            reverse_index: reverse_index.into_boxed_slice(),
84            _phantom: PhantomData,
85        }
86    }
87
88    /// Sort nodes based on the topo order in self.
89    #[inline]
90    pub fn sort_nodes(&self, nodes: &mut [NodeIndex<Ix>]) {
91        nodes.sort_unstable_by_key(|node_ix| self.topo_ix(*node_ix))
92    }
93
94    #[inline]
95    pub fn topo_ix(&self, node_ix: NodeIndex<Ix>) -> usize {
96        self.reverse_index[node_ix.index()]
97    }
98}
99
100#[cfg(all(test, feature = "proptest1"))]
101mod proptests {
102    use super::*;
103    use proptest::prelude::*;
104
105    proptest! {
106        #[test]
107        fn graph_topo_sort(graph in possibly_cyclic_graph()) {
108            let topo = TopoWithCycles::new(&graph);
109            let mut nodes: Vec<_> = graph.node_indices().collect();
110
111            check_consistency(&topo, graph.node_count());
112
113            topo.sort_nodes(&mut nodes);
114            for (topo_ix, node_ix) in nodes.iter().enumerate() {
115                assert_eq!(topo.topo_ix(*node_ix), topo_ix);
116            }
117
118        }
119    }
120
121    fn possibly_cyclic_graph() -> impl Strategy<Value = Graph<(), ()>> {
122        // Generate a graph in adjacency list form. N nodes, up to N**2 edges.
123        (1..=100usize)
124            .prop_flat_map(|n| {
125                (
126                    Just(n),
127                    prop::collection::vec(prop::collection::vec(0..n, 0..n), n),
128                )
129            })
130            .prop_map(|(n, adj)| {
131                let mut graph =
132                    Graph::<(), ()>::with_capacity(n, adj.iter().map(|x| x.len()).sum());
133                for _ in 0..n {
134                    // Add all the nodes under consideration.
135                    graph.add_node(());
136                }
137                for (src, dsts) in adj.into_iter().enumerate() {
138                    let src = NodeIndex::new(src);
139                    for dst in dsts {
140                        let dst = NodeIndex::new(dst);
141                        graph.update_edge(src, dst, ());
142                    }
143                }
144                graph
145            })
146    }
147
148    fn check_consistency(topo: &TopoWithCycles<u32>, n: usize) {
149        // Ensure that all indexes are covered and unique.
150        let mut seen = vec![false; n];
151        for i in 0..n {
152            let topo_ix = topo.topo_ix(NodeIndex::new(i));
153            assert!(
154                !seen[topo_ix],
155                "topo_ix {} should be seen exactly once, but seen twice",
156                topo_ix
157            );
158            seen[topo_ix] = true;
159        }
160        for (i, &this_seen) in seen.iter().enumerate() {
161            assert!(this_seen, "topo_ix {} should be seen, but wasn't", i);
162        }
163    }
164}