1use std::cmp::min;
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9use std::iter;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct DirectedGraph<T: Clone + PartialEq + Hash + Ord + Debug + Display>(
14 HashMap<T, DirectedNode<T>>,
15);
16
17impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> DirectedGraph<T> {
18 pub fn new() -> Self {
20 Self(HashMap::new())
21 }
22
23 pub fn retain(&mut self, predicate: impl Fn(&T, &T) -> bool) {
25 let mut dangling_nodes = HashSet::new();
26 for (k, set) in &mut self.0 {
27 set.0.retain(|v| predicate(k, v));
28 if set.0.is_empty() {
29 dangling_nodes.insert(k.clone());
30 }
31 }
32 for (_, set) in &self.0 {
34 for v in &set.0 {
35 let _ = dangling_nodes.remove(v);
36 }
37 }
38 self.0.retain(|k, _| !dangling_nodes.contains(k));
39 }
40
41 pub fn extend(&mut self, other: impl IntoIterator<Item = (T, T)>) {
42 for (a, b) in other.into_iter() {
43 self.add_edge(a, b);
44 }
45 }
46
47 pub fn add_edge(&mut self, source: T, target: T) {
49 self.0.entry(source).or_insert_with(DirectedNode::new).add_target(target.clone());
50 self.0.entry(target).or_insert_with(DirectedNode::new);
51 }
52
53 pub fn get_targets<'a>(&'a self, id: &T) -> Option<&'a HashSet<T>> {
55 self.0.get(id).as_ref().map(|node| &node.0)
56 }
57
58 pub fn get_closure<'a>(&'a self, start: &T) -> HashSet<&'a T> {
62 let Some((start, _)) = self.0.get_key_value(start) else {
63 return HashSet::new();
64 };
65 let mut reverse_deps: HashMap<&T, Vec<&T>> = HashMap::new();
66 for (source, targets) in &self.0 {
67 for target in &targets.0 {
68 reverse_deps.entry(target).or_default().push(source);
69 }
70 }
71
72 let mut closure = HashSet::new();
73 let mut to_visit = VecDeque::new();
74
75 closure.insert(start);
76 to_visit.push_back(start);
77
78 while let Some(current_node) = to_visit.pop_front() {
79 if let Some(parents) = reverse_deps.get(¤t_node) {
80 for parent in parents {
81 if closure.insert(*parent) {
82 to_visit.push_back(*parent);
83 }
84 }
85 }
86 }
87
88 closure
89 }
90
91 pub fn add_node(&mut self, source: T) {
94 self.0.entry(source).or_insert_with(DirectedNode::new);
95 }
96
97 pub fn topological_sort<'a>(&'a self) -> Result<Vec<&'a T>, Error<'_, T>> {
100 TarjanSCC::new(self).run()
101 }
102
103 pub fn find_shortest_path<'a>(&'a self, from: &T, to: &T) -> Option<Vec<&'a T>> {
106 let mut shortest_path_edges: HashMap<&'a T, &'a T> = HashMap::new();
116 let from = self.0.get_key_value(from).map(|e| e.0)?;
117 let to = self.0.get_key_value(to).map(|e| e.0)?;
118
119 let mut discovered_nodes = VecDeque::new();
121 discovered_nodes.push_back(from);
122
123 loop {
124 let Some(current_node) = discovered_nodes.pop_front() else {
126 return None;
128 };
129 match self.get_targets(current_node) {
130 None => continue,
131 Some(targets) if targets.is_empty() => continue,
132 Some(targets) => {
133 for target in targets {
134 if !shortest_path_edges.contains_key(target) {
137 shortest_path_edges.insert(target, current_node);
138 discovered_nodes.push_back(target);
139 }
140 if target == to {
143 let mut result = vec![target];
144 let mut path_node: &T = target;
145 loop {
146 path_node = shortest_path_edges.get(&path_node).unwrap();
147 result.push(path_node);
148 if path_node == from {
149 break;
150 }
151 }
152 result.reverse();
153 return Some(result);
154 }
155 }
156 }
157 }
158 }
159 }
160}
161
162impl<T: Clone + PartialEq + Hash + Ord + Debug + Display, const N: usize> From<[(T, T); N]>
163 for DirectedGraph<T>
164{
165 fn from(items: [(T, T); N]) -> Self {
166 let mut this = Self::new();
167 for (a, b) in IntoIterator::into_iter(items) {
168 this.add_edge(a, b);
169 }
170 this
171 }
172}
173
174impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> From<Box<[(T, T)]>> for DirectedGraph<T> {
175 fn from(items: Box<[(T, T)]>) -> Self {
176 let mut this = Self::new();
177 for (a, b) in IntoIterator::into_iter(items) {
178 this.add_edge(a, b);
179 }
180 this
181 }
182}
183
184impl<T: Clone + PartialEq + Hash + Ord + Debug + Display + 'static> IntoIterator
185 for DirectedGraph<T>
186{
187 type Item = (T, T);
188 type IntoIter = Box<dyn Iterator<Item = (T, T)>>;
189
190 fn into_iter(self) -> Self::IntoIter {
191 Box::new(
192 self.0
193 .into_iter()
194 .map(|(k, set)| iter::zip(iter::repeat(k), set.0.into_iter()))
195 .flatten(),
196 )
197 }
198}
199
200impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> Default for DirectedGraph<T> {
201 fn default() -> Self {
202 Self(HashMap::new())
203 }
204}
205
206#[derive(Clone, Debug, Eq, PartialEq)]
208pub struct DirectedNode<T: Clone + PartialEq + Hash + Ord + Debug + Display>(HashSet<T>);
209
210impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> DirectedNode<T> {
211 pub fn new() -> Self {
213 Self(HashSet::new())
214 }
215
216 pub fn add_target(&mut self, target: T) {
218 self.0.insert(target);
219 }
220}
221
222#[derive(Debug)]
224pub enum Error<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> {
225 CyclesDetected(HashSet<Vec<&'a T>>),
226}
227
228impl<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> Error<'a, T> {
229 pub fn format_cycle(&self) -> String {
230 match &self {
231 Error::CyclesDetected(cycles) => {
232 let mut cycles: Vec<_> = cycles.iter().cloned().collect();
234 cycles.sort_unstable();
235
236 let mut output = "{".to_string();
237 for cycle in cycles.iter() {
238 output.push_str("{");
239 for item in cycle.iter() {
240 output.push_str(&format!("{} -> ", item));
241 }
242 if !cycle.is_empty() {
243 output.truncate(output.len() - 4);
244 }
245 output.push_str("}, ");
246 }
247 if !cycles.is_empty() {
248 output.truncate(output.len() - 2);
249 }
250 output.push_str("}");
251 output
252 }
253 }
254 }
255}
256
257struct TarjanSCC<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> {
263 index: u64,
265 indices: HashMap<&'a T, u64>,
267 low_links: HashMap<&'a T, u64>,
269 stack: Vec<&'a T>,
271 on_stack: HashSet<&'a T>,
274 cycles: HashSet<Vec<&'a T>>,
276 node_order: Vec<&'a T>,
278 graph: &'a DirectedGraph<T>,
280}
281
282impl<'a, T: Clone + Hash + Ord + Debug + Display> TarjanSCC<'a, T> {
283 fn new(graph: &'a DirectedGraph<T>) -> Self {
284 TarjanSCC {
285 index: 0,
286 indices: HashMap::new(),
287 low_links: HashMap::new(),
288 stack: Vec::new(),
289 on_stack: HashSet::new(),
290 cycles: HashSet::new(),
291 node_order: Vec::new(),
292 graph,
293 }
294 }
295
296 fn run(&mut self) -> Result<Vec<&'a T>, Error<'a, T>> {
299 let mut nodes: Vec<_> = self.graph.0.keys().collect();
302 nodes.sort_unstable();
303 for node in &nodes {
304 if !self.indices.contains_key(node) {
307 self.visit(*node);
308 }
309 }
310
311 if self.cycles.is_empty() {
312 Ok(self.node_order.drain(..).collect())
313 } else {
314 Err(Error::CyclesDetected(self.cycles.drain().collect()))
315 }
316 }
317
318 fn visit(&mut self, current_node: &'a T) {
319 self.indices.insert(current_node, self.index);
321 self.low_links.insert(current_node, self.index);
322 self.index += 1;
323 self.stack.push(current_node);
324 self.on_stack.insert(current_node);
325
326 let mut targets: Vec<_> = self.graph.0[current_node].0.iter().collect();
327 targets.sort_unstable();
328
329 for target in targets {
330 if !self.indices.contains_key(target) {
331 self.visit(target);
333 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
335 let target_low_link = *self.low_links.get(&target).unwrap();
336 self.low_links.insert(current_node, min(current_node_low_link, target_low_link));
337 } else if self.on_stack.contains(target) {
338 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
339 let target_index = *self.indices.get(&target).unwrap();
340 self.low_links.insert(current_node, min(current_node_low_link, target_index));
341 }
342 }
343
344 if self.low_links.get(¤t_node) == self.indices.get(¤t_node) {
346 let mut strongly_connected_nodes = HashSet::new();
347 let mut stack_node;
348 loop {
349 stack_node = self.stack.pop().unwrap();
350 self.on_stack.remove(&stack_node);
351 strongly_connected_nodes.insert(stack_node);
352 if stack_node == current_node {
353 break;
354 }
355 }
356 self.insert_cycles_from_scc(
357 &strongly_connected_nodes,
358 stack_node,
359 HashSet::new(),
360 vec![],
361 );
362 }
363 self.node_order.push(current_node);
364 }
365
366 fn insert_cycles_from_scc(
369 &mut self,
370 scc_nodes: &HashSet<&'a T>,
371 current_node: &'a T,
372 mut visited_nodes: HashSet<&'a T>,
373 mut path: Vec<&'a T>,
374 ) {
375 if visited_nodes.contains(¤t_node) {
376 let (current_node_path_index, _) =
379 path.iter().enumerate().find(|(_, val)| val == &¤t_node).unwrap();
380 let mut cycle = path[current_node_path_index..].to_vec();
381
382 Self::rotate_cycle(&mut cycle);
385 cycle.push(*cycle.first().unwrap());
388 self.cycles.insert(cycle);
389 return;
390 }
391
392 visited_nodes.insert(current_node);
393 path.push(current_node);
394
395 let targets_in_scc: Vec<_> =
396 self.graph.0[¤t_node].0.iter().filter(|n| scc_nodes.contains(n)).collect();
397 for target in targets_in_scc {
398 self.insert_cycles_from_scc(scc_nodes, target, visited_nodes.clone(), path.clone());
399 }
400 }
401
402 fn rotate_cycle(cycle: &mut Vec<&'a T>) {
406 let mut lowest_index = 0;
407 let mut lowest_value = cycle.first().unwrap();
408 for (index, node) in cycle.iter().enumerate() {
409 if node < lowest_value {
410 lowest_index = index;
411 lowest_value = node;
412 }
413 }
414 cycle.rotate_left(lowest_index);
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 macro_rules! test_topological_sort {
423 (
424 $(
425 $test_name:ident => {
426 edges = $edges:expr,
427 order = $order:expr,
428 },
429 )+
430 ) => {
431 $(
432 #[test]
433 fn $test_name() {
434 topological_sort_test(&$edges, &$order);
435 }
436 )+
437 }
438 }
439
440 macro_rules! test_cycles {
441 (
442 $(
443 $test_name:ident => {
444 edges = $edges:expr,
445 cycles = $cycles:expr,
446 },
447 )+
448 ) => {
449 $(
450 #[test]
451 fn $test_name() {
452 cycles_test(&$edges, &$cycles);
453 }
454 )+
455 }
456 }
457
458 macro_rules! test_shortest_path {
459 (
460 $(
461 $test_name:ident => {
462 edges = $edges:expr,
463 from = $from:expr,
464 to = $to:expr,
465 shortest_path = $shortest_path:expr,
466 },
467 )+
468 ) => {
469 $(
470 #[test]
471 fn $test_name() {
472 shortest_path_test($edges, $from, $to, $shortest_path);
473 }
474 )+
475 }
476 }
477
478 fn topological_sort_test(edges: &[(&'static str, &'static str)], order: &[&'static str]) {
479 let mut graph = DirectedGraph::new();
480 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
481 let actual_order = graph.topological_sort().expect("found a cycle");
482
483 let expected_order: Vec<_> = order.iter().collect();
484 assert_eq!(expected_order, actual_order);
485 }
486
487 fn cycles_test(edges: &[(&'static str, &'static str)], cycles: &[&[&'static str]]) {
488 let mut graph = DirectedGraph::new();
489 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
490 let Error::CyclesDetected(reported_cycles) = graph
491 .topological_sort()
492 .expect_err("topological sort succeeded on a dataset with a cycle");
493
494 let expected_cycles: HashSet<Vec<_>> =
495 cycles.iter().cloned().map(|c| c.iter().collect()).collect();
496 assert_eq!(reported_cycles, expected_cycles);
497 }
498
499 fn shortest_path_test(
500 edges: &[(&'static str, &'static str)],
501 from: &'static str,
502 to: &'static str,
503 expected_shortest_path: Option<&[&'static str]>,
504 ) {
505 let mut graph = DirectedGraph::new();
506 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
507 let actual_shortest_path = graph.find_shortest_path(&from, &to);
508 let expected_shortest_path =
509 expected_shortest_path.map(|path| path.iter().collect::<Vec<_>>());
510 assert_eq!(actual_shortest_path, expected_shortest_path);
511 }
512
513 #[test]
514 fn operations() {
515 fn assert_elements(
516 graph: &DirectedGraph<&'static str>,
517 expected: &[(&'static str, &'static str)],
518 ) {
519 let mut elements: Vec<_> = graph.clone().into_iter().collect();
520 elements.sort_unstable();
521 assert_eq!(&elements, expected);
522 }
523
524 let mut graph = DirectedGraph::new();
525 graph.add_edge("a", "b");
526 assert_elements(&graph, &[("a", "b")]);
527
528 graph.extend(vec![("c", "b"), ("a", "e")]);
529 assert_elements(&graph, &[("a", "b"), ("a", "e"), ("c", "b")]);
530
531 graph.retain(|k, v| *k == "c" || *v != "e");
532 assert_elements(&graph, &[("a", "b"), ("c", "b")]);
533
534 graph.retain(|k, _| *k != "a");
535 assert_elements(&graph, &[("c", "b")]);
536 let mut expected = DirectedGraph::new();
539 expected.add_edge("c", "b");
540 assert_eq!(graph, expected);
541 }
542
543 test_topological_sort! {
546 test_empty => {
547 edges = [],
548 order = [],
549 },
550 test_fan_out => {
551 edges = [
552 ("a", "b"),
553 ("b", "c"),
554 ("b", "d"),
555 ("d", "e"),
556 ],
557 order = ["c", "e", "d", "b", "a"],
558 },
559 test_fan_in => {
560 edges = [
561 ("a", "b"),
562 ("b", "d"),
563 ("c", "d"),
564 ("d", "e"),
565 ],
566 order = ["e", "d", "b", "a", "c"],
567 },
568 test_forest => {
569 edges = [
570 ("a", "b"),
571 ("b", "c"),
572 ("d", "e"),
573 ],
574 order = ["c", "b", "a", "e", "d"],
575 },
576 test_diamond => {
577 edges = [
578 ("a", "b"),
579 ("a", "c"),
580 ("b", "d"),
581 ("c", "d"),
582 ],
583 order = ["d", "b", "c", "a"],
584 },
585 test_lattice => {
586 edges = [
587 ("a", "b"),
588 ("a", "c"),
589 ("b", "d"),
590 ("b", "e"),
591 ("c", "d"),
592 ("e", "f"),
593 ("d", "f"),
594 ],
595 order = ["f", "d", "e", "b", "c", "a"],
596 },
597 test_deduped_edge => {
598 edges = [
599 ("a", "b"),
600 ("a", "b"),
601 ("b", "c"),
602 ],
603 order = ["c", "b", "a"],
604 },
605 }
606
607 test_cycles! {
608 test_cycle_self_referential => {
611 edges = [
612 ("a", "a"),
613 ],
614 cycles = [
615 &["a", "a"],
616 ],
617 },
618 test_cycle_two_nodes => {
619 edges = [
620 ("a", "b"),
621 ("b", "a"),
622 ],
623 cycles = [
624 &["a", "b", "a"],
625 ],
626 },
627 test_cycle_two_nodes_with_path_in => {
628 edges = [
629 ("a", "b"),
630 ("b", "c"),
631 ("c", "d"),
632 ("d", "c"),
633 ],
634 cycles = [
635 &["c", "d", "c"],
636 ],
637 },
638 test_cycle_two_nodes_with_path_out => {
639 edges = [
640 ("a", "b"),
641 ("b", "a"),
642 ("b", "c"),
643 ("c", "d"),
644 ],
645 cycles = [
646 &["a", "b", "a"],
647 ],
648 },
649 test_cycle_three_nodes => {
650 edges = [
651 ("a", "b"),
652 ("b", "c"),
653 ("c", "a"),
654 ],
655 cycles = [
656 &["a", "b", "c", "a"],
657 ],
658 },
659 test_cycle_three_nodes_with_inner_cycle => {
660 edges = [
661 ("a", "b"),
662 ("b", "c"),
663 ("c", "b"),
664 ("c", "a"),
665 ],
666 cycles = [
667 &["a", "b", "c", "a"],
668 &["b", "c", "b"],
669 ],
670 },
671 test_cycle_three_nodes_doubly_linked => {
672 edges = [
673 ("a", "b"),
674 ("b", "a"),
675 ("b", "c"),
676 ("c", "b"),
677 ("c", "a"),
678 ("a", "c"),
679 ],
680 cycles = [
681 &["a", "b", "a"],
682 &["b", "c", "b"],
683 &["a", "c", "a"],
684 &["a", "b", "c", "a"],
685 &["a", "c", "b", "a"],
686 ],
687 },
688 test_cycle_with_inner_cycle => {
689 edges = [
690 ("a", "b"),
691 ("b", "c"),
692 ("c", "a"),
693
694 ("b", "d"),
695 ("d", "e"),
696 ("e", "c"),
697 ],
698 cycles = [
699 &["a", "b", "c", "a"],
700 &["a", "b", "d", "e", "c", "a"],
701 ],
702 },
703 test_two_join_cycles => {
704 edges = [
705 ("a", "b"),
706 ("b", "c"),
707 ("c", "a"),
708 ("b", "d"),
709 ("d", "a"),
710 ],
711 cycles = [
712 &["a", "b", "c", "a"],
713 &["a", "b", "d", "a"],
714 ],
715 },
716 test_cycle_four_nodes_doubly_linked => {
717 edges = [
718 ("a", "b"),
719 ("b", "a"),
720 ("b", "c"),
721 ("c", "b"),
722 ("c", "d"),
723 ("d", "c"),
724 ("d", "a"),
725 ("a", "d"),
726 ],
727 cycles = [
728 &["a", "b", "c", "d", "a"],
729 &["a", "b", "a"],
730 &["a", "d", "c", "b", "a"],
731 &["a", "d", "a"],
732 &["b", "c", "b"],
733 &["c", "d", "c"],
734 ],
735 },
736
737 test_cycle_self_referential_islands => {
740 edges = [
741 ("a", "a"),
742 ("b", "b"),
743 ("c", "c"),
744 ("d", "e"),
745 ],
746 cycles = [
747 &["a", "a"],
748 &["b", "b"],
749 &["c", "c"],
750 ],
751 },
752 test_cycle_two_sets_of_two_nodes => {
753 edges = [
754 ("a", "b"),
755 ("b", "a"),
756 ("c", "d"),
757 ("d", "c"),
758 ],
759 cycles = [
760 &["a", "b", "a"],
761 &["c", "d", "c"],
762 ],
763 },
764 test_cycle_two_sets_of_two_nodes_connected => {
765 edges = [
766 ("a", "b"),
767 ("b", "a"),
768 ("c", "d"),
769 ("d", "c"),
770 ("a", "c"),
771 ],
772 cycles = [
773 &["a", "b", "a"],
774 &["c", "d", "c"],
775 ],
776 },
777 }
778
779 test_shortest_path! {
780 test_empty_graph => {
781 edges = &[],
782 from = "a",
783 to = "b",
784 shortest_path = None,
785 },
786 test_two_nodes => {
787 edges = &[
788 ("a", "b"),
789 ],
790 from = "a",
791 to = "b",
792 shortest_path = Some(&["a", "b"]),
793 },
794 test_path_to_self => {
795 edges = &[
796 ("a", "a"),
797 ],
798 from = "a",
799 to = "a",
800 shortest_path = Some(&["a", "a"]),
801 },
802 test_path_to_self_no_edge => {
803 edges = &[
804 ("a", "b"),
805 ],
806 from = "a",
807 to = "a",
808 shortest_path = None,
809 },
810 test_path_three_nodes => {
811 edges = &[
812 ("a", "b"),
813 ("b", "c"),
814 ],
815 from = "a",
816 to = "c",
817 shortest_path = Some(&["a", "b", "c"]),
818 },
819 test_path_multiple_options => {
820 edges = &[
821 ("a", "b"),
822 ("b", "c"),
823 ("a", "c"),
824 ],
825 from = "a",
826 to = "c",
827 shortest_path = Some(&["a", "c"]),
828 },
829 test_path_two_islands => {
830 edges = &[
831 ("a", "b"),
832 ("c", "d"),
833 ],
834 from = "a",
835 to = "d",
836 shortest_path = None,
837 },
838 test_path_with_cycle => {
839 edges = &[
840 ("a", "b"),
841 ("b", "a"),
842 ],
843 from = "a",
844 to = "b",
845 shortest_path = Some(&["a", "b"]),
846 },
847 test_path_with_cycle_2 => {
848 edges = &[
849 ("a", "b"),
850 ("b", "c"),
851 ("c", "b"),
852 ],
853 from = "a",
854 to = "b",
855 shortest_path = Some(&["a", "b"]),
856 },
857 test_path_with_cycle_3 => {
858 edges = &[
859 ("a", "b"),
860 ("b", "c"),
861 ("c", "b"),
862 ("b", "d"),
863 ("d", "e"),
864 ],
865 from = "a",
866 to = "e",
867 shortest_path = Some(&["a", "b", "d", "e"]),
868 },
869 }
870}