mirror of
https://github.com/pezkuwichain/pezkuwi-subxt.git
synced 2026-04-26 19:17:58 +00:00
Generic Normalize impl for arithmetic and npos-elections (#6374)
* add normalize * better api for normalize * Some grumbles * Update primitives/arithmetic/src/lib.rs Co-authored-by: Guillaume Thiolliere <gui.thiolliere@gmail.com> * More great review grumbles * Way better doc for everything. * Some improvement * Update primitives/arithmetic/src/lib.rs Co-authored-by: Bernhard Schuster <bernhard@ahoi.io> Co-authored-by: Guillaume Thiolliere <gui.thiolliere@gmail.com> Co-authored-by: Bernhard Schuster <bernhard@ahoi.io>
This commit is contained in:
@@ -41,10 +41,11 @@ mod fixed_point;
|
||||
mod rational128;
|
||||
|
||||
pub use fixed_point::{FixedPointNumber, FixedPointOperand, FixedI64, FixedI128, FixedU128};
|
||||
pub use per_things::{PerThing, InnerOf, Percent, PerU16, Permill, Perbill, Perquintill};
|
||||
pub use per_things::{PerThing, InnerOf, UpperOf, Percent, PerU16, Permill, Perbill, Perquintill};
|
||||
pub use rational128::Rational128;
|
||||
|
||||
use sp_std::cmp::Ordering;
|
||||
use sp_std::{prelude::*, cmp::Ordering, fmt::Debug, convert::TryInto};
|
||||
use traits::{BaseArithmetic, One, Zero, SaturatedConversion, Unsigned};
|
||||
|
||||
/// Trait for comparing two numbers with an threshold.
|
||||
///
|
||||
@@ -85,8 +86,365 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// A collection-like object that is made of values of type `T` and can normalize its individual
|
||||
/// values around a centric point.
|
||||
///
|
||||
/// Note that the order of items in the collection may affect the result.
|
||||
pub trait Normalizable<T> {
|
||||
/// Normalize self around `targeted_sum`.
|
||||
///
|
||||
/// Only returns `Ok` if the new sum of results is guaranteed to be equal to `targeted_sum`.
|
||||
/// Else, returns an error explaining why it failed to do so.
|
||||
fn normalize(&self, targeted_sum: T) -> Result<Vec<T>, &'static str>;
|
||||
}
|
||||
|
||||
macro_rules! impl_normalize_for_numeric {
|
||||
($($numeric:ty),*) => {
|
||||
$(
|
||||
impl Normalizable<$numeric> for Vec<$numeric> {
|
||||
fn normalize(&self, targeted_sum: $numeric) -> Result<Vec<$numeric>, &'static str> {
|
||||
normalize(self.as_ref(), targeted_sum)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_normalize_for_numeric!(u8, u16, u32, u64, u128);
|
||||
|
||||
impl<P: PerThing> Normalizable<P> for Vec<P> {
|
||||
fn normalize(&self, targeted_sum: P) -> Result<Vec<P>, &'static str> {
|
||||
let inners = self.iter().map(|p| p.clone().deconstruct().into()).collect::<Vec<_>>();
|
||||
let normalized = normalize(inners.as_ref(), targeted_sum.deconstruct().into())?;
|
||||
Ok(normalized.into_iter().map(|i: UpperOf<P>| P::from_parts(i.saturated_into())).collect())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Normalize `input` so that the sum of all elements reaches `targeted_sum`.
|
||||
///
|
||||
/// This implementation is currently in a balanced position between being performant and accurate.
|
||||
///
|
||||
/// 1. We prefer storing original indices, and sorting the `input` only once. This will save the
|
||||
/// cost of sorting per round at the cost of a little bit of memory.
|
||||
/// 2. The granularity of increment/decrements is determined by the number of elements in `input`
|
||||
/// and their sum difference with `targeted_sum`, namely `diff = diff(sum(input), target_sum)`.
|
||||
/// This value is then distributed into `per_round = diff / input.len()` and `leftover = diff %
|
||||
/// round`. First, per_round is applied to all elements of input, and then we move to leftover,
|
||||
/// in which case we add/subtract 1 by 1 until `leftover` is depleted.
|
||||
///
|
||||
/// When the sum is less than the target, the above approach always holds. In this case, then each
|
||||
/// individual element is also less than target. Thus, by adding `per_round` to each item, neither
|
||||
/// of them can overflow the numeric bound of `T`. In fact, neither of the can go beyond
|
||||
/// `target_sum`*.
|
||||
///
|
||||
/// If sum is more than target, there is small twist. The subtraction of `per_round`
|
||||
/// form each element might go below zero. In this case, we saturate and add the error to the
|
||||
/// `leftover` value. This ensures that the result will always stay accurate, yet it might cause the
|
||||
/// execution to become increasingly slow, since leftovers are applied one by one.
|
||||
///
|
||||
/// All in all, the complicated case above is rare to happen in all substrate use cases, hence we
|
||||
/// opt for it due to its simplicity.
|
||||
///
|
||||
/// This function will return an error is if length of `input` cannot fit in `T`, or if `sum(input)`
|
||||
/// cannot fit inside `T`.
|
||||
///
|
||||
/// * This proof is used in the implementation as well.
|
||||
pub fn normalize<T>(input: &[T], targeted_sum: T) -> Result<Vec<T>, &'static str>
|
||||
where T: Clone + Copy + Ord + BaseArithmetic + Unsigned + Debug,
|
||||
{
|
||||
// compute sum and return error if failed.
|
||||
let mut sum = T::zero();
|
||||
for t in input.iter() {
|
||||
sum = sum.checked_add(t).ok_or("sum of input cannot fit in `T`")?;
|
||||
}
|
||||
|
||||
// convert count and return error if failed.
|
||||
let count = input.len();
|
||||
let count_t: T = count.try_into().map_err(|_| "length of `inputs` cannot fit in `T`")?;
|
||||
|
||||
// Nothing to do here.
|
||||
if count.is_zero() {
|
||||
return Ok(Vec::<T>::new());
|
||||
}
|
||||
|
||||
let diff = targeted_sum.max(sum) - targeted_sum.min(sum);
|
||||
if diff.is_zero() {
|
||||
return Ok(input.to_vec());
|
||||
}
|
||||
|
||||
let needs_bump = targeted_sum > sum;
|
||||
let per_round = diff / count_t;
|
||||
let mut leftover = diff % count_t;
|
||||
|
||||
// sort output once based on diff. This will require more data transfer and saving original
|
||||
// index, but we sort only twice instead: once now and once at the very end.
|
||||
let mut output_with_idx = input.iter().cloned().enumerate().collect::<Vec<(usize, T)>>();
|
||||
output_with_idx.sort_unstable_by_key(|x| x.1);
|
||||
|
||||
if needs_bump {
|
||||
// must increase the values a bit. Bump from the min element. Index of minimum is now zero
|
||||
// because we did a sort. If at any point the min goes greater or equal the `max_threshold`,
|
||||
// we move to the next minimum.
|
||||
let mut min_index = 0;
|
||||
// at this threshold we move to next index.
|
||||
let threshold = targeted_sum / count_t;
|
||||
|
||||
if !per_round.is_zero() {
|
||||
for _ in 0..count {
|
||||
output_with_idx[min_index].1 = output_with_idx[min_index].1
|
||||
.checked_add(&per_round)
|
||||
.expect("Proof provided in the module doc; qed.");
|
||||
if output_with_idx[min_index].1 >= threshold {
|
||||
min_index += 1;
|
||||
min_index = min_index % count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// continue with the previous min_index
|
||||
while !leftover.is_zero() {
|
||||
output_with_idx[min_index].1 = output_with_idx[min_index].1
|
||||
.checked_add(&T::one())
|
||||
.expect("Proof provided in the module doc; qed.");
|
||||
if output_with_idx[min_index].1 >= threshold {
|
||||
min_index += 1;
|
||||
min_index = min_index % count;
|
||||
}
|
||||
leftover -= One::one()
|
||||
}
|
||||
} else {
|
||||
// must decrease the stakes a bit. decrement from the max element. index of maximum is now
|
||||
// last. if at any point the max goes less or equal the `min_threshold`, we move to the next
|
||||
// maximum.
|
||||
let mut max_index = count - 1;
|
||||
// at this threshold we move to next index.
|
||||
let threshold = output_with_idx
|
||||
.first()
|
||||
.expect("length of input is greater than zero; it must have a first; qed")
|
||||
.1;
|
||||
|
||||
if !per_round.is_zero() {
|
||||
for _ in 0..count {
|
||||
output_with_idx[max_index].1 = output_with_idx[max_index].1
|
||||
.checked_sub(&per_round)
|
||||
.unwrap_or_else(|| {
|
||||
let remainder = per_round - output_with_idx[max_index].1;
|
||||
leftover += remainder;
|
||||
output_with_idx[max_index].1.saturating_sub(per_round)
|
||||
});
|
||||
if output_with_idx[max_index].1 <= threshold {
|
||||
max_index = max_index.checked_sub(1).unwrap_or(count - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// continue with the previous max_index
|
||||
while !leftover.is_zero() {
|
||||
if let Some(next) = output_with_idx[max_index].1.checked_sub(&One::one()) {
|
||||
output_with_idx[max_index].1 = next;
|
||||
if output_with_idx[max_index].1 <= threshold {
|
||||
max_index = max_index.checked_sub(1).unwrap_or(count - 1);
|
||||
}
|
||||
leftover -= One::one()
|
||||
} else {
|
||||
max_index = max_index.checked_sub(1).unwrap_or(count - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug_assert_eq!(
|
||||
output_with_idx.iter().fold(T::zero(), |acc, (_, x)| acc + *x),
|
||||
targeted_sum,
|
||||
"sum({:?}) != {:?}",
|
||||
output_with_idx,
|
||||
targeted_sum,
|
||||
);
|
||||
|
||||
// sort again based on the original index.
|
||||
output_with_idx.sort_unstable_by_key(|x| x.0);
|
||||
Ok(output_with_idx.into_iter().map(|(_, t)| t).collect())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod normalize_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn work_for_all_types() {
|
||||
macro_rules! test_for {
|
||||
($type:ty) => {
|
||||
assert_eq!(
|
||||
normalize(vec![8 as $type, 9, 7, 10].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
}
|
||||
}
|
||||
// it should work for all types as long as the length of vector can be converted to T.
|
||||
test_for!(u128);
|
||||
test_for!(u64);
|
||||
test_for!(u32);
|
||||
test_for!(u16);
|
||||
test_for!(u8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fails_on_if_input_sum_large() {
|
||||
assert!(normalize(vec![1u8; 255].as_ref(), 10).is_ok());
|
||||
assert_eq!(
|
||||
normalize(vec![1u8; 256].as_ref(), 10),
|
||||
Err("sum of input cannot fit in `T`"),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn does_not_fail_on_subtraction_overflow() {
|
||||
assert_eq!(
|
||||
normalize(vec![1u8, 100, 100].as_ref(), 10).unwrap(),
|
||||
vec![1, 9, 0],
|
||||
);
|
||||
assert_eq!(
|
||||
normalize(vec![1u8, 8, 9].as_ref(), 1).unwrap(),
|
||||
vec![0, 1, 0],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn works_for_vec() {
|
||||
assert_eq!(vec![8u32, 9, 7, 10].normalize(40).unwrap(), vec![10u32, 10, 10, 10]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn works_for_per_thing() {
|
||||
assert_eq!(
|
||||
vec![
|
||||
Perbill::from_percent(33),
|
||||
Perbill::from_percent(33),
|
||||
Perbill::from_percent(33)
|
||||
].normalize(Perbill::one()).unwrap(),
|
||||
vec![
|
||||
Perbill::from_parts(333333334),
|
||||
Perbill::from_parts(333333333),
|
||||
Perbill::from_parts(333333333),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![
|
||||
Perbill::from_percent(20),
|
||||
Perbill::from_percent(15),
|
||||
Perbill::from_percent(30)
|
||||
].normalize(Perbill::one()).unwrap(),
|
||||
vec![
|
||||
Perbill::from_parts(316666668),
|
||||
Perbill::from_parts(383333332),
|
||||
Perbill::from_parts(300000000),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_work_for_peru16() {
|
||||
// Peru16 is a rather special case; since inner type is exactly the same as capacity, we
|
||||
// could have a situation where the sum cannot be calculated in the inner type. Calculating
|
||||
// using the upper type of the per_thing should assure this to be okay.
|
||||
assert_eq!(
|
||||
vec![
|
||||
PerU16::from_percent(40),
|
||||
PerU16::from_percent(40),
|
||||
PerU16::from_percent(40),
|
||||
].normalize(PerU16::one()).unwrap(),
|
||||
vec![
|
||||
PerU16::from_parts(21845), // 33%
|
||||
PerU16::from_parts(21845), // 33%
|
||||
PerU16::from_parts(21845), // 33%
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_works_all_le() {
|
||||
assert_eq!(
|
||||
normalize(vec![8u32, 9, 7, 10].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![7u32, 7, 7, 7].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![7u32, 7, 7, 10].as_ref(), 40).unwrap(),
|
||||
vec![11, 11, 8, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![7u32, 8, 7, 10].as_ref(), 40).unwrap(),
|
||||
vec![11, 8, 11, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![7u32, 7, 8, 10].as_ref(), 40).unwrap(),
|
||||
vec![11, 11, 8, 10],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_works_some_ge() {
|
||||
assert_eq!(
|
||||
normalize(vec![8u32, 11, 9, 10].as_ref(), 40).unwrap(),
|
||||
vec![10, 11, 9, 10],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn always_inc_min() {
|
||||
assert_eq!(
|
||||
normalize(vec![10u32, 7, 10, 10].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
assert_eq!(
|
||||
normalize(vec![10u32, 10, 7, 10].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
assert_eq!(
|
||||
normalize(vec![10u32, 10, 10, 7].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_works_all_ge() {
|
||||
assert_eq!(
|
||||
normalize(vec![12u32, 11, 13, 10].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![13u32, 13, 13, 13].as_ref(), 40).unwrap(),
|
||||
vec![10, 10, 10, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![13u32, 13, 13, 10].as_ref(), 40).unwrap(),
|
||||
vec![12, 9, 9, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![13u32, 12, 13, 10].as_ref(), 40).unwrap(),
|
||||
vec![9, 12, 9, 10],
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
normalize(vec![13u32, 13, 12, 10].as_ref(), 40).unwrap(),
|
||||
vec![9, 9, 12, 10],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod threshold_compare_tests {
|
||||
use super::*;
|
||||
use crate::traits::Saturating;
|
||||
use sp_std::cmp::Ordering;
|
||||
|
||||
Reference in New Issue
Block a user