Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/mls-rs/src/tree_kem/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 10 kB image not shown  

Quelle  math.rs   Sprache: unbekannt

 
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use alloc::vec::Vec;
use core::{fmt::Debug, hash::Hash};
use mls_rs_codec::{MlsDecode, MlsEncode};

use super::node::LeafIndex;

pub trait TreeIndex:
    Send + Sync + Eq + Clone + Debug + Default + MlsEncode + MlsDecode + Hash + Ord
{
    fn root(&self) -> Self;

    fn left_unchecked(&self) -> Self;
    fn right_unchecked(&self) -> Self;

    fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>;
    fn is_leaf(&self) -> bool;
    fn is_in_tree(&self, root: &Self) -> bool;

    #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
    fn zero() -> Self;

    #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
    fn left(&self) -> Option<Self> {
        (!self.is_leaf()).then(|| self.left_unchecked())
    }

    #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
    fn right(&self) -> Option<Self> {
        (!self.is_leaf()).then(|| self.right_unchecked())
    }

    fn direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>> {
        let root = leaf_count.root();

        if !self.is_in_tree(&root) {
            return Vec::new();
        }

        let mut path = Vec::new();
        let mut parent = self.clone();

        while let Some(ps) = parent.parent_sibling(leaf_count) {
            path.push(CopathNode::new(ps.parent.clone(), ps.sibling));
            parent = ps.parent;
        }

        path
    }
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct CopathNode<T> {
    pub path: T,
    pub copath: T,
}

impl<T: Clone + PartialEq + Eq + core::fmt::Debug> CopathNode<T> {
    pub fn new(path: T, copath: T) -> CopathNode<T> {
        CopathNode { path, copath }
    }
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct ParentSibling<T> {
    pub parent: T,
    pub sibling: T,
}

impl<T: Clone + PartialEq + Eq + core::fmt::Debug> ParentSibling<T> {
    pub fn new(parent: T, sibling: T) -> ParentSibling<T> {
        ParentSibling { parent, sibling }
    }
}

macro_rules! impl_tree_stdint {
    ($t:ty) => {
        impl TreeIndex for $t {
            fn root(&self) -> $t {
                *self - 1
            }

            /// Panicks if `x` is even in debug, overflows in release.
            fn left_unchecked(&self) -> Self {
                *self ^ (0x01 << (self.trailing_ones() - 1))
            }

            /// Panicks if `x` is even in debug, overflows in release.
            fn right_unchecked(&self) -> Self {
                *self ^ (0x03 << (self.trailing_ones() - 1))
            }

            fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>> {
                if self == &leaf_count.root() {
                    return None;
                }

                let lvl = self.trailing_ones();
                let p = (self & !(1 << (lvl + 1))) | (1 << lvl);

                let s = if *self < p {
                    p.right_unchecked()
                } else {
                    p.left_unchecked()
                };

                Some(ParentSibling::new(p, s))
            }

            fn is_leaf(&self) -> bool {
                self & 1 == 0
            }

            fn is_in_tree(&self, root: &Self) -> bool {
                *self <= 2 * root
            }

            #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
            fn zero() -> Self {
                0
            }
        }
    };
}

impl_tree_stdint!(u32);

#[cfg(test)]
impl_tree_stdint!(u64);

pub fn leaf_lca_level(x: u32, y: u32) -> u32 {
    let mut xn = x;
    let mut yn = y;
    let mut k = 0;

    while xn != yn {
        xn >>= 1;
        yn >>= 1;
        k += 1;
    }

    k
}

pub fn subtree(x: u32) -> (LeafIndex, LeafIndex) {
    let breadth = 1 << x.trailing_ones();
    (
        LeafIndex((x + 1 - breadth) >> 1),
        LeafIndex(((x + breadth) >> 1) + 1),
    )
}

pub struct BfsIterTopDown {
    level: usize,
    mask: usize,
    level_end: usize,
    ctr: usize,
}

impl BfsIterTopDown {
    pub fn new(num_leaves: usize) -> Self {
        let depth = num_leaves.trailing_zeros() as usize;
        Self {
            level: depth + 1,
            mask: (1 << depth) - 1,
            level_end: 1,
            ctr: 0,
        }
    }
}

impl Iterator for BfsIterTopDown {
    type Item = usize;

    fn next(&mut self) -> Option<Self::Item> {
        if self.ctr == self.level_end {
            if self.level == 1 {
                return None;
            }
            self.level_end = (((self.level_end - 1) << 1) | 1) + 1;
            self.level -= 1;
            self.ctr = 0;
            self.mask >>= 1;
        }
        let res = Some((self.ctr << self.level) | self.mask);
        self.ctr += 1;
        res
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use itertools::Itertools;
    use serde::{Deserialize, Serialize};

    #[cfg(target_arch = "wasm32")]
    use wasm_bindgen_test::wasm_bindgen_test as test;

    #[derive(Serialize, Deserialize)]
    struct TestCase {
        n_leaves: u32,
        n_nodes: u32,
        root: u32,
        left: Vec<Option<u32>>,
        right: Vec<Option<u32>>,
        parent: Vec<Option<u32>>,
        sibling: Vec<Option<u32>>,
    }

    pub fn node_width(n: u32) -> u32 {
        if n == 0 {
            0
        } else {
            2 * (n - 1) + 1
        }
    }

    #[test]
    fn test_bfs_iterator() {
        let expected = [7, 3, 11, 1, 5, 9, 13, 0, 2, 4, 6, 8, 10, 12, 14];
        let bfs = BfsIterTopDown::new(8);
        assert_eq!(bfs.collect::<Vec<_>>(), expected);
    }

    #[cfg_attr(coverage_nightly, coverage(off))]
    fn generate_tree_math_test_cases() -> Vec<TestCase> {
        let mut test_cases = Vec::new();

        for log_n_leaves in 0..8 {
            let n_leaves = 1 << log_n_leaves;
            let n_nodes = node_width(n_leaves);
            let left = (0..n_nodes).map(|x| x.left()).collect::<Vec<_>>();
            let right = (0..n_nodes).map(|x| x.right()).collect::<Vec<_>>();

            let (parent, sibling) = (0..n_nodes)
                .map(|x| {
                    x.parent_sibling(&n_leaves)
                        .map(|ps| (ps.parent, ps.sibling))
                        .unzip()
                })
                .unzip();

            test_cases.push(TestCase {
                n_leaves,
                n_nodes,
                root: n_leaves.root(),
                left,
                right,
                parent,
                sibling,
            })
        }

        test_cases
    }

    fn load_test_cases() -> Vec<TestCase> {
        load_test_case_json!(tree_math, generate_tree_math_test_cases())
    }

    #[test]
    fn test_tree_math() {
        let test_cases = load_test_cases();

        for case in test_cases {
            assert_eq!(node_width(case.n_leaves), case.n_nodes);
            assert_eq!(case.n_leaves.root(), case.root);

            for x in 0..case.n_nodes {
                assert_eq!(x.left(), case.left[x as usize]);
                assert_eq!(x.right(), case.right[x as usize]);

                let (p, s) = x
                    .parent_sibling(&case.n_leaves)
                    .map(|ps| (ps.parent, ps.sibling))
                    .unzip();

                assert_eq!(p, case.parent[x as usize]);
                assert_eq!(s, case.sibling[x as usize]);
            }
        }
    }

    #[test]
    fn test_direct_path() {
        let expected: Vec<Vec<u32>> = [
            [0x01, 0x03, 0x07, 0x0f].to_vec(),
            [0x03, 0x07, 0x0f].to_vec(),
            [0x01, 0x03, 0x07, 0x0f].to_vec(),
            [0x07, 0x0f].to_vec(),
            [0x05, 0x03, 0x07, 0x0f].to_vec(),
            [0x03, 0x07, 0x0f].to_vec(),
            [0x05, 0x03, 0x07, 0x0f].to_vec(),
            [0x0f].to_vec(),
            [0x09, 0x0b, 0x07, 0x0f].to_vec(),
            [0x0b, 0x07, 0x0f].to_vec(),
            [0x09, 0x0b, 0x07, 0x0f].to_vec(),
            [0x07, 0x0f].to_vec(),
            [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
            [0x0b, 0x07, 0x0f].to_vec(),
            [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
            [].to_vec(),
            [0x11, 0x13, 0x17, 0x0f].to_vec(),
            [0x13, 0x17, 0x0f].to_vec(),
            [0x11, 0x13, 0x17, 0x0f].to_vec(),
            [0x17, 0x0f].to_vec(),
            [0x15, 0x13, 0x17, 0x0f].to_vec(),
            [0x13, 0x17, 0x0f].to_vec(),
            [0x15, 0x13, 0x17, 0x0f].to_vec(),
            [0x0f].to_vec(),
            [0x19, 0x1b, 0x17, 0x0f].to_vec(),
            [0x1b, 0x17, 0x0f].to_vec(),
            [0x19, 0x1b, 0x17, 0x0f].to_vec(),
            [0x17, 0x0f].to_vec(),
            [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
            [0x1b, 0x17, 0x0f].to_vec(),
            [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
        ]
        .to_vec();

        for (i, item) in expected.iter().enumerate() {
            let path = (i as u32)
                .direct_copath(&16)
                .into_iter()
                .map(|cp| cp.path)
                .collect_vec();

            assert_eq!(item, &path)
        }
    }

    #[test]
    fn test_copath_path() {
        let expected: Vec<Vec<u32>> = [
            [0x02, 0x05, 0x0b, 0x17].to_vec(),
            [0x05, 0x0b, 0x17].to_vec(),
            [0x00, 0x05, 0x0b, 0x17].to_vec(),
            [0x0b, 0x17].to_vec(),
            [0x06, 0x01, 0x0b, 0x17].to_vec(),
            [0x01, 0x0b, 0x17].to_vec(),
            [0x04, 0x01, 0x0b, 0x17].to_vec(),
            [0x17].to_vec(),
            [0x0a, 0x0d, 0x03, 0x17].to_vec(),
            [0x0d, 0x03, 0x17].to_vec(),
            [0x08, 0x0d, 0x03, 0x17].to_vec(),
            [0x03, 0x17].to_vec(),
            [0x0e, 0x09, 0x03, 0x17].to_vec(),
            [0x09, 0x03, 0x17].to_vec(),
            [0x0c, 0x09, 0x03, 0x17].to_vec(),
            [].to_vec(),
            [0x12, 0x15, 0x1b, 0x07].to_vec(),
            [0x15, 0x1b, 0x07].to_vec(),
            [0x10, 0x15, 0x1b, 0x07].to_vec(),
            [0x1b, 0x07].to_vec(),
            [0x16, 0x11, 0x1b, 0x07].to_vec(),
            [0x11, 0x1b, 0x07].to_vec(),
            [0x14, 0x11, 0x1b, 0x07].to_vec(),
            [0x07].to_vec(),
            [0x1a, 0x1d, 0x13, 0x07].to_vec(),
            [0x1d, 0x13, 0x07].to_vec(),
            [0x18, 0x1d, 0x13, 0x07].to_vec(),
            [0x13, 0x07].to_vec(),
            [0x1e, 0x19, 0x13, 0x07].to_vec(),
            [0x19, 0x13, 0x07].to_vec(),
            [0x1c, 0x19, 0x13, 0x07].to_vec(),
        ]
        .to_vec();

        for (i, item) in expected.iter().enumerate() {
            let copath = (i as u32)
                .direct_copath(&16)
                .into_iter()
                .map(|cp| cp.copath)
                .collect_vec();

            assert_eq!(item, &copath)
        }
    }
}

[ Dauer der Verarbeitung: 0.25 Sekunden  (vorverarbeitet)  ]