1use anyhow::{format_err, Error};
6
7mod rkf45_params {
12 pub const NUM_STAGES: usize = 6;
14
15 pub static A: [[f64; NUM_STAGES - 1]; NUM_STAGES - 1] = [
17 [1.0 / 4.0, 0.0, 0.0, 0.0, 0.0],
18 [3.0 / 32.0, 9.0 / 32.0, 0.0, 0.0, 0.0],
19 [1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0, 0.0, 0.0],
20 [439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0, 0.0],
21 [-8.0 / 27.0, 2.0, -3544.0 / 2565.0, 1859.0 / 4104.0, -11.0 / 40.0],
22 ];
23
24 pub static B4: [f64; NUM_STAGES] =
26 [25.0 / 216.0, 0.0, 1408.0 / 2565.0, 2197.0 / 4104.0, -1.0 / 5.0, 0.0];
27
28 pub static B5: [f64; NUM_STAGES] =
30 [16.0 / 135.0, 0.0, 6656.0 / 12825.0, 28561.0 / 56430.0, -9.0 / 50.0, 2.0 / 55.0];
31
32 pub static C: [f64; NUM_STAGES - 1] = [1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0];
34}
35
36mod adaptive_stepping_params {
42 pub static SAFETY_FACTOR: f64 = 0.9;
46
47 pub static ERROR_RATIO_REFINING_THRESHOLD: f64 = 1.1;
50 pub static ERROR_RATIO_COARSENING_THRESHOLD: f64 = 0.5;
51
52 pub static REFINING_EXPONENT: f64 = -0.25;
79 pub static COARSENING_EXPONENT: f64 = -0.20;
80
81 pub static MIN_REFINING_FACTOR: f64 = 0.2;
84 pub static MAX_COARSENING_FACTOR: f64 = 5.0;
85}
86
87trait VectorOperations {
93 fn scale(&mut self, a: f64);
95
96 fn copy_from(&mut self, x: &Self);
98
99 fn add(&mut self, x: &Self);
101
102 fn add_ax(&mut self, a: f64, x: &Self);
104
105 fn subtract(&mut self, x: &Self);
107
108 fn to_zeros(&mut self);
110}
111
112impl VectorOperations for [f64] {
113 fn scale(&mut self, a: f64) {
114 self.iter_mut().for_each(|p| *p = *p * a);
115 }
116
117 fn copy_from(&mut self, x: &Self) {
118 self.iter_mut().zip(x.iter()).for_each(|(p, q)| *p = *q);
119 }
120
121 fn add(&mut self, x: &Self) {
122 self.iter_mut().zip(x.iter()).for_each(|(p, q)| *p += *q);
123 }
124
125 fn add_ax(&mut self, a: f64, x: &Self) {
126 self.iter_mut().zip(x.iter()).for_each(|(p, q)| *p += a * *q);
127 }
128
129 fn subtract(&mut self, x: &Self) {
130 self.iter_mut().zip(x.iter()).for_each(|(p, q)| *p -= *q);
131 }
132
133 fn to_zeros(&mut self) {
134 self.iter_mut().for_each(|p| *p = 0.0)
135 }
136}
137
138fn rkf45_step(
151 y: &mut [f64],
152 dydt: &impl Fn(f64, &[f64]) -> Vec<f64>,
153 tn: f64,
154 dt: f64,
155 error_control: &ErrorControlOptions,
156) -> (Vec<f64>, f64) {
157 let mut work = vec![0.0; y.len()];
159
160 let mut k = Vec::with_capacity(6);
161 k.push(dydt(tn, y));
162
163 use rkf45_params as params;
164 for i in 1..params::NUM_STAGES {
165 work.to_zeros();
166 for j in 0..i {
167 work.add_ax(params::A[i - 1][j], &k[j]);
168 }
169 work.scale(dt);
171 work.add(&y);
172 k.push(dydt(tn + params::C[i - 1] * dt, &work));
173 }
174
175 let mut y_5th_order = y.to_vec();
177 work.to_zeros();
178 for i in 0..params::B5.len() {
179 work.add_ax(params::B5[i], &k[i]);
180 }
181 y_5th_order.add_ax(dt, &work);
182
183 work.to_zeros();
185 for i in 0..params::B4.len() {
186 work.add_ax(params::B4[i], &k[i]);
187 }
188 y.add_ax(dt, &work);
189
190 let mut error_estimate = y_5th_order;
191 error_estimate.subtract(&y);
192 error_estimate.iter_mut().for_each(|x| *x = (*x).abs());
193
194 let mut max_error_ratio = 0.0;
195 for i in 0..y.len() {
196 let error_bound = error_control.absolute_magnitude
198 + error_control.relative_magnitude
199 * (error_control.function_scale * y[i].abs()
200 + error_control.derivative_scale * dt * k[0][i].abs());
201 max_error_ratio = f64::max(max_error_ratio, error_estimate[i] / error_bound);
202 }
203
204 (error_estimate, max_error_ratio)
205}
206
207pub struct AdaptiveOdeSolverOptions {
209 pub t_initial: f64,
211 pub t_final: f64,
213 pub dt_initial: f64,
215 pub error_control: ErrorControlOptions,
217}
218
219pub struct ErrorControlOptions {
230 pub absolute_magnitude: f64,
234 pub relative_magnitude: f64,
236 pub function_scale: f64,
238 pub derivative_scale: f64,
240}
241
242impl ErrorControlOptions {
243 pub fn simple(scale: f64) -> ErrorControlOptions {
250 ErrorControlOptions {
251 absolute_magnitude: scale,
252 relative_magnitude: scale,
253 function_scale: 1.0,
254 derivative_scale: 1.0,
255 }
256 }
257}
258
259pub fn rkf45_adaptive(
273 y: &mut [f64],
274 dydt: &impl Fn(f64, &[f64]) -> Vec<f64>,
275 options: &AdaptiveOdeSolverOptions,
276) -> Result<Vec<f64>, Error> {
277 macro_rules! validate_input {
278 ($x:expr) => {
279 if !($x) {
280 return Err(format_err!("Failed input validation: {}", stringify!($x)));
281 }
282 };
283 }
284
285 validate_input!(options.t_final > options.t_initial);
286 validate_input!(options.dt_initial > 0.0);
287 validate_input!(options.error_control.absolute_magnitude > 0.0);
288 validate_input!(options.error_control.relative_magnitude >= 0.0);
289 validate_input!(options.error_control.function_scale >= 0.0);
290 validate_input!(options.error_control.derivative_scale >= 0.0);
291
292 let mut t = options.t_initial;
293 let mut dt = options.dt_initial;
294 let mut total_error = vec![0.0; y.len()];
295
296 let mut work = y.to_vec();
297 while t < options.t_final {
298 if t + dt > options.t_final {
299 dt = options.t_final - t;
300 }
301
302 let (error_estimate, max_error_ratio) =
303 rkf45_step(&mut work, dydt, t, dt, &options.error_control);
304 use adaptive_stepping_params as params;
305 if max_error_ratio < params::ERROR_RATIO_REFINING_THRESHOLD {
306 y.copy_from(&work);
308 total_error.add(&error_estimate);
309 t += dt;
310 if max_error_ratio < params::ERROR_RATIO_COARSENING_THRESHOLD {
311 let factor = f64::min(
313 params::SAFETY_FACTOR * max_error_ratio.powf(params::COARSENING_EXPONENT),
314 params::MAX_COARSENING_FACTOR,
315 );
316 dt *= factor;
317 }
318 } else {
319 work.copy_from(y);
321 let factor = f64::max(
322 params::SAFETY_FACTOR * max_error_ratio.powf(params::REFINING_EXPONENT),
323 params::MIN_REFINING_FACTOR,
324 );
325 dt *= factor;
326 }
327 }
328
329 Ok(total_error)
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use std::f64::consts::PI;
336 use test_util::{assert_gt, assert_lt};
337
338 static MEANINGLESS_OPTIONS: ErrorControlOptions = ErrorControlOptions {
341 absolute_magnitude: 1e-8,
342 relative_magnitude: 1e-8,
343 function_scale: 1.0,
344 derivative_scale: 1.0,
345 };
346
347 #[test]
355 fn test_first_order_problem_rkf45_step() {
356 let lambda = -0.1;
357 let dydt = |_t: f64, y: &[f64]| -> Vec<f64> { vec![lambda * y[0]] };
358 let y_true = |t: f64| -> f64 { f64::exp(lambda * t) };
359
360 let mut actual_errors = Vec::new();
364 let mut errors_in_estimated_error = Vec::new();
365 for dt in &[1.0, 0.5, 0.25] {
366 let mut y = [1.0];
367 let estimated_error = rkf45_step(&mut y, &dydt, 0.0, *dt, &MEANINGLESS_OPTIONS).0[0];
368 let actual_error = (y_true(*dt) - y[0]).abs();
369 errors_in_estimated_error.push((estimated_error - actual_error).abs());
370 actual_errors.push(actual_error);
371 }
372
373 assert_lt!(actual_errors[1], actual_errors[0] / 32.0 * 1.10);
377 assert_lt!(actual_errors[2], actual_errors[1] / 32.0 * 1.10);
378 assert_lt!(errors_in_estimated_error[1], errors_in_estimated_error[0] / 2.0 * 1.10);
379 assert_lt!(errors_in_estimated_error[2], errors_in_estimated_error[1] / 2.0 * 1.10);
380 }
381
382 #[test]
386 fn test_first_order_problem_rkf45_adaptive() -> Result<(), Error> {
387 let lambda = -0.1;
388 let dydt = |_t: f64, y: &[f64]| -> Vec<f64> { vec![lambda * y[0]] };
389 let y_true = |t: f64| -> f64 { f64::exp(lambda * t) };
390
391 let options = AdaptiveOdeSolverOptions {
392 t_initial: 0.0,
393 t_final: 3.0,
394 dt_initial: 0.1,
395 error_control: ErrorControlOptions::simple(1e-6),
396 };
397 let mut y = [1.0];
398 rkf45_adaptive(&mut y, &dydt, &options)?;
399 let actual_error = (y_true(options.t_final) - y[0]).abs();
400
401 assert_lt!(actual_error, 1e-5);
403 assert_gt!(actual_error, 1e-7);
404
405 Ok(())
406 }
407
408 #[test]
420 fn test_second_order_problem_rkf45_step() {
421 let dydt = |_t: f64, y: &[f64]| -> Vec<f64> { vec![y[1], -y[0]] };
422 let y_true = |t: f64| -> f64 { f64::cos(t) };
423
424 let mut actual_errors = Vec::new();
425 let mut errors_in_estimated_error = Vec::new();
426 for dt in &[PI / 4.0, PI / 8.0, PI / 16.0] {
427 let mut y = [1.0, 0.0];
428 let estimated_error = rkf45_step(&mut y, &dydt, 0.0, *dt, &MEANINGLESS_OPTIONS).0[0];
429 let actual_error = (y_true(*dt) - y[0]).abs();
430 errors_in_estimated_error.push((estimated_error - actual_error).abs());
431 actual_errors.push(actual_error);
432 }
433
434 assert_lt!(actual_errors[1], actual_errors[0] / 32.0 * 1.10);
435 assert_lt!(actual_errors[2], actual_errors[1] / 32.0 * 1.10);
436 assert_lt!(errors_in_estimated_error[1], errors_in_estimated_error[0] / 2.0 * 1.10);
437 assert_lt!(errors_in_estimated_error[2], errors_in_estimated_error[1] / 2.0 * 1.10);
438 }
439
440 #[test]
442 fn test_second_order_problem_rkf45_adaptive() -> Result<(), Error> {
443 let dydt = |_t: f64, y: &[f64]| -> Vec<f64> { vec![y[1], -y[0]] };
444 let y_true = |t: f64| -> f64 { f64::cos(t) };
445
446 let options = AdaptiveOdeSolverOptions {
447 t_initial: 0.0,
448 t_final: 2.0 * PI,
449 dt_initial: PI / 4.0,
450 error_control: ErrorControlOptions::simple(1e-6),
451 };
452 let mut y = [1.0, 0.0];
453 rkf45_adaptive(&mut y, &dydt, &options)?;
454 let actual_error = (y_true(options.t_final) - y[0]).abs();
455
456 assert_lt!(actual_error, 1e-4);
458 assert_gt!(actual_error, 1e-7);
459
460 Ok(())
461 }
462
463 #[test]
478 fn test_third_order_problem_rkf45_step() {
479 let alpha = 0.1;
480 let square = |x: f64| -> f64 { x * x };
481 let dydt = |t: f64, y: &[f64]| -> Vec<f64> {
482 vec![y[1], y[2], -12.0 * square(alpha) * t * y[0] - 4.0 * square(alpha * t) * y[1]]
483 };
484 let y_true = |t: f64| -> f64 { f64::cos(alpha * square(t)) };
485
486 let mut actual_errors = Vec::new();
487 let mut errors_in_estimated_error = Vec::new();
488 for dt in &[0.25, 0.125, 0.0625] {
489 let mut y = [1.0, 0.0, 0.0];
490 let estimated_error = rkf45_step(&mut y, &dydt, 0.0, *dt, &MEANINGLESS_OPTIONS).0[0];
491 let actual_error = (y_true(*dt) - y[0]).abs();
492 errors_in_estimated_error.push((estimated_error - actual_error).abs());
493 actual_errors.push(actual_error);
494 }
495
496 assert_lt!(actual_errors[1], actual_errors[0] / 32.0 * 1.10);
497 assert_lt!(actual_errors[2], actual_errors[1] / 32.0 * 1.10);
498 assert_lt!(errors_in_estimated_error[1], errors_in_estimated_error[0] / 2.0 * 1.10);
499 assert_lt!(errors_in_estimated_error[2], errors_in_estimated_error[1] / 2.0 * 1.10);
500 }
501 #[test]
503 fn test_third_order_problem_rkf45_adaptive() -> Result<(), Error> {
504 let alpha = 0.1;
505 let square = |x: f64| -> f64 { x * x };
506 let dydt = |t: f64, y: &[f64]| -> Vec<f64> {
507 vec![y[1], y[2], -12.0 * square(alpha) * t * y[0] - 4.0 * square(alpha * t) * y[1]]
508 };
509 let y_true = |t: f64| -> f64 { f64::cos(alpha * square(t)) };
510
511 let options = AdaptiveOdeSolverOptions {
512 t_initial: 0.0,
513 t_final: 4.0,
514 dt_initial: 0.25,
515 error_control: ErrorControlOptions::simple(1e-6),
516 };
517 let mut y = [1.0, 0.0, 0.0];
518 rkf45_adaptive(&mut y, &dydt, &options)?;
519 let actual_error = (y_true(options.t_final) - y[0]).abs();
520
521 assert_lt!(actual_error, 1e-4);
523 assert_gt!(actual_error, 1e-7);
524
525 Ok(())
526 }
527
528 #[test]
535 fn test_multiple_time_scales() -> Result<(), Error> {
536 let lambda1 = 10.0;
537 let lambda2 = 0.001;
538 let dydt = |_t: f64, y: &[f64]| -> Vec<f64> { vec![-lambda1 * y[0], -lambda2 * y[1]] };
539 let y_true = |t: f64| -> Vec<f64> { vec![f64::exp(-lambda1 * t), f64::exp(-lambda2 * t)] };
540
541 let options = AdaptiveOdeSolverOptions {
542 t_initial: 0.0,
543 t_final: 1.0,
544 dt_initial: 1.0,
545 error_control: ErrorControlOptions::simple(1e-6),
546 };
547 let mut y = [1.0, 1.0];
548 rkf45_adaptive(&mut y, &dydt, &options)?;
549 let mut actual_error = y_true(options.t_final);
550 actual_error.iter_mut().zip(y.iter()).for_each(|(p, q)| *p = (*p - *q).abs());
551
552 assert_lt!(actual_error[0], 1e-5);
557 assert_lt!(actual_error[1], 1e-5);
558 assert_gt!(actual_error[0], 1e-7);
559
560 Ok(())
561 }
562
563 #[test]
565 fn test_error_checks() {
566 let dydt = |_t: f64, _y: &[f64]| -> Vec<f64> { vec![0.0] };
567
568 assert!(rkf45_adaptive(
569 &mut [1.0],
570 &dydt,
571 &AdaptiveOdeSolverOptions {
572 t_initial: 2.0, t_final: 1.0,
574 dt_initial: 0.1,
575 error_control: ErrorControlOptions::simple(1e-8),
576 }
577 )
578 .is_err());
579
580 assert!(rkf45_adaptive(
581 &mut [1.0],
582 &dydt,
583 &AdaptiveOdeSolverOptions {
584 t_initial: 1.0,
585 t_final: 2.0,
586 dt_initial: -0.1, error_control: ErrorControlOptions {
588 absolute_magnitude: 1e-8,
589 relative_magnitude: 1e-8,
590 function_scale: 1.0,
591 derivative_scale: 1.0,
592 }
593 }
594 )
595 .is_err());
596
597 assert!(rkf45_adaptive(
598 &mut [1.0],
599 &dydt,
600 &AdaptiveOdeSolverOptions {
601 t_initial: 1.0,
602 t_final: 2.0,
603 dt_initial: 0.1,
604 error_control: ErrorControlOptions {
605 absolute_magnitude: -1e-8, relative_magnitude: 1e-8,
607 function_scale: 1.0,
608 derivative_scale: 1.0,
609 }
610 }
611 )
612 .is_err());
613
614 assert!(rkf45_adaptive(
615 &mut [1.0],
616 &dydt,
617 &AdaptiveOdeSolverOptions {
618 t_initial: 1.0,
619 t_final: 2.0,
620 dt_initial: 0.1,
621 error_control: ErrorControlOptions {
622 absolute_magnitude: 1e-8,
623 relative_magnitude: -1e-8, function_scale: 1.0,
625 derivative_scale: 1.0,
626 }
627 }
628 )
629 .is_err());
630
631 assert!(rkf45_adaptive(
632 &mut [1.0],
633 &dydt,
634 &AdaptiveOdeSolverOptions {
635 t_initial: 1.0,
636 t_final: 2.0,
637 dt_initial: 0.1,
638 error_control: ErrorControlOptions {
639 absolute_magnitude: 1e-8,
640 relative_magnitude: 1e-8,
641 function_scale: -1.0, derivative_scale: 1.0,
643 }
644 }
645 )
646 .is_err());
647
648 assert!(rkf45_adaptive(
649 &mut [1.0],
650 &dydt,
651 &AdaptiveOdeSolverOptions {
652 t_initial: 1.0,
653 t_final: 2.0,
654 dt_initial: 0.1,
655 error_control: ErrorControlOptions {
656 absolute_magnitude: 1e-8,
657 relative_magnitude: 1e-8,
658 function_scale: 1.0,
659 derivative_scale: -1.0, }
661 }
662 )
663 .is_err());
664
665 assert!(rkf45_adaptive(
666 &mut [1.0],
667 &dydt,
668 &AdaptiveOdeSolverOptions {
669 t_initial: 1.0,
670 t_final: 2.0,
671 dt_initial: 0.1,
672 error_control: ErrorControlOptions {
673 absolute_magnitude: 0.0, relative_magnitude: 1e-8,
675 function_scale: 1.0,
676 derivative_scale: 1.0,
677 }
678 }
679 )
680 .is_err());
681 }
682}