1use std::cmp::min;
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10pub struct DirectedGraph<T: Clone + PartialEq + Hash + Ord + Debug + Display>(
12 HashMap<T, DirectedNode<T>>,
13);
14
15impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> DirectedGraph<T> {
16 pub fn new() -> Self {
18 Self(HashMap::new())
19 }
20
21 pub fn add_edge(&mut self, source: T, target: T) {
23 self.0.entry(source).or_insert_with(DirectedNode::new).add_target(target.clone());
24 self.0.entry(target).or_insert_with(DirectedNode::new);
25 }
26
27 pub fn get_targets<'a>(&'a self, id: &T) -> Option<&'a HashSet<T>> {
29 self.0.get(id).as_ref().map(|node| &node.0)
30 }
31
32 pub fn get_closure<'a>(&'a self, start: &T) -> HashSet<&'a T> {
36 let Some((start, _)) = self.0.get_key_value(start) else {
37 return HashSet::new();
38 };
39 let mut reverse_deps: HashMap<&T, Vec<&T>> = HashMap::new();
40 for (source, targets) in &self.0 {
41 for target in &targets.0 {
42 reverse_deps.entry(target).or_default().push(source);
43 }
44 }
45
46 let mut closure = HashSet::new();
47 let mut to_visit = VecDeque::new();
48
49 closure.insert(start);
50 to_visit.push_back(start);
51
52 while let Some(current_node) = to_visit.pop_front() {
53 if let Some(parents) = reverse_deps.get(¤t_node) {
54 for parent in parents {
55 if closure.insert(*parent) {
56 to_visit.push_back(*parent);
57 }
58 }
59 }
60 }
61
62 closure
63 }
64
65 pub fn add_node(&mut self, source: T) {
68 self.0.entry(source).or_insert_with(DirectedNode::new);
69 }
70
71 pub fn topological_sort<'a>(&'a self) -> Result<Vec<&'a T>, Error<'_, T>> {
74 TarjanSCC::new(self).run()
75 }
76
77 pub fn find_shortest_path<'a>(&'a self, from: &T, to: &T) -> Option<Vec<&'a T>> {
80 let mut shortest_path_edges: HashMap<&'a T, &'a T> = HashMap::new();
90 let from = self.0.get_key_value(from).map(|e| e.0)?;
91 let to = self.0.get_key_value(to).map(|e| e.0)?;
92
93 let mut discovered_nodes = VecDeque::new();
95 discovered_nodes.push_back(from);
96
97 loop {
98 let Some(current_node) = discovered_nodes.pop_front() else {
100 return None;
102 };
103 match self.get_targets(current_node) {
104 None => continue,
105 Some(targets) if targets.is_empty() => continue,
106 Some(targets) => {
107 for target in targets {
108 if !shortest_path_edges.contains_key(target) {
111 shortest_path_edges.insert(target, current_node);
112 discovered_nodes.push_back(target);
113 }
114 if target == to {
117 let mut result = vec![target];
118 let mut path_node: &T = target;
119 loop {
120 path_node = shortest_path_edges.get(&path_node).unwrap();
121 result.push(path_node);
122 if path_node == from {
123 break;
124 }
125 }
126 result.reverse();
127 return Some(result);
128 }
129 }
130 }
131 }
132 }
133 }
134}
135
136impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> Default for DirectedGraph<T> {
137 fn default() -> Self {
138 Self(HashMap::new())
139 }
140}
141
142#[derive(Eq, PartialEq)]
144pub struct DirectedNode<T: Clone + PartialEq + Hash + Ord + Debug + Display>(HashSet<T>);
145
146impl<T: Clone + PartialEq + Hash + Ord + Debug + Display> DirectedNode<T> {
147 pub fn new() -> Self {
149 Self(HashSet::new())
150 }
151
152 pub fn add_target(&mut self, target: T) {
154 self.0.insert(target);
155 }
156}
157
158#[derive(Debug)]
160pub enum Error<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> {
161 CyclesDetected(HashSet<Vec<&'a T>>),
162}
163
164impl<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> Error<'a, T> {
165 pub fn format_cycle(&self) -> String {
166 match &self {
167 Error::CyclesDetected(cycles) => {
168 let mut cycles: Vec<_> = cycles.iter().cloned().collect();
170 cycles.sort_unstable();
171
172 let mut output = "{".to_string();
173 for cycle in cycles.iter() {
174 output.push_str("{");
175 for item in cycle.iter() {
176 output.push_str(&format!("{} -> ", item));
177 }
178 if !cycle.is_empty() {
179 output.truncate(output.len() - 4);
180 }
181 output.push_str("}, ");
182 }
183 if !cycles.is_empty() {
184 output.truncate(output.len() - 2);
185 }
186 output.push_str("}");
187 output
188 }
189 }
190 }
191}
192
193struct TarjanSCC<'a, T: Clone + PartialEq + Hash + Ord + Debug + Display> {
199 index: u64,
201 indices: HashMap<&'a T, u64>,
203 low_links: HashMap<&'a T, u64>,
205 stack: Vec<&'a T>,
207 on_stack: HashSet<&'a T>,
210 cycles: HashSet<Vec<&'a T>>,
212 node_order: Vec<&'a T>,
214 graph: &'a DirectedGraph<T>,
216}
217
218impl<'a, T: Clone + Hash + Ord + Debug + Display> TarjanSCC<'a, T> {
219 fn new(graph: &'a DirectedGraph<T>) -> Self {
220 TarjanSCC {
221 index: 0,
222 indices: HashMap::new(),
223 low_links: HashMap::new(),
224 stack: Vec::new(),
225 on_stack: HashSet::new(),
226 cycles: HashSet::new(),
227 node_order: Vec::new(),
228 graph,
229 }
230 }
231
232 fn run(&mut self) -> Result<Vec<&'a T>, Error<'a, T>> {
235 let mut nodes: Vec<_> = self.graph.0.keys().collect();
238 nodes.sort_unstable();
239 for node in &nodes {
240 if !self.indices.contains_key(node) {
243 self.visit(*node);
244 }
245 }
246
247 if self.cycles.is_empty() {
248 Ok(self.node_order.drain(..).collect())
249 } else {
250 Err(Error::CyclesDetected(self.cycles.drain().collect()))
251 }
252 }
253
254 fn visit(&mut self, current_node: &'a T) {
255 self.indices.insert(current_node, self.index);
257 self.low_links.insert(current_node, self.index);
258 self.index += 1;
259 self.stack.push(current_node);
260 self.on_stack.insert(current_node);
261
262 let mut targets: Vec<_> = self.graph.0[current_node].0.iter().collect();
263 targets.sort_unstable();
264
265 for target in targets {
266 if !self.indices.contains_key(target) {
267 self.visit(target);
269 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
271 let target_low_link = *self.low_links.get(&target).unwrap();
272 self.low_links.insert(current_node, min(current_node_low_link, target_low_link));
273 } else if self.on_stack.contains(target) {
274 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
275 let target_index = *self.indices.get(&target).unwrap();
276 self.low_links.insert(current_node, min(current_node_low_link, target_index));
277 }
278 }
279
280 if self.low_links.get(¤t_node) == self.indices.get(¤t_node) {
282 let mut strongly_connected_nodes = HashSet::new();
283 let mut stack_node;
284 loop {
285 stack_node = self.stack.pop().unwrap();
286 self.on_stack.remove(&stack_node);
287 strongly_connected_nodes.insert(stack_node);
288 if stack_node == current_node {
289 break;
290 }
291 }
292 self.insert_cycles_from_scc(
293 &strongly_connected_nodes,
294 stack_node,
295 HashSet::new(),
296 vec![],
297 );
298 }
299 self.node_order.push(current_node);
300 }
301
302 fn insert_cycles_from_scc(
305 &mut self,
306 scc_nodes: &HashSet<&'a T>,
307 current_node: &'a T,
308 mut visited_nodes: HashSet<&'a T>,
309 mut path: Vec<&'a T>,
310 ) {
311 if visited_nodes.contains(¤t_node) {
312 let (current_node_path_index, _) =
315 path.iter().enumerate().find(|(_, val)| val == &¤t_node).unwrap();
316 let mut cycle = path[current_node_path_index..].to_vec();
317
318 Self::rotate_cycle(&mut cycle);
321 cycle.push(*cycle.first().unwrap());
324 self.cycles.insert(cycle);
325 return;
326 }
327
328 visited_nodes.insert(current_node);
329 path.push(current_node);
330
331 let targets_in_scc: Vec<_> =
332 self.graph.0[¤t_node].0.iter().filter(|n| scc_nodes.contains(n)).collect();
333 for target in targets_in_scc {
334 self.insert_cycles_from_scc(scc_nodes, target, visited_nodes.clone(), path.clone());
335 }
336 }
337
338 fn rotate_cycle(cycle: &mut Vec<&'a T>) {
342 let mut lowest_index = 0;
343 let mut lowest_value = cycle.first().unwrap();
344 for (index, node) in cycle.iter().enumerate() {
345 if node < lowest_value {
346 lowest_index = index;
347 lowest_value = node;
348 }
349 }
350 cycle.rotate_left(lowest_index);
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 macro_rules! test_topological_sort {
359 (
360 $(
361 $test_name:ident => {
362 edges = $edges:expr,
363 order = $order:expr,
364 },
365 )+
366 ) => {
367 $(
368 #[test]
369 fn $test_name() {
370 topological_sort_test(&$edges, &$order);
371 }
372 )+
373 }
374 }
375
376 macro_rules! test_cycles {
377 (
378 $(
379 $test_name:ident => {
380 edges = $edges:expr,
381 cycles = $cycles:expr,
382 },
383 )+
384 ) => {
385 $(
386 #[test]
387 fn $test_name() {
388 cycles_test(&$edges, &$cycles);
389 }
390 )+
391 }
392 }
393
394 macro_rules! test_shortest_path {
395 (
396 $(
397 $test_name:ident => {
398 edges = $edges:expr,
399 from = $from:expr,
400 to = $to:expr,
401 shortest_path = $shortest_path:expr,
402 },
403 )+
404 ) => {
405 $(
406 #[test]
407 fn $test_name() {
408 shortest_path_test($edges, $from, $to, $shortest_path);
409 }
410 )+
411 }
412 }
413
414 fn topological_sort_test(edges: &[(&'static str, &'static str)], order: &[&'static str]) {
415 let mut graph = DirectedGraph::new();
416 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
417 let actual_order = graph.topological_sort().expect("found a cycle");
418
419 let expected_order: Vec<_> = order.iter().collect();
420 assert_eq!(expected_order, actual_order);
421 }
422
423 fn cycles_test(edges: &[(&'static str, &'static str)], cycles: &[&[&'static str]]) {
424 let mut graph = DirectedGraph::new();
425 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
426 let Error::CyclesDetected(reported_cycles) = graph
427 .topological_sort()
428 .expect_err("topological sort succeeded on a dataset with a cycle");
429
430 let expected_cycles: HashSet<Vec<_>> =
431 cycles.iter().cloned().map(|c| c.iter().collect()).collect();
432 assert_eq!(reported_cycles, expected_cycles);
433 }
434
435 fn shortest_path_test(
436 edges: &[(&'static str, &'static str)],
437 from: &'static str,
438 to: &'static str,
439 expected_shortest_path: Option<&[&'static str]>,
440 ) {
441 let mut graph = DirectedGraph::new();
442 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
443 let actual_shortest_path = graph.find_shortest_path(&from, &to);
444 let expected_shortest_path =
445 expected_shortest_path.map(|path| path.iter().collect::<Vec<_>>());
446 assert_eq!(actual_shortest_path, expected_shortest_path);
447 }
448
449 test_topological_sort! {
452 test_empty => {
453 edges = [],
454 order = [],
455 },
456 test_fan_out => {
457 edges = [
458 ("a", "b"),
459 ("b", "c"),
460 ("b", "d"),
461 ("d", "e"),
462 ],
463 order = ["c", "e", "d", "b", "a"],
464 },
465 test_fan_in => {
466 edges = [
467 ("a", "b"),
468 ("b", "d"),
469 ("c", "d"),
470 ("d", "e"),
471 ],
472 order = ["e", "d", "b", "a", "c"],
473 },
474 test_forest => {
475 edges = [
476 ("a", "b"),
477 ("b", "c"),
478 ("d", "e"),
479 ],
480 order = ["c", "b", "a", "e", "d"],
481 },
482 test_diamond => {
483 edges = [
484 ("a", "b"),
485 ("a", "c"),
486 ("b", "d"),
487 ("c", "d"),
488 ],
489 order = ["d", "b", "c", "a"],
490 },
491 test_lattice => {
492 edges = [
493 ("a", "b"),
494 ("a", "c"),
495 ("b", "d"),
496 ("b", "e"),
497 ("c", "d"),
498 ("e", "f"),
499 ("d", "f"),
500 ],
501 order = ["f", "d", "e", "b", "c", "a"],
502 },
503 test_deduped_edge => {
504 edges = [
505 ("a", "b"),
506 ("a", "b"),
507 ("b", "c"),
508 ],
509 order = ["c", "b", "a"],
510 },
511 }
512
513 test_cycles! {
514 test_cycle_self_referential => {
517 edges = [
518 ("a", "a"),
519 ],
520 cycles = [
521 &["a", "a"],
522 ],
523 },
524 test_cycle_two_nodes => {
525 edges = [
526 ("a", "b"),
527 ("b", "a"),
528 ],
529 cycles = [
530 &["a", "b", "a"],
531 ],
532 },
533 test_cycle_two_nodes_with_path_in => {
534 edges = [
535 ("a", "b"),
536 ("b", "c"),
537 ("c", "d"),
538 ("d", "c"),
539 ],
540 cycles = [
541 &["c", "d", "c"],
542 ],
543 },
544 test_cycle_two_nodes_with_path_out => {
545 edges = [
546 ("a", "b"),
547 ("b", "a"),
548 ("b", "c"),
549 ("c", "d"),
550 ],
551 cycles = [
552 &["a", "b", "a"],
553 ],
554 },
555 test_cycle_three_nodes => {
556 edges = [
557 ("a", "b"),
558 ("b", "c"),
559 ("c", "a"),
560 ],
561 cycles = [
562 &["a", "b", "c", "a"],
563 ],
564 },
565 test_cycle_three_nodes_with_inner_cycle => {
566 edges = [
567 ("a", "b"),
568 ("b", "c"),
569 ("c", "b"),
570 ("c", "a"),
571 ],
572 cycles = [
573 &["a", "b", "c", "a"],
574 &["b", "c", "b"],
575 ],
576 },
577 test_cycle_three_nodes_doubly_linked => {
578 edges = [
579 ("a", "b"),
580 ("b", "a"),
581 ("b", "c"),
582 ("c", "b"),
583 ("c", "a"),
584 ("a", "c"),
585 ],
586 cycles = [
587 &["a", "b", "a"],
588 &["b", "c", "b"],
589 &["a", "c", "a"],
590 &["a", "b", "c", "a"],
591 &["a", "c", "b", "a"],
592 ],
593 },
594 test_cycle_with_inner_cycle => {
595 edges = [
596 ("a", "b"),
597 ("b", "c"),
598 ("c", "a"),
599
600 ("b", "d"),
601 ("d", "e"),
602 ("e", "c"),
603 ],
604 cycles = [
605 &["a", "b", "c", "a"],
606 &["a", "b", "d", "e", "c", "a"],
607 ],
608 },
609 test_two_join_cycles => {
610 edges = [
611 ("a", "b"),
612 ("b", "c"),
613 ("c", "a"),
614 ("b", "d"),
615 ("d", "a"),
616 ],
617 cycles = [
618 &["a", "b", "c", "a"],
619 &["a", "b", "d", "a"],
620 ],
621 },
622 test_cycle_four_nodes_doubly_linked => {
623 edges = [
624 ("a", "b"),
625 ("b", "a"),
626 ("b", "c"),
627 ("c", "b"),
628 ("c", "d"),
629 ("d", "c"),
630 ("d", "a"),
631 ("a", "d"),
632 ],
633 cycles = [
634 &["a", "b", "c", "d", "a"],
635 &["a", "b", "a"],
636 &["a", "d", "c", "b", "a"],
637 &["a", "d", "a"],
638 &["b", "c", "b"],
639 &["c", "d", "c"],
640 ],
641 },
642
643 test_cycle_self_referential_islands => {
646 edges = [
647 ("a", "a"),
648 ("b", "b"),
649 ("c", "c"),
650 ("d", "e"),
651 ],
652 cycles = [
653 &["a", "a"],
654 &["b", "b"],
655 &["c", "c"],
656 ],
657 },
658 test_cycle_two_sets_of_two_nodes => {
659 edges = [
660 ("a", "b"),
661 ("b", "a"),
662 ("c", "d"),
663 ("d", "c"),
664 ],
665 cycles = [
666 &["a", "b", "a"],
667 &["c", "d", "c"],
668 ],
669 },
670 test_cycle_two_sets_of_two_nodes_connected => {
671 edges = [
672 ("a", "b"),
673 ("b", "a"),
674 ("c", "d"),
675 ("d", "c"),
676 ("a", "c"),
677 ],
678 cycles = [
679 &["a", "b", "a"],
680 &["c", "d", "c"],
681 ],
682 },
683 }
684
685 test_shortest_path! {
686 test_empty_graph => {
687 edges = &[],
688 from = "a",
689 to = "b",
690 shortest_path = None,
691 },
692 test_two_nodes => {
693 edges = &[
694 ("a", "b"),
695 ],
696 from = "a",
697 to = "b",
698 shortest_path = Some(&["a", "b"]),
699 },
700 test_path_to_self => {
701 edges = &[
702 ("a", "a"),
703 ],
704 from = "a",
705 to = "a",
706 shortest_path = Some(&["a", "a"]),
707 },
708 test_path_to_self_no_edge => {
709 edges = &[
710 ("a", "b"),
711 ],
712 from = "a",
713 to = "a",
714 shortest_path = None,
715 },
716 test_path_three_nodes => {
717 edges = &[
718 ("a", "b"),
719 ("b", "c"),
720 ],
721 from = "a",
722 to = "c",
723 shortest_path = Some(&["a", "b", "c"]),
724 },
725 test_path_multiple_options => {
726 edges = &[
727 ("a", "b"),
728 ("b", "c"),
729 ("a", "c"),
730 ],
731 from = "a",
732 to = "c",
733 shortest_path = Some(&["a", "c"]),
734 },
735 test_path_two_islands => {
736 edges = &[
737 ("a", "b"),
738 ("c", "d"),
739 ],
740 from = "a",
741 to = "d",
742 shortest_path = None,
743 },
744 test_path_with_cycle => {
745 edges = &[
746 ("a", "b"),
747 ("b", "a"),
748 ],
749 from = "a",
750 to = "b",
751 shortest_path = Some(&["a", "b"]),
752 },
753 test_path_with_cycle_2 => {
754 edges = &[
755 ("a", "b"),
756 ("b", "c"),
757 ("c", "b"),
758 ],
759 from = "a",
760 to = "b",
761 shortest_path = Some(&["a", "b"]),
762 },
763 test_path_with_cycle_3 => {
764 edges = &[
765 ("a", "b"),
766 ("b", "c"),
767 ("c", "b"),
768 ("b", "d"),
769 ("d", "e"),
770 ],
771 from = "a",
772 to = "e",
773 shortest_path = Some(&["a", "b", "d", "e"]),
774 },
775 }
776}