diff --git a/backend/src/util/num_stats.rs b/backend/src/util/num_stats.rs index 6c80b46..42cb548 100644 --- a/backend/src/util/num_stats.rs +++ b/backend/src/util/num_stats.rs @@ -22,7 +22,7 @@ impl> NumS pub fn push(&mut self, val: T) { let slot = &mut self.stack[self.index % self.stack.len()]; - self.sum = (self.sum + val).saturating_sub(*slot); + self.sum = (self.sum + val) - *slot; *slot = val; @@ -31,6 +31,11 @@ impl> NumS pub fn average(&self) -> T { let cap = std::cmp::min(self.index, self.stack.len()); + + if cap == 0 { + return T::zero(); + } + let cap = T::try_from(cap).unwrap_or_else(|_| T::max_value()); self.sum / cap @@ -39,5 +44,61 @@ impl> NumS pub fn reset(&mut self) { self.index = 0; self.sum = T::zero(); + + for val in self.stack.iter_mut() { + *val = T::zero(); + } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn calculates_correct_average() { + let mut stats: NumStats = NumStats::new(10); + + stats.push(3); + stats.push(7); + + assert_eq!(stats.average(), 5); + } + + #[test] + fn calculates_correct_average_over_bounds() { + let mut stats: NumStats = NumStats::new(10); + + stats.push(100); + + for _ in 0..9 { + stats.push(0); + } + + assert_eq!(stats.average(), 10); + + stats.push(0); + + assert_eq!(stats.average(), 0); + } + + #[test] + fn resets_properly() { + let mut stats: NumStats = NumStats::new(10); + + for _ in 0..10 { + stats.push(100); + } + + assert_eq!(stats.average(), 100); + + stats.reset(); + + assert_eq!(stats.average(), 0); + + stats.push(7); + stats.push(3); + + assert_eq!(stats.average(), 5); + } +} \ No newline at end of file