directed_graph/
lib.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cmp::min;
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10/// A directed graph, whose nodes contain an identifier of type `T`.
11pub struct DirectedGraph<T: PartialEq + Hash + Copy + Ord + Debug + Display>(
12    HashMap<T, DirectedNode<T>>,
13);
14
15impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedGraph<T> {
16    /// Created a new empty `DirectedGraph`.
17    pub fn new() -> Self {
18        Self(HashMap::new())
19    }
20
21    /// Add an edge to the graph, adding nodes if necessary.
22    pub fn add_edge(&mut self, source: T, target: T) {
23        self.0.entry(source).or_insert_with(DirectedNode::new).add_target(target);
24        self.0.entry(target).or_insert_with(DirectedNode::new);
25    }
26
27    /// Get targets of all edges from this node.
28    pub fn get_targets(&self, id: T) -> Option<&HashSet<T>> {
29        self.0.get(&id).as_ref().map(|node| &node.0)
30    }
31
32    /// Returns the nodes of the graph in reverse topological order, or an error if the graph
33    /// contains a cycle.
34    ///
35    /// TODO: //src/devices/tools/banjo/srt/ast.rs can be migrated to use this feature.
36    pub fn topological_sort(&self) -> Result<Vec<T>, Error<T>> {
37        TarjanSCC::new(self).run()
38    }
39
40    /// Finds the shortest path between the `from` and `to` nodes in this graph, if such a path
41    /// exists. Both `from` and `to` are included in the returned path.
42    pub fn find_shortest_path(&self, from: T, to: T) -> Option<Vec<T>> {
43        // Keeps track of edges in the shortest path to each node.
44        //
45        // The key in this map is a node whose shortest path to it is known. The value
46        // is the next-to-last node in the shortest path to the key node.
47        //
48        // For example, if the shortest path from `a` to `b` is `{a, b, c}`, this
49        // map will contain:
50        // (c, b)
51        // (b, a)
52        let mut shortest_path_edges: HashMap<T, T> = HashMap::new();
53
54        // Nodes which we have found in the graph but have not yet been visited.
55        let mut discovered_nodes = VecDeque::new();
56        discovered_nodes.push_back(from);
57
58        loop {
59            // Visit the first node in the list.
60            let Some(current_node) = discovered_nodes.pop_front() else {
61                // If there are no more nodes to visit, then a shortest path must not exist.
62                return None;
63            };
64            match self.get_targets(current_node) {
65                None => continue,
66                Some(targets) if targets.is_empty() => continue,
67                Some(targets) => {
68                    for target in targets {
69                        // If we haven't yet visited this node, add it to our set of edges and add
70                        // it to the set of nodes we should visit.
71                        if !shortest_path_edges.contains_key(target) {
72                            shortest_path_edges.insert(*target, current_node);
73                            discovered_nodes.push_back(*target);
74                        }
75                        // If this node is the node we're searching for a path to, then compute the
76                        // path based on the hashmap we've built and return it.
77                        if *target == to {
78                            let mut result = vec![*target];
79                            let mut path_node: T = *target;
80                            loop {
81                                path_node = *shortest_path_edges.get(&path_node).unwrap();
82                                result.push(path_node);
83                                if path_node == from {
84                                    break;
85                                }
86                            }
87                            result.reverse();
88                            return Some(result);
89                        }
90                    }
91                }
92            }
93        }
94    }
95}
96
97impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Default for DirectedGraph<T> {
98    fn default() -> Self {
99        Self(HashMap::new())
100    }
101}
102
103/// A graph node. Contents contain the nodes mapped by edges from this node.
104#[derive(Eq, PartialEq)]
105struct DirectedNode<T: PartialEq + Hash + Copy + Ord + Debug + Display>(HashSet<T>);
106
107impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedNode<T> {
108    /// Create an empty node.
109    pub fn new() -> Self {
110        Self(HashSet::new())
111    }
112
113    /// Add edge from this node to `target`.
114    pub fn add_target(&mut self, target: T) {
115        self.0.insert(target);
116    }
117}
118
119/// Errors produced by `DirectedGraph`.
120#[derive(Debug)]
121pub enum Error<T: PartialEq + Hash + Copy + Ord + Debug + Display> {
122    CyclesDetected(HashSet<Vec<T>>),
123}
124
125impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Error<T> {
126    pub fn format_cycle(&self) -> String {
127        match &self {
128            Error::CyclesDetected(cycles) => {
129                // Copy the cycles into a vector and sort them so our output is stable
130                let mut cycles: Vec<_> = cycles.iter().cloned().collect();
131                cycles.sort_unstable();
132
133                let mut output = "{".to_string();
134                for cycle in cycles.iter() {
135                    output.push_str("{");
136                    for item in cycle.iter() {
137                        output.push_str(&format!("{} -> ", item));
138                    }
139                    if !cycle.is_empty() {
140                        output.truncate(output.len() - 4);
141                    }
142                    output.push_str("}, ");
143                }
144                if !cycles.is_empty() {
145                    output.truncate(output.len() - 2);
146                }
147                output.push_str("}");
148                output
149            }
150        }
151    }
152}
153
154/// Runs the tarjan strongly connected components algorithm on a graph to produce either a reverse
155/// topological sort of the nodes in the graph, or a set of the cycles present in the graph.
156///
157/// Description of algorithm:
158/// https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
159struct TarjanSCC<'a, T: PartialEq + Hash + Copy + Ord + Debug + Display> {
160    // Each node is assigned an index in the order we find them. This tracks the next index to use.
161    index: u64,
162    // The mappings between nodes and indices
163    indices: HashMap<T, u64>,
164    // The lowest index (numerically) that's accessible from each node
165    low_links: HashMap<T, u64>,
166    // The set of nodes we're currently in the process of considering
167    stack: Vec<T>,
168    // A set containing the nodes in the stack, so we can more efficiently check if an element is
169    // in the stack
170    on_stack: HashSet<T>,
171    // Detected cycles
172    cycles: HashSet<Vec<T>>,
173    // Nodes sorted by reverse topological order
174    node_order: Vec<T>,
175    // The graph this run will be operating on
176    graph: &'a DirectedGraph<T>,
177}
178
179impl<'a, T: Hash + Copy + Ord + Debug + Display> TarjanSCC<'a, T> {
180    fn new(graph: &'a DirectedGraph<T>) -> Self {
181        TarjanSCC {
182            index: 0,
183            indices: HashMap::new(),
184            low_links: HashMap::new(),
185            stack: Vec::new(),
186            on_stack: HashSet::new(),
187            cycles: HashSet::new(),
188            node_order: Vec::new(),
189            graph,
190        }
191    }
192
193    /// Runs the tarjan scc algorithm. Must only be called once, as it will panic on subsequent
194    /// calls.
195    fn run(mut self) -> Result<Vec<T>, Error<T>> {
196        // Sort the nodes we visit, to make the output deterministic instead of being based on
197        // whichever node we find first.
198        let mut nodes: Vec<_> = self.graph.0.keys().cloned().collect();
199        nodes.sort_unstable();
200        for node in &nodes {
201            // Iterate over each node, visiting each one we haven't already visited. We determine
202            // if a node has been visited by if an index has been assigned to it yet.
203            if !self.indices.contains_key(node) {
204                self.visit(*node);
205            }
206        }
207
208        if self.cycles.is_empty() {
209            Ok(self.node_order.drain(..).collect())
210        } else {
211            Err(Error::CyclesDetected(self.cycles.drain().collect()))
212        }
213    }
214
215    fn visit(&mut self, current_node: T) {
216        // assign a new index for this node, and push it on to the stack
217        self.indices.insert(current_node, self.index);
218        self.low_links.insert(current_node, self.index);
219        self.index += 1;
220        self.stack.push(current_node);
221        self.on_stack.insert(current_node);
222
223        let mut targets: Vec<_> = self.graph.0[&current_node].0.iter().cloned().collect();
224        targets.sort_unstable();
225
226        for target in &targets {
227            if !self.indices.contains_key(target) {
228                // Target has not yet been visited; recurse on it
229                self.visit(*target);
230                // Set our lowlink to the min of our lowlink and the target's new lowlink
231                let current_node_low_link = *self.low_links.get(&current_node).unwrap();
232                let target_low_link = *self.low_links.get(&target).unwrap();
233                self.low_links.insert(current_node, min(current_node_low_link, target_low_link));
234            } else if self.on_stack.contains(target) {
235                let current_node_low_link = *self.low_links.get(&current_node).unwrap();
236                let target_index = *self.indices.get(&target).unwrap();
237                self.low_links.insert(current_node, min(current_node_low_link, target_index));
238            }
239        }
240
241        // If current_node is a root node, pop the stack and generate an SCC
242        if self.low_links.get(&current_node) == self.indices.get(&current_node) {
243            let mut strongly_connected_nodes = HashSet::new();
244            let mut stack_node;
245            loop {
246                stack_node = self.stack.pop().unwrap();
247                self.on_stack.remove(&stack_node);
248                strongly_connected_nodes.insert(stack_node);
249                if stack_node == current_node {
250                    break;
251                }
252            }
253            self.insert_cycles_from_scc(
254                &strongly_connected_nodes,
255                stack_node,
256                HashSet::new(),
257                vec![],
258            );
259        }
260        self.node_order.push(current_node);
261    }
262
263    /// Given a set of strongly connected components, computes the cycles present in the set and
264    /// adds those cycles to self.cycles.
265    fn insert_cycles_from_scc(
266        &mut self,
267        scc_nodes: &HashSet<T>,
268        current_node: T,
269        mut visited_nodes: HashSet<T>,
270        mut path: Vec<T>,
271    ) {
272        if visited_nodes.contains(&current_node) {
273            // We've already visited this node, we've got a cycle. Grab all the elements in the
274            // path starting at the first time we visited this node.
275            let (current_node_path_index, _) =
276                path.iter().enumerate().find(|(_, val)| val == &&current_node).unwrap();
277            let mut cycle = path[current_node_path_index..].to_vec();
278
279            // Rotate the cycle such that the lowest value comes first, so that the cycles we
280            // report are consistent.
281            Self::rotate_cycle(&mut cycle);
282            // Push a copy of the first node on to the end, so it's clear that this path ends where
283            // it starts
284            cycle.push(*cycle.first().unwrap());
285            self.cycles.insert(cycle);
286            return;
287        }
288
289        visited_nodes.insert(current_node);
290        path.push(current_node);
291
292        let targets_in_scc: Vec<_> =
293            self.graph.0[&current_node].0.iter().filter(|n| scc_nodes.contains(n)).collect();
294        for target in targets_in_scc {
295            self.insert_cycles_from_scc(scc_nodes, *target, visited_nodes.clone(), path.clone());
296        }
297    }
298
299    /// Rotates the cycle such that ordering is maintained and the lowest element comes first. This
300    /// is so that the reported cycles are consistent, as opposed to varying based on which node we
301    /// happened to find first.
302    fn rotate_cycle(cycle: &mut Vec<T>) {
303        let mut lowest_index = 0;
304        let mut lowest_value = cycle.first().unwrap();
305        for (index, node) in cycle.iter().enumerate() {
306            if node < lowest_value {
307                lowest_index = index;
308                lowest_value = node;
309            }
310        }
311        cycle.rotate_left(lowest_index);
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    macro_rules! test_topological_sort {
320        (
321            $(
322                $test_name:ident => {
323                    edges = $edges:expr,
324                    order = $order:expr,
325                },
326            )+
327        ) => {
328            $(
329                #[test]
330                fn $test_name() {
331                    topological_sort_test(&$edges, &$order);
332                }
333            )+
334        }
335    }
336
337    macro_rules! test_cycles {
338        (
339            $(
340                $test_name:ident => {
341                    edges = $edges:expr,
342                    cycles = $cycles:expr,
343                },
344            )+
345        ) => {
346            $(
347                #[test]
348                fn $test_name() {
349                    cycles_test(&$edges, &$cycles);
350                }
351            )+
352        }
353    }
354
355    macro_rules! test_shortest_path {
356        (
357            $(
358                $test_name:ident => {
359                    edges = $edges:expr,
360                    from = $from:expr,
361                    to = $to:expr,
362                    shortest_path = $shortest_path:expr,
363                },
364            )+
365        ) => {
366            $(
367                #[test]
368                fn $test_name() {
369                    shortest_path_test($edges, $from, $to, $shortest_path);
370                }
371            )+
372        }
373    }
374
375    fn topological_sort_test(edges: &[(&'static str, &'static str)], order: &[&'static str]) {
376        let mut graph = DirectedGraph::new();
377        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
378        let actual_order = graph.topological_sort().expect("found a cycle");
379
380        let expected_order: Vec<_> = order.iter().cloned().collect();
381        assert_eq!(expected_order, actual_order);
382    }
383
384    fn cycles_test(edges: &[(&'static str, &'static str)], cycles: &[&[&'static str]]) {
385        let mut graph = DirectedGraph::new();
386        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
387        let Error::CyclesDetected(reported_cycles) = graph
388            .topological_sort()
389            .expect_err("topological sort succeeded on a dataset with a cycle");
390
391        let expected_cycles: HashSet<Vec<_>> =
392            cycles.iter().cloned().map(|c| c.iter().cloned().collect()).collect();
393        assert_eq!(reported_cycles, expected_cycles);
394    }
395
396    fn shortest_path_test(
397        edges: &[(&'static str, &'static str)],
398        from: &'static str,
399        to: &'static str,
400        expected_shortest_path: Option<&[&'static str]>,
401    ) {
402        let mut graph = DirectedGraph::new();
403        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
404        let actual_shortest_path = graph.find_shortest_path(from, to);
405        let expected_shortest_path =
406            expected_shortest_path.map(|path| path.iter().cloned().collect::<Vec<_>>());
407        assert_eq!(actual_shortest_path, expected_shortest_path);
408    }
409
410    // Tests with no cycles
411
412    test_topological_sort! {
413        test_empty => {
414            edges = [],
415            order = [],
416        },
417        test_fan_out => {
418            edges = [
419                ("a", "b"),
420                ("b", "c"),
421                ("b", "d"),
422                ("d", "e"),
423            ],
424            order = ["c", "e", "d", "b", "a"],
425        },
426        test_fan_in => {
427            edges = [
428                ("a", "b"),
429                ("b", "d"),
430                ("c", "d"),
431                ("d", "e"),
432            ],
433            order = ["e", "d", "b", "a", "c"],
434        },
435        test_forest => {
436            edges = [
437                ("a", "b"),
438                ("b", "c"),
439                ("d", "e"),
440            ],
441            order = ["c", "b", "a", "e", "d"],
442        },
443        test_diamond => {
444            edges = [
445                ("a", "b"),
446                ("a", "c"),
447                ("b", "d"),
448                ("c", "d"),
449            ],
450            order = ["d", "b", "c", "a"],
451        },
452        test_lattice => {
453            edges = [
454                ("a", "b"),
455                ("a", "c"),
456                ("b", "d"),
457                ("b", "e"),
458                ("c", "d"),
459                ("e", "f"),
460                ("d", "f"),
461            ],
462            order = ["f", "d", "e", "b", "c", "a"],
463        },
464        test_deduped_edge => {
465            edges = [
466                ("a", "b"),
467                ("a", "b"),
468                ("b", "c"),
469            ],
470            order = ["c", "b", "a"],
471        },
472    }
473
474    test_cycles! {
475        // Tests where only 1 SCC contains cycles
476
477        test_cycle_self_referential => {
478            edges = [
479                ("a", "a"),
480            ],
481            cycles = [
482                &["a", "a"],
483            ],
484        },
485        test_cycle_two_nodes => {
486            edges = [
487                ("a", "b"),
488                ("b", "a"),
489            ],
490            cycles = [
491                &["a", "b", "a"],
492            ],
493        },
494        test_cycle_two_nodes_with_path_in => {
495            edges = [
496                ("a", "b"),
497                ("b", "c"),
498                ("c", "d"),
499                ("d", "c"),
500            ],
501            cycles = [
502                &["c", "d", "c"],
503            ],
504        },
505        test_cycle_two_nodes_with_path_out => {
506            edges = [
507                ("a", "b"),
508                ("b", "a"),
509                ("b", "c"),
510                ("c", "d"),
511            ],
512            cycles = [
513                &["a", "b", "a"],
514            ],
515        },
516        test_cycle_three_nodes => {
517            edges = [
518                ("a", "b"),
519                ("b", "c"),
520                ("c", "a"),
521            ],
522            cycles = [
523                &["a", "b", "c", "a"],
524            ],
525        },
526        test_cycle_three_nodes_with_inner_cycle => {
527            edges = [
528                ("a", "b"),
529                ("b", "c"),
530                ("c", "b"),
531                ("c", "a"),
532            ],
533            cycles = [
534                &["a", "b", "c", "a"],
535                &["b", "c", "b"],
536            ],
537        },
538        test_cycle_three_nodes_doubly_linked => {
539            edges = [
540                ("a", "b"),
541                ("b", "a"),
542                ("b", "c"),
543                ("c", "b"),
544                ("c", "a"),
545                ("a", "c"),
546            ],
547            cycles = [
548                &["a", "b", "a"],
549                &["b", "c", "b"],
550                &["a", "c", "a"],
551                &["a", "b", "c", "a"],
552                &["a", "c", "b", "a"],
553            ],
554        },
555        test_cycle_with_inner_cycle => {
556            edges = [
557                ("a", "b"),
558                ("b", "c"),
559                ("c", "a"),
560
561                ("b", "d"),
562                ("d", "e"),
563                ("e", "c"),
564            ],
565            cycles = [
566                &["a", "b", "c", "a"],
567                &["a", "b", "d", "e", "c", "a"],
568            ],
569        },
570        test_two_join_cycles => {
571            edges = [
572                ("a", "b"),
573                ("b", "c"),
574                ("c", "a"),
575                ("b", "d"),
576                ("d", "a"),
577            ],
578            cycles = [
579                &["a", "b", "c", "a"],
580                &["a", "b", "d", "a"],
581            ],
582        },
583        test_cycle_four_nodes_doubly_linked => {
584            edges = [
585                ("a", "b"),
586                ("b", "a"),
587                ("b", "c"),
588                ("c", "b"),
589                ("c", "d"),
590                ("d", "c"),
591                ("d", "a"),
592                ("a", "d"),
593            ],
594            cycles = [
595                &["a", "b", "c", "d", "a"],
596                &["a", "b", "a"],
597                &["a", "d", "c", "b", "a"],
598                &["a", "d", "a"],
599                &["b", "c", "b"],
600                &["c", "d", "c"],
601            ],
602        },
603
604        // Tests with multiple SCCs that contain cycles
605
606        test_cycle_self_referential_islands => {
607            edges = [
608                ("a", "a"),
609                ("b", "b"),
610                ("c", "c"),
611                ("d", "e"),
612            ],
613            cycles = [
614                &["a", "a"],
615                &["b", "b"],
616                &["c", "c"],
617            ],
618        },
619        test_cycle_two_sets_of_two_nodes => {
620            edges = [
621                ("a", "b"),
622                ("b", "a"),
623                ("c", "d"),
624                ("d", "c"),
625            ],
626            cycles = [
627                &["a", "b", "a"],
628                &["c", "d", "c"],
629            ],
630        },
631        test_cycle_two_sets_of_two_nodes_connected => {
632            edges = [
633                ("a", "b"),
634                ("b", "a"),
635                ("c", "d"),
636                ("d", "c"),
637                ("a", "c"),
638            ],
639            cycles = [
640                &["a", "b", "a"],
641                &["c", "d", "c"],
642            ],
643        },
644    }
645
646    test_shortest_path! {
647        test_empty_graph => {
648            edges = &[],
649            from = "a",
650            to = "b",
651            shortest_path = None,
652        },
653        test_two_nodes => {
654            edges = &[
655                ("a", "b"),
656            ],
657            from = "a",
658            to = "b",
659            shortest_path = Some(&["a", "b"]),
660        },
661        test_path_to_self => {
662            edges = &[
663                ("a", "a"),
664            ],
665            from = "a",
666            to = "a",
667            shortest_path = Some(&["a", "a"]),
668        },
669        test_path_to_self_no_edge => {
670            edges = &[
671                ("a", "b"),
672            ],
673            from = "a",
674            to = "a",
675            shortest_path = None,
676        },
677        test_path_three_nodes => {
678            edges = &[
679                ("a", "b"),
680                ("b", "c"),
681            ],
682            from = "a",
683            to = "c",
684            shortest_path = Some(&["a", "b", "c"]),
685        },
686        test_path_multiple_options => {
687            edges = &[
688                ("a", "b"),
689                ("b", "c"),
690                ("a", "c"),
691            ],
692            from = "a",
693            to = "c",
694            shortest_path = Some(&["a", "c"]),
695        },
696        test_path_two_islands => {
697            edges = &[
698                ("a", "b"),
699                ("c", "d"),
700            ],
701            from = "a",
702            to = "d",
703            shortest_path = None,
704        },
705        test_path_with_cycle => {
706            edges = &[
707                ("a", "b"),
708                ("b", "a"),
709            ],
710            from = "a",
711            to = "b",
712            shortest_path = Some(&["a", "b"]),
713        },
714        test_path_with_cycle_2 => {
715            edges = &[
716                ("a", "b"),
717                ("b", "c"),
718                ("c", "b"),
719            ],
720            from = "a",
721            to = "b",
722            shortest_path = Some(&["a", "b"]),
723        },
724        test_path_with_cycle_3 => {
725            edges = &[
726                ("a", "b"),
727                ("b", "c"),
728                ("c", "b"),
729                ("b", "d"),
730                ("d", "e"),
731            ],
732            from = "a",
733            to = "e",
734            shortest_path = Some(&["a", "b", "d", "e"]),
735        },
736    }
737}