directed_graph/
lib.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cmp::min;
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10/// A directed graph, whose nodes contain an identifier of type `T`.
11pub struct DirectedGraph<T: Clone + PartialEq + Hash + Ord + Debug + Display>(
12    HashMap<T, DirectedNode<T>>,
13);
14
15impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> DirectedGraph<T> {
16    /// Created a new empty `DirectedGraph`.
17    pub fn new() -> Self {
18        Self(HashMap::new())
19    }
20
21    /// Add an edge to the graph, adding nodes if necessary.
22    pub fn add_edge(&mut self, source: T, target: T) {
23        self.0.entry(source).or_insert_with(DirectedNode::new).add_target(target.clone());
24        self.0.entry(target).or_insert_with(DirectedNode::new);
25    }
26
27    /// Get targets of all edges from this node.
28    pub fn get_targets<'a>(&'a self, id: &T) -> Option<&'a HashSet<T>> {
29        self.0.get(id).as_ref().map(|node| &node.0)
30    }
31
32    /// Given a dependency graph, find the set of all nodes that have a
33    /// dependency on the `start` node (i.e., the reverse dependency closure).
34    /// This includes `start` itself.
35    pub fn get_closure<'a>(&'a self, start: &T) -> HashSet<&'a T> {
36        let Some((start, _)) = self.0.get_key_value(start) else {
37            return HashSet::new();
38        };
39        let mut reverse_deps: HashMap<&T, Vec<&T>> = HashMap::new();
40        for (source, targets) in &self.0 {
41            for target in &targets.0 {
42                reverse_deps.entry(target).or_default().push(source);
43            }
44        }
45
46        let mut closure = HashSet::new();
47        let mut to_visit = VecDeque::new();
48
49        closure.insert(start);
50        to_visit.push_back(start);
51
52        while let Some(current_node) = to_visit.pop_front() {
53            if let Some(parents) = reverse_deps.get(&current_node) {
54                for parent in parents {
55                    if closure.insert(*parent) {
56                        to_visit.push_back(*parent);
57                    }
58                }
59            }
60        }
61
62        closure
63    }
64
65    /// Adds an edgeless node to the graph. Used by shutdown to ensure every node gets visited,
66    /// even if there are no paths to the node.
67    pub fn add_node(&mut self, source: T) {
68        self.0.entry(source).or_insert_with(DirectedNode::new);
69    }
70
71    /// Returns the nodes of the graph in reverse topological order, or an error if the graph
72    /// contains a cycle.
73    pub fn topological_sort<'a>(&'a self) -> Result<Vec<&'a T>, Error<'_, T>> {
74        TarjanSCC::new(self).run()
75    }
76
77    /// Finds the shortest path between the `from` and `to` nodes in this graph, if such a path
78    /// exists. Both `from` and `to` are included in the returned path.
79    pub fn find_shortest_path<'a>(&'a self, from: &T, to: &T) -> Option<Vec<&'a T>> {
80        // Keeps track of edges in the shortest path to each node.
81        //
82        // The key in this map is a node whose shortest path to it is known. The value
83        // is the next-to-last node in the shortest path to the key node.
84        //
85        // For example, if the shortest path from `a` to `b` is `{a, b, c}`, this
86        // map will contain:
87        // (c, b)
88        // (b, a)
89        let mut shortest_path_edges: HashMap<&'a T, &'a T> = HashMap::new();
90        let from = self.0.get_key_value(from).map(|e| e.0)?;
91        let to = self.0.get_key_value(to).map(|e| e.0)?;
92
93        // Nodes which we have found in the graph but have not yet been visited.
94        let mut discovered_nodes = VecDeque::new();
95        discovered_nodes.push_back(from);
96
97        loop {
98            // Visit the first node in the list.
99            let Some(current_node) = discovered_nodes.pop_front() else {
100                // If there are no more nodes to visit, then a shortest path must not exist.
101                return None;
102            };
103            match self.get_targets(current_node) {
104                None => continue,
105                Some(targets) if targets.is_empty() => continue,
106                Some(targets) => {
107                    for target in targets {
108                        // If we haven't yet visited this node, add it to our set of edges and add
109                        // it to the set of nodes we should visit.
110                        if !shortest_path_edges.contains_key(target) {
111                            shortest_path_edges.insert(target, current_node);
112                            discovered_nodes.push_back(target);
113                        }
114                        // If this node is the node we're searching for a path to, then compute the
115                        // path based on the hashmap we've built and return it.
116                        if target == to {
117                            let mut result = vec![target];
118                            let mut path_node: &T = target;
119                            loop {
120                                path_node = shortest_path_edges.get(&path_node).unwrap();
121                                result.push(path_node);
122                                if path_node == from {
123                                    break;
124                                }
125                            }
126                            result.reverse();
127                            return Some(result);
128                        }
129                    }
130                }
131            }
132        }
133    }
134}
135
136impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> Default for DirectedGraph<T> {
137    fn default() -> Self {
138        Self(HashMap::new())
139    }
140}
141
142/// A graph node. Contents contain the nodes mapped by edges from this node.
143#[derive(Eq, PartialEq)]
144pub struct DirectedNode<T: Clone + PartialEq + Hash + Ord + Debug + Display>(HashSet<T>);
145
146impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> DirectedNode<T> {
147    /// Create an empty node.
148    pub fn new() -> Self {
149        Self(HashSet::new())
150    }
151
152    /// Add edge from this node to `target`.
153    pub fn add_target(&mut self, target: T) {
154        self.0.insert(target);
155    }
156}
157
158/// Errors produced by `DirectedGraph`.
159#[derive(Debug)]
160pub enum Error<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> {
161    CyclesDetected(HashSet<Vec<&'a T>>),
162}
163
164impl<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> Error<'a, T> {
165    pub fn format_cycle(&self) -> String {
166        match &self {
167            Error::CyclesDetected(cycles) => {
168                // Copy the cycles into a vector and sort them so our output is stable
169                let mut cycles: Vec<_> = cycles.iter().cloned().collect();
170                cycles.sort_unstable();
171
172                let mut output = "{".to_string();
173                for cycle in cycles.iter() {
174                    output.push_str("{");
175                    for item in cycle.iter() {
176                        output.push_str(&format!("{} -> ", item));
177                    }
178                    if !cycle.is_empty() {
179                        output.truncate(output.len() - 4);
180                    }
181                    output.push_str("}, ");
182                }
183                if !cycles.is_empty() {
184                    output.truncate(output.len() - 2);
185                }
186                output.push_str("}");
187                output
188            }
189        }
190    }
191}
192
193/// Runs the tarjan strongly connected components algorithm on a graph to produce either a reverse
194/// topological sort of the nodes in the graph, or a set of the cycles present in the graph.
195///
196/// Description of algorithm:
197/// https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
198struct TarjanSCC<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> {
199    // Each node is assigned an index in the order we find them. This tracks the next index to use.
200    index: u64,
201    // The mappings between nodes and indices
202    indices: HashMap<&'a T, u64>,
203    // The lowest index (numerically) that's accessible from each node
204    low_links: HashMap<&'a T, u64>,
205    // The set of nodes we're currently in the process of considering
206    stack: Vec<&'a T>,
207    // A set containing the nodes in the stack, so we can more efficiently check if an element is
208    // in the stack
209    on_stack: HashSet<&'a T>,
210    // Detected cycles
211    cycles: HashSet<Vec<&'a T>>,
212    // Nodes sorted by reverse topological order
213    node_order: Vec<&'a T>,
214    // The graph this run will be operating on
215    graph: &'a DirectedGraph<T>,
216}
217
218impl<'a, T: Clone + Hash + Ord + Debug + Display> TarjanSCC<'a, T> {
219    fn new(graph: &'a DirectedGraph<T>) -> Self {
220        TarjanSCC {
221            index: 0,
222            indices: HashMap::new(),
223            low_links: HashMap::new(),
224            stack: Vec::new(),
225            on_stack: HashSet::new(),
226            cycles: HashSet::new(),
227            node_order: Vec::new(),
228            graph,
229        }
230    }
231
232    /// Runs the tarjan scc algorithm. Must only be called once, as it will panic on subsequent
233    /// calls.
234    fn run(&mut self) -> Result<Vec<&'a T>, Error<'a, T>> {
235        // Sort the nodes we visit, to make the output deterministic instead of being based on
236        // whichever node we find first.
237        let mut nodes: Vec<_> = self.graph.0.keys().collect();
238        nodes.sort_unstable();
239        for node in &nodes {
240            // Iterate over each node, visiting each one we haven't already visited. We determine
241            // if a node has been visited by if an index has been assigned to it yet.
242            if !self.indices.contains_key(node) {
243                self.visit(*node);
244            }
245        }
246
247        if self.cycles.is_empty() {
248            Ok(self.node_order.drain(..).collect())
249        } else {
250            Err(Error::CyclesDetected(self.cycles.drain().collect()))
251        }
252    }
253
254    fn visit(&mut self, current_node: &'a T) {
255        // assign a new index for this node, and push it on to the stack
256        self.indices.insert(current_node, self.index);
257        self.low_links.insert(current_node, self.index);
258        self.index += 1;
259        self.stack.push(current_node);
260        self.on_stack.insert(current_node);
261
262        let mut targets: Vec<_> = self.graph.0[current_node].0.iter().collect();
263        targets.sort_unstable();
264
265        for target in targets {
266            if !self.indices.contains_key(target) {
267                // Target has not yet been visited; recurse on it
268                self.visit(target);
269                // Set our lowlink to the min of our lowlink and the target's new lowlink
270                let current_node_low_link = *self.low_links.get(&current_node).unwrap();
271                let target_low_link = *self.low_links.get(&target).unwrap();
272                self.low_links.insert(current_node, min(current_node_low_link, target_low_link));
273            } else if self.on_stack.contains(target) {
274                let current_node_low_link = *self.low_links.get(&current_node).unwrap();
275                let target_index = *self.indices.get(&target).unwrap();
276                self.low_links.insert(current_node, min(current_node_low_link, target_index));
277            }
278        }
279
280        // If current_node is a root node, pop the stack and generate an SCC
281        if self.low_links.get(&current_node) == self.indices.get(&current_node) {
282            let mut strongly_connected_nodes = HashSet::new();
283            let mut stack_node;
284            loop {
285                stack_node = self.stack.pop().unwrap();
286                self.on_stack.remove(&stack_node);
287                strongly_connected_nodes.insert(stack_node);
288                if stack_node == current_node {
289                    break;
290                }
291            }
292            self.insert_cycles_from_scc(
293                &strongly_connected_nodes,
294                stack_node,
295                HashSet::new(),
296                vec![],
297            );
298        }
299        self.node_order.push(current_node);
300    }
301
302    /// Given a set of strongly connected components, computes the cycles present in the set and
303    /// adds those cycles to self.cycles.
304    fn insert_cycles_from_scc(
305        &mut self,
306        scc_nodes: &HashSet<&'a T>,
307        current_node: &'a T,
308        mut visited_nodes: HashSet<&'a T>,
309        mut path: Vec<&'a T>,
310    ) {
311        if visited_nodes.contains(&current_node) {
312            // We've already visited this node, we've got a cycle. Grab all the elements in the
313            // path starting at the first time we visited this node.
314            let (current_node_path_index, _) =
315                path.iter().enumerate().find(|(_, val)| val == &&current_node).unwrap();
316            let mut cycle = path[current_node_path_index..].to_vec();
317
318            // Rotate the cycle such that the lowest value comes first, so that the cycles we
319            // report are consistent.
320            Self::rotate_cycle(&mut cycle);
321            // Push a copy of the first node on to the end, so it's clear that this path ends where
322            // it starts
323            cycle.push(*cycle.first().unwrap());
324            self.cycles.insert(cycle);
325            return;
326        }
327
328        visited_nodes.insert(current_node);
329        path.push(current_node);
330
331        let targets_in_scc: Vec<_> =
332            self.graph.0[&current_node].0.iter().filter(|n| scc_nodes.contains(n)).collect();
333        for target in targets_in_scc {
334            self.insert_cycles_from_scc(scc_nodes, target, visited_nodes.clone(), path.clone());
335        }
336    }
337
338    /// Rotates the cycle such that ordering is maintained and the lowest element comes first. This
339    /// is so that the reported cycles are consistent, as opposed to varying based on which node we
340    /// happened to find first.
341    fn rotate_cycle(cycle: &mut Vec<&'a T>) {
342        let mut lowest_index = 0;
343        let mut lowest_value = cycle.first().unwrap();
344        for (index, node) in cycle.iter().enumerate() {
345            if node < lowest_value {
346                lowest_index = index;
347                lowest_value = node;
348            }
349        }
350        cycle.rotate_left(lowest_index);
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    macro_rules! test_topological_sort {
359        (
360            $(
361                $test_name:ident => {
362                    edges = $edges:expr,
363                    order = $order:expr,
364                },
365            )+
366        ) => {
367            $(
368                #[test]
369                fn $test_name() {
370                    topological_sort_test(&$edges, &$order);
371                }
372            )+
373        }
374    }
375
376    macro_rules! test_cycles {
377        (
378            $(
379                $test_name:ident => {
380                    edges = $edges:expr,
381                    cycles = $cycles:expr,
382                },
383            )+
384        ) => {
385            $(
386                #[test]
387                fn $test_name() {
388                    cycles_test(&$edges, &$cycles);
389                }
390            )+
391        }
392    }
393
394    macro_rules! test_shortest_path {
395        (
396            $(
397                $test_name:ident => {
398                    edges = $edges:expr,
399                    from = $from:expr,
400                    to = $to:expr,
401                    shortest_path = $shortest_path:expr,
402                },
403            )+
404        ) => {
405            $(
406                #[test]
407                fn $test_name() {
408                    shortest_path_test($edges, $from, $to, $shortest_path);
409                }
410            )+
411        }
412    }
413
414    fn topological_sort_test(edges: &[(&'static str, &'static str)], order: &[&'static str]) {
415        let mut graph = DirectedGraph::new();
416        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
417        let actual_order = graph.topological_sort().expect("found a cycle");
418
419        let expected_order: Vec<_> = order.iter().collect();
420        assert_eq!(expected_order, actual_order);
421    }
422
423    fn cycles_test(edges: &[(&'static str, &'static str)], cycles: &[&[&'static str]]) {
424        let mut graph = DirectedGraph::new();
425        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
426        let Error::CyclesDetected(reported_cycles) = graph
427            .topological_sort()
428            .expect_err("topological sort succeeded on a dataset with a cycle");
429
430        let expected_cycles: HashSet<Vec<_>> =
431            cycles.iter().cloned().map(|c| c.iter().collect()).collect();
432        assert_eq!(reported_cycles, expected_cycles);
433    }
434
435    fn shortest_path_test(
436        edges: &[(&'static str, &'static str)],
437        from: &'static str,
438        to: &'static str,
439        expected_shortest_path: Option<&[&'static str]>,
440    ) {
441        let mut graph = DirectedGraph::new();
442        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
443        let actual_shortest_path = graph.find_shortest_path(&from, &to);
444        let expected_shortest_path =
445            expected_shortest_path.map(|path| path.iter().collect::<Vec<_>>());
446        assert_eq!(actual_shortest_path, expected_shortest_path);
447    }
448
449    // Tests with no cycles
450
451    test_topological_sort! {
452        test_empty => {
453            edges = [],
454            order = [],
455        },
456        test_fan_out => {
457            edges = [
458                ("a", "b"),
459                ("b", "c"),
460                ("b", "d"),
461                ("d", "e"),
462            ],
463            order = ["c", "e", "d", "b", "a"],
464        },
465        test_fan_in => {
466            edges = [
467                ("a", "b"),
468                ("b", "d"),
469                ("c", "d"),
470                ("d", "e"),
471            ],
472            order = ["e", "d", "b", "a", "c"],
473        },
474        test_forest => {
475            edges = [
476                ("a", "b"),
477                ("b", "c"),
478                ("d", "e"),
479            ],
480            order = ["c", "b", "a", "e", "d"],
481        },
482        test_diamond => {
483            edges = [
484                ("a", "b"),
485                ("a", "c"),
486                ("b", "d"),
487                ("c", "d"),
488            ],
489            order = ["d", "b", "c", "a"],
490        },
491        test_lattice => {
492            edges = [
493                ("a", "b"),
494                ("a", "c"),
495                ("b", "d"),
496                ("b", "e"),
497                ("c", "d"),
498                ("e", "f"),
499                ("d", "f"),
500            ],
501            order = ["f", "d", "e", "b", "c", "a"],
502        },
503        test_deduped_edge => {
504            edges = [
505                ("a", "b"),
506                ("a", "b"),
507                ("b", "c"),
508            ],
509            order = ["c", "b", "a"],
510        },
511    }
512
513    test_cycles! {
514        // Tests where only 1 SCC contains cycles
515
516        test_cycle_self_referential => {
517            edges = [
518                ("a", "a"),
519            ],
520            cycles = [
521                &["a", "a"],
522            ],
523        },
524        test_cycle_two_nodes => {
525            edges = [
526                ("a", "b"),
527                ("b", "a"),
528            ],
529            cycles = [
530                &["a", "b", "a"],
531            ],
532        },
533        test_cycle_two_nodes_with_path_in => {
534            edges = [
535                ("a", "b"),
536                ("b", "c"),
537                ("c", "d"),
538                ("d", "c"),
539            ],
540            cycles = [
541                &["c", "d", "c"],
542            ],
543        },
544        test_cycle_two_nodes_with_path_out => {
545            edges = [
546                ("a", "b"),
547                ("b", "a"),
548                ("b", "c"),
549                ("c", "d"),
550            ],
551            cycles = [
552                &["a", "b", "a"],
553            ],
554        },
555        test_cycle_three_nodes => {
556            edges = [
557                ("a", "b"),
558                ("b", "c"),
559                ("c", "a"),
560            ],
561            cycles = [
562                &["a", "b", "c", "a"],
563            ],
564        },
565        test_cycle_three_nodes_with_inner_cycle => {
566            edges = [
567                ("a", "b"),
568                ("b", "c"),
569                ("c", "b"),
570                ("c", "a"),
571            ],
572            cycles = [
573                &["a", "b", "c", "a"],
574                &["b", "c", "b"],
575            ],
576        },
577        test_cycle_three_nodes_doubly_linked => {
578            edges = [
579                ("a", "b"),
580                ("b", "a"),
581                ("b", "c"),
582                ("c", "b"),
583                ("c", "a"),
584                ("a", "c"),
585            ],
586            cycles = [
587                &["a", "b", "a"],
588                &["b", "c", "b"],
589                &["a", "c", "a"],
590                &["a", "b", "c", "a"],
591                &["a", "c", "b", "a"],
592            ],
593        },
594        test_cycle_with_inner_cycle => {
595            edges = [
596                ("a", "b"),
597                ("b", "c"),
598                ("c", "a"),
599
600                ("b", "d"),
601                ("d", "e"),
602                ("e", "c"),
603            ],
604            cycles = [
605                &["a", "b", "c", "a"],
606                &["a", "b", "d", "e", "c", "a"],
607            ],
608        },
609        test_two_join_cycles => {
610            edges = [
611                ("a", "b"),
612                ("b", "c"),
613                ("c", "a"),
614                ("b", "d"),
615                ("d", "a"),
616            ],
617            cycles = [
618                &["a", "b", "c", "a"],
619                &["a", "b", "d", "a"],
620            ],
621        },
622        test_cycle_four_nodes_doubly_linked => {
623            edges = [
624                ("a", "b"),
625                ("b", "a"),
626                ("b", "c"),
627                ("c", "b"),
628                ("c", "d"),
629                ("d", "c"),
630                ("d", "a"),
631                ("a", "d"),
632            ],
633            cycles = [
634                &["a", "b", "c", "d", "a"],
635                &["a", "b", "a"],
636                &["a", "d", "c", "b", "a"],
637                &["a", "d", "a"],
638                &["b", "c", "b"],
639                &["c", "d", "c"],
640            ],
641        },
642
643        // Tests with multiple SCCs that contain cycles
644
645        test_cycle_self_referential_islands => {
646            edges = [
647                ("a", "a"),
648                ("b", "b"),
649                ("c", "c"),
650                ("d", "e"),
651            ],
652            cycles = [
653                &["a", "a"],
654                &["b", "b"],
655                &["c", "c"],
656            ],
657        },
658        test_cycle_two_sets_of_two_nodes => {
659            edges = [
660                ("a", "b"),
661                ("b", "a"),
662                ("c", "d"),
663                ("d", "c"),
664            ],
665            cycles = [
666                &["a", "b", "a"],
667                &["c", "d", "c"],
668            ],
669        },
670        test_cycle_two_sets_of_two_nodes_connected => {
671            edges = [
672                ("a", "b"),
673                ("b", "a"),
674                ("c", "d"),
675                ("d", "c"),
676                ("a", "c"),
677            ],
678            cycles = [
679                &["a", "b", "a"],
680                &["c", "d", "c"],
681            ],
682        },
683    }
684
685    test_shortest_path! {
686        test_empty_graph => {
687            edges = &[],
688            from = "a",
689            to = "b",
690            shortest_path = None,
691        },
692        test_two_nodes => {
693            edges = &[
694                ("a", "b"),
695            ],
696            from = "a",
697            to = "b",
698            shortest_path = Some(&["a", "b"]),
699        },
700        test_path_to_self => {
701            edges = &[
702                ("a", "a"),
703            ],
704            from = "a",
705            to = "a",
706            shortest_path = Some(&["a", "a"]),
707        },
708        test_path_to_self_no_edge => {
709            edges = &[
710                ("a", "b"),
711            ],
712            from = "a",
713            to = "a",
714            shortest_path = None,
715        },
716        test_path_three_nodes => {
717            edges = &[
718                ("a", "b"),
719                ("b", "c"),
720            ],
721            from = "a",
722            to = "c",
723            shortest_path = Some(&["a", "b", "c"]),
724        },
725        test_path_multiple_options => {
726            edges = &[
727                ("a", "b"),
728                ("b", "c"),
729                ("a", "c"),
730            ],
731            from = "a",
732            to = "c",
733            shortest_path = Some(&["a", "c"]),
734        },
735        test_path_two_islands => {
736            edges = &[
737                ("a", "b"),
738                ("c", "d"),
739            ],
740            from = "a",
741            to = "d",
742            shortest_path = None,
743        },
744        test_path_with_cycle => {
745            edges = &[
746                ("a", "b"),
747                ("b", "a"),
748            ],
749            from = "a",
750            to = "b",
751            shortest_path = Some(&["a", "b"]),
752        },
753        test_path_with_cycle_2 => {
754            edges = &[
755                ("a", "b"),
756                ("b", "c"),
757                ("c", "b"),
758            ],
759            from = "a",
760            to = "b",
761            shortest_path = Some(&["a", "b"]),
762        },
763        test_path_with_cycle_3 => {
764            edges = &[
765                ("a", "b"),
766                ("b", "c"),
767                ("c", "b"),
768                ("b", "d"),
769                ("d", "e"),
770            ],
771            from = "a",
772            to = "e",
773            shortest_path = Some(&["a", "b", "d", "e"]),
774        },
775    }
776}