1use std::cmp::min;
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10pub struct DirectedGraph<T: PartialEq + Hash + Copy + Ord + Debug + Display>(
12 HashMap<T, DirectedNode<T>>,
13);
14
15impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedGraph<T> {
16 pub fn new() -> Self {
18 Self(HashMap::new())
19 }
20
21 pub fn add_edge(&mut self, source: T, target: T) {
23 self.0.entry(source).or_insert_with(DirectedNode::new).add_target(target);
24 self.0.entry(target).or_insert_with(DirectedNode::new);
25 }
26
27 pub fn get_targets(&self, id: T) -> Option<&HashSet<T>> {
29 self.0.get(&id).as_ref().map(|node| &node.0)
30 }
31
32 pub fn topological_sort(&self) -> Result<Vec<T>, Error<T>> {
37 TarjanSCC::new(self).run()
38 }
39
40 pub fn find_shortest_path(&self, from: T, to: T) -> Option<Vec<T>> {
43 let mut shortest_path_edges: HashMap<T, T> = HashMap::new();
53
54 let mut discovered_nodes = VecDeque::new();
56 discovered_nodes.push_back(from);
57
58 loop {
59 let Some(current_node) = discovered_nodes.pop_front() else {
61 return None;
63 };
64 match self.get_targets(current_node) {
65 None => continue,
66 Some(targets) if targets.is_empty() => continue,
67 Some(targets) => {
68 for target in targets {
69 if !shortest_path_edges.contains_key(target) {
72 shortest_path_edges.insert(*target, current_node);
73 discovered_nodes.push_back(*target);
74 }
75 if *target == to {
78 let mut result = vec![*target];
79 let mut path_node: T = *target;
80 loop {
81 path_node = *shortest_path_edges.get(&path_node).unwrap();
82 result.push(path_node);
83 if path_node == from {
84 break;
85 }
86 }
87 result.reverse();
88 return Some(result);
89 }
90 }
91 }
92 }
93 }
94 }
95}
96
97impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Default for DirectedGraph<T> {
98 fn default() -> Self {
99 Self(HashMap::new())
100 }
101}
102
103#[derive(Eq, PartialEq)]
105struct DirectedNode<T: PartialEq + Hash + Copy + Ord + Debug + Display>(HashSet<T>);
106
107impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedNode<T> {
108 pub fn new() -> Self {
110 Self(HashSet::new())
111 }
112
113 pub fn add_target(&mut self, target: T) {
115 self.0.insert(target);
116 }
117}
118
119#[derive(Debug)]
121pub enum Error<T: PartialEq + Hash + Copy + Ord + Debug + Display> {
122 CyclesDetected(HashSet<Vec<T>>),
123}
124
125impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Error<T> {
126 pub fn format_cycle(&self) -> String {
127 match &self {
128 Error::CyclesDetected(cycles) => {
129 let mut cycles: Vec<_> = cycles.iter().cloned().collect();
131 cycles.sort_unstable();
132
133 let mut output = "{".to_string();
134 for cycle in cycles.iter() {
135 output.push_str("{");
136 for item in cycle.iter() {
137 output.push_str(&format!("{} -> ", item));
138 }
139 if !cycle.is_empty() {
140 output.truncate(output.len() - 4);
141 }
142 output.push_str("}, ");
143 }
144 if !cycles.is_empty() {
145 output.truncate(output.len() - 2);
146 }
147 output.push_str("}");
148 output
149 }
150 }
151 }
152}
153
154struct TarjanSCC<'a, T: PartialEq + Hash + Copy + Ord + Debug + Display> {
160 index: u64,
162 indices: HashMap<T, u64>,
164 low_links: HashMap<T, u64>,
166 stack: Vec<T>,
168 on_stack: HashSet<T>,
171 cycles: HashSet<Vec<T>>,
173 node_order: Vec<T>,
175 graph: &'a DirectedGraph<T>,
177}
178
179impl<'a, T: Hash + Copy + Ord + Debug + Display> TarjanSCC<'a, T> {
180 fn new(graph: &'a DirectedGraph<T>) -> Self {
181 TarjanSCC {
182 index: 0,
183 indices: HashMap::new(),
184 low_links: HashMap::new(),
185 stack: Vec::new(),
186 on_stack: HashSet::new(),
187 cycles: HashSet::new(),
188 node_order: Vec::new(),
189 graph,
190 }
191 }
192
193 fn run(mut self) -> Result<Vec<T>, Error<T>> {
196 let mut nodes: Vec<_> = self.graph.0.keys().cloned().collect();
199 nodes.sort_unstable();
200 for node in &nodes {
201 if !self.indices.contains_key(node) {
204 self.visit(*node);
205 }
206 }
207
208 if self.cycles.is_empty() {
209 Ok(self.node_order.drain(..).collect())
210 } else {
211 Err(Error::CyclesDetected(self.cycles.drain().collect()))
212 }
213 }
214
215 fn visit(&mut self, current_node: T) {
216 self.indices.insert(current_node, self.index);
218 self.low_links.insert(current_node, self.index);
219 self.index += 1;
220 self.stack.push(current_node);
221 self.on_stack.insert(current_node);
222
223 let mut targets: Vec<_> = self.graph.0[¤t_node].0.iter().cloned().collect();
224 targets.sort_unstable();
225
226 for target in &targets {
227 if !self.indices.contains_key(target) {
228 self.visit(*target);
230 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
232 let target_low_link = *self.low_links.get(&target).unwrap();
233 self.low_links.insert(current_node, min(current_node_low_link, target_low_link));
234 } else if self.on_stack.contains(target) {
235 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
236 let target_index = *self.indices.get(&target).unwrap();
237 self.low_links.insert(current_node, min(current_node_low_link, target_index));
238 }
239 }
240
241 if self.low_links.get(¤t_node) == self.indices.get(¤t_node) {
243 let mut strongly_connected_nodes = HashSet::new();
244 let mut stack_node;
245 loop {
246 stack_node = self.stack.pop().unwrap();
247 self.on_stack.remove(&stack_node);
248 strongly_connected_nodes.insert(stack_node);
249 if stack_node == current_node {
250 break;
251 }
252 }
253 self.insert_cycles_from_scc(
254 &strongly_connected_nodes,
255 stack_node,
256 HashSet::new(),
257 vec![],
258 );
259 }
260 self.node_order.push(current_node);
261 }
262
263 fn insert_cycles_from_scc(
266 &mut self,
267 scc_nodes: &HashSet<T>,
268 current_node: T,
269 mut visited_nodes: HashSet<T>,
270 mut path: Vec<T>,
271 ) {
272 if visited_nodes.contains(¤t_node) {
273 let (current_node_path_index, _) =
276 path.iter().enumerate().find(|(_, val)| val == &¤t_node).unwrap();
277 let mut cycle = path[current_node_path_index..].to_vec();
278
279 Self::rotate_cycle(&mut cycle);
282 cycle.push(*cycle.first().unwrap());
285 self.cycles.insert(cycle);
286 return;
287 }
288
289 visited_nodes.insert(current_node);
290 path.push(current_node);
291
292 let targets_in_scc: Vec<_> =
293 self.graph.0[¤t_node].0.iter().filter(|n| scc_nodes.contains(n)).collect();
294 for target in targets_in_scc {
295 self.insert_cycles_from_scc(scc_nodes, *target, visited_nodes.clone(), path.clone());
296 }
297 }
298
299 fn rotate_cycle(cycle: &mut Vec<T>) {
303 let mut lowest_index = 0;
304 let mut lowest_value = cycle.first().unwrap();
305 for (index, node) in cycle.iter().enumerate() {
306 if node < lowest_value {
307 lowest_index = index;
308 lowest_value = node;
309 }
310 }
311 cycle.rotate_left(lowest_index);
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 macro_rules! test_topological_sort {
320 (
321 $(
322 $test_name:ident => {
323 edges = $edges:expr,
324 order = $order:expr,
325 },
326 )+
327 ) => {
328 $(
329 #[test]
330 fn $test_name() {
331 topological_sort_test(&$edges, &$order);
332 }
333 )+
334 }
335 }
336
337 macro_rules! test_cycles {
338 (
339 $(
340 $test_name:ident => {
341 edges = $edges:expr,
342 cycles = $cycles:expr,
343 },
344 )+
345 ) => {
346 $(
347 #[test]
348 fn $test_name() {
349 cycles_test(&$edges, &$cycles);
350 }
351 )+
352 }
353 }
354
355 macro_rules! test_shortest_path {
356 (
357 $(
358 $test_name:ident => {
359 edges = $edges:expr,
360 from = $from:expr,
361 to = $to:expr,
362 shortest_path = $shortest_path:expr,
363 },
364 )+
365 ) => {
366 $(
367 #[test]
368 fn $test_name() {
369 shortest_path_test($edges, $from, $to, $shortest_path);
370 }
371 )+
372 }
373 }
374
375 fn topological_sort_test(edges: &[(&'static str, &'static str)], order: &[&'static str]) {
376 let mut graph = DirectedGraph::new();
377 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
378 let actual_order = graph.topological_sort().expect("found a cycle");
379
380 let expected_order: Vec<_> = order.iter().cloned().collect();
381 assert_eq!(expected_order, actual_order);
382 }
383
384 fn cycles_test(edges: &[(&'static str, &'static str)], cycles: &[&[&'static str]]) {
385 let mut graph = DirectedGraph::new();
386 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
387 let Error::CyclesDetected(reported_cycles) = graph
388 .topological_sort()
389 .expect_err("topological sort succeeded on a dataset with a cycle");
390
391 let expected_cycles: HashSet<Vec<_>> =
392 cycles.iter().cloned().map(|c| c.iter().cloned().collect()).collect();
393 assert_eq!(reported_cycles, expected_cycles);
394 }
395
396 fn shortest_path_test(
397 edges: &[(&'static str, &'static str)],
398 from: &'static str,
399 to: &'static str,
400 expected_shortest_path: Option<&[&'static str]>,
401 ) {
402 let mut graph = DirectedGraph::new();
403 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
404 let actual_shortest_path = graph.find_shortest_path(from, to);
405 let expected_shortest_path =
406 expected_shortest_path.map(|path| path.iter().cloned().collect::<Vec<_>>());
407 assert_eq!(actual_shortest_path, expected_shortest_path);
408 }
409
410 test_topological_sort! {
413 test_empty => {
414 edges = [],
415 order = [],
416 },
417 test_fan_out => {
418 edges = [
419 ("a", "b"),
420 ("b", "c"),
421 ("b", "d"),
422 ("d", "e"),
423 ],
424 order = ["c", "e", "d", "b", "a"],
425 },
426 test_fan_in => {
427 edges = [
428 ("a", "b"),
429 ("b", "d"),
430 ("c", "d"),
431 ("d", "e"),
432 ],
433 order = ["e", "d", "b", "a", "c"],
434 },
435 test_forest => {
436 edges = [
437 ("a", "b"),
438 ("b", "c"),
439 ("d", "e"),
440 ],
441 order = ["c", "b", "a", "e", "d"],
442 },
443 test_diamond => {
444 edges = [
445 ("a", "b"),
446 ("a", "c"),
447 ("b", "d"),
448 ("c", "d"),
449 ],
450 order = ["d", "b", "c", "a"],
451 },
452 test_lattice => {
453 edges = [
454 ("a", "b"),
455 ("a", "c"),
456 ("b", "d"),
457 ("b", "e"),
458 ("c", "d"),
459 ("e", "f"),
460 ("d", "f"),
461 ],
462 order = ["f", "d", "e", "b", "c", "a"],
463 },
464 test_deduped_edge => {
465 edges = [
466 ("a", "b"),
467 ("a", "b"),
468 ("b", "c"),
469 ],
470 order = ["c", "b", "a"],
471 },
472 }
473
474 test_cycles! {
475 test_cycle_self_referential => {
478 edges = [
479 ("a", "a"),
480 ],
481 cycles = [
482 &["a", "a"],
483 ],
484 },
485 test_cycle_two_nodes => {
486 edges = [
487 ("a", "b"),
488 ("b", "a"),
489 ],
490 cycles = [
491 &["a", "b", "a"],
492 ],
493 },
494 test_cycle_two_nodes_with_path_in => {
495 edges = [
496 ("a", "b"),
497 ("b", "c"),
498 ("c", "d"),
499 ("d", "c"),
500 ],
501 cycles = [
502 &["c", "d", "c"],
503 ],
504 },
505 test_cycle_two_nodes_with_path_out => {
506 edges = [
507 ("a", "b"),
508 ("b", "a"),
509 ("b", "c"),
510 ("c", "d"),
511 ],
512 cycles = [
513 &["a", "b", "a"],
514 ],
515 },
516 test_cycle_three_nodes => {
517 edges = [
518 ("a", "b"),
519 ("b", "c"),
520 ("c", "a"),
521 ],
522 cycles = [
523 &["a", "b", "c", "a"],
524 ],
525 },
526 test_cycle_three_nodes_with_inner_cycle => {
527 edges = [
528 ("a", "b"),
529 ("b", "c"),
530 ("c", "b"),
531 ("c", "a"),
532 ],
533 cycles = [
534 &["a", "b", "c", "a"],
535 &["b", "c", "b"],
536 ],
537 },
538 test_cycle_three_nodes_doubly_linked => {
539 edges = [
540 ("a", "b"),
541 ("b", "a"),
542 ("b", "c"),
543 ("c", "b"),
544 ("c", "a"),
545 ("a", "c"),
546 ],
547 cycles = [
548 &["a", "b", "a"],
549 &["b", "c", "b"],
550 &["a", "c", "a"],
551 &["a", "b", "c", "a"],
552 &["a", "c", "b", "a"],
553 ],
554 },
555 test_cycle_with_inner_cycle => {
556 edges = [
557 ("a", "b"),
558 ("b", "c"),
559 ("c", "a"),
560
561 ("b", "d"),
562 ("d", "e"),
563 ("e", "c"),
564 ],
565 cycles = [
566 &["a", "b", "c", "a"],
567 &["a", "b", "d", "e", "c", "a"],
568 ],
569 },
570 test_two_join_cycles => {
571 edges = [
572 ("a", "b"),
573 ("b", "c"),
574 ("c", "a"),
575 ("b", "d"),
576 ("d", "a"),
577 ],
578 cycles = [
579 &["a", "b", "c", "a"],
580 &["a", "b", "d", "a"],
581 ],
582 },
583 test_cycle_four_nodes_doubly_linked => {
584 edges = [
585 ("a", "b"),
586 ("b", "a"),
587 ("b", "c"),
588 ("c", "b"),
589 ("c", "d"),
590 ("d", "c"),
591 ("d", "a"),
592 ("a", "d"),
593 ],
594 cycles = [
595 &["a", "b", "c", "d", "a"],
596 &["a", "b", "a"],
597 &["a", "d", "c", "b", "a"],
598 &["a", "d", "a"],
599 &["b", "c", "b"],
600 &["c", "d", "c"],
601 ],
602 },
603
604 test_cycle_self_referential_islands => {
607 edges = [
608 ("a", "a"),
609 ("b", "b"),
610 ("c", "c"),
611 ("d", "e"),
612 ],
613 cycles = [
614 &["a", "a"],
615 &["b", "b"],
616 &["c", "c"],
617 ],
618 },
619 test_cycle_two_sets_of_two_nodes => {
620 edges = [
621 ("a", "b"),
622 ("b", "a"),
623 ("c", "d"),
624 ("d", "c"),
625 ],
626 cycles = [
627 &["a", "b", "a"],
628 &["c", "d", "c"],
629 ],
630 },
631 test_cycle_two_sets_of_two_nodes_connected => {
632 edges = [
633 ("a", "b"),
634 ("b", "a"),
635 ("c", "d"),
636 ("d", "c"),
637 ("a", "c"),
638 ],
639 cycles = [
640 &["a", "b", "a"],
641 &["c", "d", "c"],
642 ],
643 },
644 }
645
646 test_shortest_path! {
647 test_empty_graph => {
648 edges = &[],
649 from = "a",
650 to = "b",
651 shortest_path = None,
652 },
653 test_two_nodes => {
654 edges = &[
655 ("a", "b"),
656 ],
657 from = "a",
658 to = "b",
659 shortest_path = Some(&["a", "b"]),
660 },
661 test_path_to_self => {
662 edges = &[
663 ("a", "a"),
664 ],
665 from = "a",
666 to = "a",
667 shortest_path = Some(&["a", "a"]),
668 },
669 test_path_to_self_no_edge => {
670 edges = &[
671 ("a", "b"),
672 ],
673 from = "a",
674 to = "a",
675 shortest_path = None,
676 },
677 test_path_three_nodes => {
678 edges = &[
679 ("a", "b"),
680 ("b", "c"),
681 ],
682 from = "a",
683 to = "c",
684 shortest_path = Some(&["a", "b", "c"]),
685 },
686 test_path_multiple_options => {
687 edges = &[
688 ("a", "b"),
689 ("b", "c"),
690 ("a", "c"),
691 ],
692 from = "a",
693 to = "c",
694 shortest_path = Some(&["a", "c"]),
695 },
696 test_path_two_islands => {
697 edges = &[
698 ("a", "b"),
699 ("c", "d"),
700 ],
701 from = "a",
702 to = "d",
703 shortest_path = None,
704 },
705 test_path_with_cycle => {
706 edges = &[
707 ("a", "b"),
708 ("b", "a"),
709 ],
710 from = "a",
711 to = "b",
712 shortest_path = Some(&["a", "b"]),
713 },
714 test_path_with_cycle_2 => {
715 edges = &[
716 ("a", "b"),
717 ("b", "c"),
718 ("c", "b"),
719 ],
720 from = "a",
721 to = "b",
722 shortest_path = Some(&["a", "b"]),
723 },
724 test_path_with_cycle_3 => {
725 edges = &[
726 ("a", "b"),
727 ("b", "c"),
728 ("c", "b"),
729 ("b", "d"),
730 ("d", "e"),
731 ],
732 from = "a",
733 to = "e",
734 shortest_path = Some(&["a", "b", "d", "e"]),
735 },
736 }
737}