1// Copyright (c) The cargo-guppy Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
34use petgraph::{
5 graph::IndexType,
6 prelude::*,
7 visit::{
8 GraphRef, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCompactIndexable, VisitMap,
9 Visitable, Walker,
10 },
11};
12use std::marker::PhantomData;
1314/// A cycle-aware topological sort of a graph.
15#[derive(Clone, Debug)]
16pub struct TopoWithCycles<Ix> {
17// This is a map of each node index to its corresponding topo index.
18reverse_index: Box<[usize]>,
19// Prevent mixing up index types.
20_phantom: PhantomData<Ix>,
21}
2223impl<Ix: IndexType> TopoWithCycles<Ix> {
24pub fn new<G>(graph: G) -> Self
25where
26G: GraphRef
27 + Visitable<NodeId = NodeIndex<Ix>>
28 + IntoNodeIdentifiers
29 + IntoNeighborsDirected<NodeId = NodeIndex<Ix>>
30 + NodeCompactIndexable,
31 G::Map: VisitMap<NodeIndex<Ix>>,
32 {
33// petgraph's default topo algorithms don't handle cycles. Use DfsPostOrder which does.
34let mut dfs = DfsPostOrder::empty(graph);
3536let roots = graph
37 .node_identifiers()
38 .filter(move |&a| graph.neighbors_directed(a, Incoming).next().is_none());
39 dfs.stack.extend(roots);
4041let mut topo: Vec<NodeIndex<Ix>> = (&mut dfs).iter(graph).collect();
42// dfs returns its data in postorder (reverse topo order), so reverse that for forward topo
43 // order.
44topo.reverse();
4546// Because the graph is NodeCompactIndexable, the indexes are in the range
47 // (0..graph.node_count()).
48 // Use this property to build a reverse map.
49let mut reverse_index = vec![0; graph.node_count()];
50 topo.iter().enumerate().for_each(|(topo_ix, node_ix)| {
51 reverse_index[node_ix.index()] = topo_ix;
52 });
5354// topo.len cannot possibly exceed graph.node_count().
55assert!(
56 topo.len() <= graph.node_count(),
57"topo.len() <= graph.node_count() ({} is actually > {})",
58 topo.len(),
59 graph.node_count(),
60 );
61if topo.len() < graph.node_count() {
62// This means there was a cycle in the graph which caused some nodes to be skipped (e.g.
63 // consider a node with a self-loop -- it will be filtered out by the
64 // graph.neighbors_directed call above, and might not end up being part of the topo
65 // order).
66 //
67 // In this case, do a best-effort job: fill in the missing nodes with their reverse
68 // index set to the end of the topo order. We could do something fancier here with sccs,
69 // but for guppy this should never happen in practice. (In fact, the one time this code
70 // was hit there was actually an underlying bug.)
71let mut next = topo.len();
72for n in 0..graph.node_count() {
73let a = NodeIndex::new(n);
74if !dfs.finished.is_visited(&a) {
75// a is a missing index.
76reverse_index[a.index()] = next;
77 next += 1;
78 }
79 }
80 }
8182Self {
83 reverse_index: reverse_index.into_boxed_slice(),
84 _phantom: PhantomData,
85 }
86 }
8788/// Sort nodes based on the topo order in self.
89#[inline]
90pub fn sort_nodes(&self, nodes: &mut [NodeIndex<Ix>]) {
91 nodes.sort_unstable_by_key(|node_ix| self.topo_ix(*node_ix))
92 }
9394#[inline]
95pub fn topo_ix(&self, node_ix: NodeIndex<Ix>) -> usize {
96self.reverse_index[node_ix.index()]
97 }
98}
99100#[cfg(all(test, feature = "proptest1"))]
101mod proptests {
102use super::*;
103use proptest::prelude::*;
104105proptest! {
106#[test]
107fn graph_topo_sort(graph in possibly_cyclic_graph()) {
108let topo = TopoWithCycles::new(&graph);
109let mut nodes: Vec<_> = graph.node_indices().collect();
110111 check_consistency(&topo, graph.node_count());
112113 topo.sort_nodes(&mut nodes);
114for (topo_ix, node_ix) in nodes.iter().enumerate() {
115assert_eq!(topo.topo_ix(*node_ix), topo_ix);
116 }
117118 }
119 }
120121fn possibly_cyclic_graph() -> impl Strategy<Value = Graph<(), ()>> {
122// Generate a graph in adjacency list form. N nodes, up to N**2 edges.
123(1..=100usize)
124 .prop_flat_map(|n| {
125 (
126 Just(n),
127 prop::collection::vec(prop::collection::vec(0..n, 0..n), n),
128 )
129 })
130 .prop_map(|(n, adj)| {
131let mut graph =
132 Graph::<(), ()>::with_capacity(n, adj.iter().map(|x| x.len()).sum());
133for _ in 0..n {
134// Add all the nodes under consideration.
135graph.add_node(());
136 }
137for (src, dsts) in adj.into_iter().enumerate() {
138let src = NodeIndex::new(src);
139for dst in dsts {
140let dst = NodeIndex::new(dst);
141 graph.update_edge(src, dst, ());
142 }
143 }
144 graph
145 })
146 }
147148fn check_consistency(topo: &TopoWithCycles<u32>, n: usize) {
149// Ensure that all indexes are covered and unique.
150let mut seen = vec![false; n];
151for i in 0..n {
152let topo_ix = topo.topo_ix(NodeIndex::new(i));
153assert!(
154 !seen[topo_ix],
155"topo_ix {} should be seen exactly once, but seen twice",
156 topo_ix
157 );
158 seen[topo_ix] = true;
159 }
160for (i, &this_seen) in seen.iter().enumerate() {
161assert!(this_seen, "topo_ix {} should be seen, but wasn't", i);
162 }
163 }
164}