Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/mls-rs/src/group/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 36 kB image not shown  

Quelle  secret_tree.rs   Sprache: unbekannt

 
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use alloc::vec::Vec;
use core::{
    fmt::{self, Debug},
    ops::{Deref, DerefMut},
};

use zeroize::Zeroizing;

use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider};

use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::error::IntoAnyError;

#[cfg(feature = "std")]
use std::collections::HashMap;

#[cfg(not(feature = "std"))]
use alloc::collections::BTreeMap;

use super::key_schedule::kdf_expand_with_label;

pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;

#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
enum SecretTreeNode {
    Secret(TreeSecret) = 0u8,
    Ratchet(SecretRatchets) = 1u8,
}

impl SecretTreeNode {
    fn into_secret(self) -> Option<TreeSecret> {
        if let SecretTreeNode::Secret(secret) = self {
            Some(secret)
        } else {
            None
        }
    }
}

#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct TreeSecret(
    #[mls_codec(with = "mls_rs_codec::byte_vec")]
    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
    Zeroizing<Vec<u8>>,
);

impl Debug for TreeSecret {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        mls_rs_core::debug::pretty_bytes(&self.0)
            .named("TreeSecret")
            .fmt(f)
    }
}

impl Deref for TreeSecret {
    type Target = Vec<u8>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl DerefMut for TreeSecret {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl AsRef<[u8]> for TreeSecret {
    fn as_ref(&self) -> &[u8] {
        &self.0
    }
}

impl From<Vec<u8>> for TreeSecret {
    fn from(vec: Vec<u8>) -> Self {
        TreeSecret(Zeroizing::new(vec))
    }
}

impl From<Zeroizing<Vec<u8>>> for TreeSecret {
    fn from(vec: Zeroizing<Vec<u8>>) -> Self {
        TreeSecret(vec)
    }
}

#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct TreeSecretsVec<T: TreeIndex> {
    #[cfg(feature = "std")]
    inner: HashMap<T, SecretTreeNode>,
    #[cfg(not(feature = "std"))]
    inner: Vec<(T, SecretTreeNode)>,
}

#[cfg(feature = "std")]
impl<T: TreeIndex> TreeSecretsVec<T> {
    fn set_node(&mut self, index: T, value: SecretTreeNode) {
        self.inner.insert(index, value);
    }

    fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
        self.inner.remove(index)
    }
}

#[cfg(not(feature = "std"))]
impl<T: TreeIndex> TreeSecretsVec<T> {
    fn set_node(&mut self, index: T, value: SecretTreeNode) {
        if let Some(i) = self.find_node(&index) {
            self.inner[i] = (index, value)
        } else {
            self.inner.push((index, value))
        }
    }

    fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
        self.find_node(index).map(|i| self.inner.remove(i).1)
    }

    fn find_node(&self, index: &T) -> Option<usize> {
        use itertools::Itertools;

        self.inner
            .iter()
            .find_position(|(i, _)| i == index)
            .map(|(i, _)| i)
    }
}

#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SecretTree<T: TreeIndex> {
    known_secrets: TreeSecretsVec<T>,
    leaf_count: T,
}

impl<T: TreeIndex> SecretTree<T> {
    pub(crate) fn empty() -> SecretTree<T> {
        SecretTree {
            known_secrets: Default::default(),
            leaf_count: T::zero(),
        }
    }
}

#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SecretRatchets {
    pub application: SecretKeyRatchet,
    pub handshake: SecretKeyRatchet,
}

impl SecretRatchets {
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn message_key_generation<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite_provider: &P,
        generation: u32,
        key_type: KeyType,
    ) -> Result<MessageKeyData, MlsError> {
        match key_type {
            KeyType::Handshake => {
                self.handshake
                    .get_message_key(cipher_suite_provider, generation)
                    .await
            }
            KeyType::Application => {
                self.application
                    .get_message_key(cipher_suite_provider, generation)
                    .await
            }
        }
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn next_message_key<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite: &P,
        key_type: KeyType,
    ) -> Result<MessageKeyData, MlsError> {
        match key_type {
            KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
            KeyType::Application => self.application.next_message_key(cipher_suite).await,
        }
    }
}

impl<T: TreeIndex> SecretTree<T> {
    pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
        let mut known_secrets = TreeSecretsVec::default();

        let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
        known_secrets.set_node(leaf_count.root(), root_secret);

        Self {
            known_secrets,
            leaf_count,
        }
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn consume_node<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite_provider: &P,
        index: &T,
    ) -> Result<(), MlsError> {
        let node = self.known_secrets.take_node(index);

        if let Some(secret) = node.and_then(|n| n.into_secret()) {
            let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
            let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;

            let left_secret =
                kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
                    .await?;

            let right_secret =
                kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
                    .await?;

            self.known_secrets
                .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));

            self.known_secrets
                .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
        }

        Ok(())
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn take_leaf_ratchet<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite: &P,
        leaf_index: &T,
    ) -> Result<SecretRatchets, MlsError> {
        let node_index = leaf_index;

        let node = match self.known_secrets.take_node(node_index) {
            Some(node) => node,
            None => {
                // Start at the root node and work your way down consuming any intermediates needed
                for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
                    self.consume_node(cipher_suite, &i.path).await?;
                }

                self.known_secrets
                    .take_node(node_index)
                    .ok_or(MlsError::InvalidLeafConsumption)?
            }
        };

        Ok(match node {
            SecretTreeNode::Ratchet(ratchet) => ratchet,
            SecretTreeNode::Secret(secret) => SecretRatchets {
                application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
                    .await?,
                handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
            },
        })
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn next_message_key<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite: &P,
        leaf_index: T,
        key_type: KeyType,
    ) -> Result<MessageKeyData, MlsError> {
        let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
        let res = ratchet.next_message_key(cipher_suite, key_type).await?;

        self.known_secrets
            .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));

        Ok(res)
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn message_key_generation<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite: &P,
        leaf_index: T,
        key_type: KeyType,
        generation: u32,
    ) -> Result<MessageKeyData, MlsError> {
        let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;

        let res = ratchet
            .message_key_generation(cipher_suite, generation, key_type)
            .await?;

        self.known_secrets
            .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));

        Ok(res)
    }
}

#[derive(Clone, Copy)]
pub enum KeyType {
    Handshake,
    Application,
}

#[cfg_attr(
    all(feature = "ffi", not(test)),
    safer_ffi_gen::ffi_type(clone, opaque)
)]
#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
/// AEAD key derived by the MLS secret tree.
pub struct MessageKeyData {
    #[mls_codec(with = "mls_rs_codec::byte_vec")]
    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
    pub(crate) nonce: Zeroizing<Vec<u8>>,
    #[mls_codec(with = "mls_rs_codec::byte_vec")]
    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
    pub(crate) key: Zeroizing<Vec<u8>>,
    pub(crate) generation: u32,
}

impl Debug for MessageKeyData {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("MessageKeyData")
            .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
            .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
            .field("generation", &self.generation)
            .finish()
    }
}

#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl MessageKeyData {
    /// AEAD nonce.
    #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
    pub fn nonce(&self) -> &[u8] {
        &self.nonce
    }

    /// AEAD key.
    #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
    pub fn key(&self) -> &[u8] {
        &self.key
    }

    /// Generation of this key within the key schedule.
    #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
    pub fn generation(&self) -> u32 {
        self.generation
    }
}

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SecretKeyRatchet {
    secret: TreeSecret,
    generation: u32,
    #[cfg(all(feature = "out_of_order", feature = "std"))]
    history: HashMap<u32, MessageKeyData>,
    #[cfg(all(feature = "out_of_order", not(feature = "std")))]
    history: BTreeMap<u32, MessageKeyData>,
}

impl MlsSize for SecretKeyRatchet {
    fn mls_encoded_len(&self) -> usize {
        let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
            + self.generation.mls_encoded_len();

        #[cfg(feature = "out_of_order")]
        return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
        #[cfg(not(feature = "out_of_order"))]
        return len;
    }
}

#[cfg(feature = "out_of_order")]
impl MlsEncode for SecretKeyRatchet {
    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
        mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
        self.generation.mls_encode(writer)?;
        mls_rs_codec::iter::mls_encode(self.history.values(), writer)
    }
}

#[cfg(not(feature = "out_of_order"))]
impl MlsEncode for SecretKeyRatchet {
    fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
        mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
        self.generation.mls_encode(writer)
    }
}

impl MlsDecode for SecretKeyRatchet {
    fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
        Ok(Self {
            secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
            generation: u32::mls_decode(reader)?,
            #[cfg(all(feature = "std", feature = "out_of_order"))]
            history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
                let mut items = HashMap::default();

                while !data.is_empty() {
                    let item = MessageKeyData::mls_decode(data)?;
                    items.insert(item.generation, item);
                }

                Ok(items)
            })?,
            #[cfg(all(not(feature = "std"), feature = "out_of_order"))]
            history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
                let mut items = alloc::collections::BTreeMap::default();

                while !data.is_empty() {
                    let item = MessageKeyData::mls_decode(data)?;
                    items.insert(item.generation, item);
                }

                Ok(items)
            })?,
        })
    }
}

impl SecretKeyRatchet {
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn new<P: CipherSuiteProvider>(
        cipher_suite_provider: &P,
        secret: &[u8],
        key_type: KeyType,
    ) -> Result<Self, MlsError> {
        let label = match key_type {
            KeyType::Handshake => b"handshake".as_slice(),
            KeyType::Application => b"application".as_slice(),
        };

        let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
            .await
            .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;

        Ok(Self {
            secret: TreeSecret::from(secret),
            generation: 0,
            #[cfg(feature = "out_of_order")]
            history: Default::default(),
        })
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn get_message_key<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite_provider: &P,
        generation: u32,
    ) -> Result<MessageKeyData, MlsError> {
        #[cfg(feature = "out_of_order")]
        if generation < self.generation {
            return self
                .history
                .remove_entry(&generation)
                .map(|(_, mk)| mk)
                .ok_or(MlsError::KeyMissing(generation));
        }

        #[cfg(not(feature = "out_of_order"))]
        if generation < self.generation {
            return Err(MlsError::KeyMissing(generation));
        }

        let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;

        if generation > max_generation_allowed {
            return Err(MlsError::InvalidFutureGeneration(generation));
        }

        #[cfg(not(feature = "out_of_order"))]
        while self.generation < generation {
            self.next_message_key(cipher_suite_provider)?;
        }

        #[cfg(feature = "out_of_order")]
        while self.generation < generation {
            let key_data = self.next_message_key(cipher_suite_provider).await?;
            self.history.insert(key_data.generation, key_data);
        }

        self.next_message_key(cipher_suite_provider).await
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn next_message_key<P: CipherSuiteProvider>(
        &mut self,
        cipher_suite_provider: &P,
    ) -> Result<MessageKeyData, MlsError> {
        let generation = self.generation;

        let key = MessageKeyData {
            nonce: self
                .derive_secret(
                    cipher_suite_provider,
                    b"nonce",
                    cipher_suite_provider.aead_nonce_size(),
                )
                .await?,
            key: self
                .derive_secret(
                    cipher_suite_provider,
                    b"key",
                    cipher_suite_provider.aead_key_size(),
                )
                .await?,
            generation,
        };

        self.secret = self
            .derive_secret(
                cipher_suite_provider,
                b"secret",
                cipher_suite_provider.kdf_extract_size(),
            )
            .await?
            .into();

        self.generation = generation + 1;

        Ok(key)
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn derive_secret<P: CipherSuiteProvider>(
        &self,
        cipher_suite_provider: &P,
        label: &[u8],
        len: usize,
    ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
        kdf_expand_with_label(
            cipher_suite_provider,
            self.secret.as_ref(),
            label,
            &self.generation.to_be_bytes(),
            Some(len),
        )
        .await
        .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
    }
}

#[cfg(test)]
pub(crate) mod test_utils {
    use alloc::{string::String, vec::Vec};
    use mls_rs_core::crypto::CipherSuiteProvider;
    use zeroize::Zeroizing;

    use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};

    use super::{KeyType, SecretKeyRatchet, SecretTree};

    pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
        SecretTree::new(leaf_count, Zeroizing::new(secret))
    }

    impl SecretTree<u32> {
        pub(crate) fn get_root_secret(&self) -> Vec<u8> {
            self.known_secrets
                .clone()
                .take_node(&self.leaf_count.root())
                .unwrap()
                .into_secret()
                .unwrap()
                .to_vec()
        }
    }

    #[derive(Debug, serde::Serialize, serde::Deserialize)]
    pub struct RatchetInteropTestCase {
        #[serde(with = "hex::serde")]
        secret: Vec<u8>,
        label: String,
        generation: u32,
        length: usize,
        #[serde(with = "hex::serde")]
        out: Vec<u8>,
    }

    #[derive(Debug, serde::Serialize, serde::Deserialize)]
    pub struct InteropTestCase {
        cipher_suite: u16,
        derive_tree_secret: RatchetInteropTestCase,
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_basic_crypto_test_vectors() {
        let test_cases: Vec<InteropTestCase> =
            load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());

        for test_case in test_cases {
            if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
                test_case.derive_tree_secret.verify(&cs).await
            }
        }
    }

    impl RatchetInteropTestCase {
        #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
        pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
            let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
                .await
                .unwrap();

            ratchet.secret = self.secret.clone().into();
            ratchet.generation = self.generation;

            let computed = ratchet
                .derive_secret(cs, self.label.as_bytes(), self.length)
                .await
                .unwrap();

            assert_eq!(&computed.to_vec(), &self.out);
        }
    }
}

#[cfg(test)]
mod tests {
    use alloc::vec;

    use crate::{
        cipher_suite::CipherSuite,
        client::test_utils::TEST_CIPHER_SUITE,
        crypto::test_utils::{
            test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
        },
        tree_kem::node::NodeIndex,
    };

    #[cfg(not(mls_build_async))]
    use crate::group::test_utils::random_bytes;

    use super::{test_utils::get_test_tree, *};

    use assert_matches::assert_matches;

    #[cfg(target_arch = "wasm32")]
    use wasm_bindgen_test::wasm_bindgen_test as test;

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_secret_tree() {
        test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
        test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn test_secret_tree_custom<T: TreeIndex>(
        leaf_count: T,
        leaves_to_check: Vec<T>,
        all_deleted: bool,
    ) {
        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
            let cs_provider = test_cipher_suite_provider(cipher_suite);

            let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
            let mut test_tree = get_test_tree(test_secret, leaf_count.clone());

            let mut secrets = Vec::<SecretRatchets>::new();

            for i in &leaves_to_check {
                let secret = test_tree
                    .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
                    .await
                    .unwrap();

                secrets.push(secret);
            }

            // Verify the tree is now completely empty
            assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());

            // Verify that all the secrets are unique
            let count = secrets.len();
            secrets.dedup();
            assert_eq!(count, secrets.len());
        }
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_secret_key_ratchet() {
        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
            let provider = test_cipher_suite_provider(cipher_suite);

            let mut app_ratchet = SecretKeyRatchet::new(
                &provider,
                &vec![0u8; provider.kdf_extract_size()],
                KeyType::Application,
            )
            .await
            .unwrap();

            let mut handshake_ratchet = SecretKeyRatchet::new(
                &provider,
                &vec![0u8; provider.kdf_extract_size()],
                KeyType::Handshake,
            )
            .await
            .unwrap();

            let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
            let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
            let app_keys = vec![app_key_one, app_key_two];

            let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
            let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
            let handshake_keys = vec![handshake_key_one, handshake_key_two];

            // Verify that the keys have different outcomes due to their different labels
            assert_ne!(app_keys, handshake_keys);

            // Verify that the keys at each generation are different
            assert_ne!(handshake_keys[0], handshake_keys[1]);
        }
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_get_key() {
        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
            let provider = test_cipher_suite_provider(cipher_suite);

            let mut ratchet = SecretKeyRatchet::new(
                &test_cipher_suite_provider(cipher_suite),
                &vec![0u8; provider.kdf_extract_size()],
                KeyType::Application,
            )
            .await
            .unwrap();

            let mut ratchet_clone = ratchet.clone();

            // This will generate keys 0 and 1 in ratchet_clone
            let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
            let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();

            // Going back in time should result in an error
            let res = ratchet_clone.get_message_key(&provider, 0).await;
            assert!(res.is_err());

            // Calling get key should be the same as calling next until hitting the desired generation
            let second_key = ratchet
                .get_message_key(&provider, ratchet_clone.generation - 1)
                .await
                .unwrap();

            assert_eq!(clone_2, second_key)
        }
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_secret_ratchet() {
        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
            let provider = test_cipher_suite_provider(cipher_suite);

            let mut ratchet = SecretKeyRatchet::new(
                &provider,
                &vec![0u8; provider.kdf_extract_size()],
                KeyType::Application,
            )
            .await
            .unwrap();

            let original_secret = ratchet.secret.clone();
            let _ = ratchet.next_message_key(&provider).await.unwrap();
            let new_secret = ratchet.secret;
            assert_ne!(original_secret, new_secret)
        }
    }

    #[cfg(feature = "out_of_order")]
    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_out_of_order_keys() {
        let cipher_suite = TEST_CIPHER_SUITE;
        let provider = test_cipher_suite_provider(cipher_suite);

        let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
            .await
            .unwrap();
        let mut ratchet_clone = ratchet.clone();

        // Ask for all the keys in order from the original ratchet
        let mut ordered_keys = Vec::<MessageKeyData>::new();

        for i in 0..=MAX_RATCHET_BACK_HISTORY {
            ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
        }

        // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone
        let last_key = ratchet_clone
            .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
            .await
            .unwrap();

        assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);

        // Get all the other keys
        let mut back_history_keys = Vec::<MessageKeyData>::new();

        for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
            back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
        }

        assert_eq!(
            back_history_keys,
            ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
        );
    }

    #[cfg(not(feature = "out_of_order"))]
    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn out_of_order_keys_should_throw_error() {
        let cipher_suite = TEST_CIPHER_SUITE;
        let provider = test_cipher_suite_provider(cipher_suite);

        let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
            .await
            .unwrap();

        ratchet.get_message_key(&provider, 10).await.unwrap();
        let res = ratchet.get_message_key(&provider, 9).await;
        assert_matches!(res, Err(MlsError::KeyMissing(9)))
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_too_out_of_order() {
        let cipher_suite = TEST_CIPHER_SUITE;
        let provider = test_cipher_suite_provider(cipher_suite);

        let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
            .await
            .unwrap();

        let res = ratchet
            .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
            .await;

        let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;

        assert_matches!(
            res,
            Err(MlsError::InvalidFutureGeneration(invalid))
            if invalid == invalid_generation
        )
    }

    #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
    struct Ratchet {
        application_keys: Vec<Vec<u8>>,
        handshake_keys: Vec<Vec<u8>>,
    }

    #[derive(Debug, serde::Serialize, serde::Deserialize)]
    struct TestCase {
        cipher_suite: u16,
        #[serde(with = "hex::serde")]
        encryption_secret: Vec<u8>,
        ratchets: Vec<Ratchet>,
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn get_ratchet_data(
        secret_tree: &mut SecretTree<NodeIndex>,
        cipher_suite: CipherSuite,
    ) -> Vec<Ratchet> {
        let provider = test_cipher_suite_provider(cipher_suite);
        let mut ratchet_data = Vec::new();

        for index in 0..16 {
            let mut ratchets = secret_tree
                .take_leaf_ratchet(&provider, &(index * 2))
                .await
                .unwrap();

            let mut application_keys = Vec::new();

            for _ in 0..20 {
                let key = ratchets
                    .handshake
                    .next_message_key(&provider)
                    .await
                    .unwrap()
                    .mls_encode_to_vec()
                    .unwrap();

                application_keys.push(key);
            }

            let mut handshake_keys = Vec::new();

            for _ in 0..20 {
                let key = ratchets
                    .handshake
                    .next_message_key(&provider)
                    .await
                    .unwrap()
                    .mls_encode_to_vec()
                    .unwrap();

                handshake_keys.push(key);
            }

            ratchet_data.push(Ratchet {
                application_keys,
                handshake_keys,
            });
        }

        ratchet_data
    }

    #[cfg(not(mls_build_async))]
    #[cfg_attr(coverage_nightly, coverage(off))]
    fn generate_test_vector() -> Vec<TestCase> {
        CipherSuite::all()
            .map(|cipher_suite| {
                let provider = test_cipher_suite_provider(cipher_suite);
                let encryption_secret = random_bytes(provider.kdf_extract_size());

                let mut secret_tree =
                    SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));

                TestCase {
                    cipher_suite: cipher_suite.into(),
                    encryption_secret,
                    ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
                }
            })
            .collect()
    }

    #[cfg(mls_build_async)]
    fn generate_test_vector() -> Vec<TestCase> {
        panic!("Tests cannot be generated in async mode");
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_secret_tree_test_vectors() {
        let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());

        for case in test_cases {
            let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
                continue;
            };

            let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
            let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;

            assert_eq!(ratchet_data, case.ratchets);
        }
    }
}

#[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
mod interop_tests {
    #[cfg(not(mls_build_async))]
    use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
    use zeroize::Zeroizing;

    use crate::{
        crypto::test_utils::try_test_cipher_suite_provider,
        group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
    };

    use super::SecretTree;

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn interop_test_vector() {
        // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json
        let test_cases = load_interop_test_cases();

        for case in test_cases {
            let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
                continue;
            };

            case.sender_data.verify(&cs).await;

            let mut tree = SecretTree::new(
                case.leaves.len() as u32,
                Zeroizing::new(case.encryption_secret),
            );

            for (index, leaves) in case.leaves.iter().enumerate() {
                for leaf in leaves.iter() {
                    let key = tree
                        .message_key_generation(
                            &cs,
                            (index as u32) * 2,
                            KeyType::Application,
                            leaf.generation,
                        )
                        .await
                        .unwrap();

                    assert_eq!(key.key.to_vec(), leaf.application_key);
                    assert_eq!(key.nonce.to_vec(), leaf.application_nonce);

                    let key = tree
                        .message_key_generation(
                            &cs,
                            (index as u32) * 2,
                            KeyType::Handshake,
                            leaf.generation,
                        )
                        .await
                        .unwrap();

                    assert_eq!(key.key.to_vec(), leaf.handshake_key);
                    assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
                }
            }
        }
    }

    #[derive(Debug, serde::Serialize, serde::Deserialize)]
    struct InteropTestCase {
        cipher_suite: u16,
        #[serde(with = "hex::serde")]
        encryption_secret: Vec<u8>,
        sender_data: InteropSenderData,
        leaves: Vec<Vec<InteropLeaf>>,
    }

    #[derive(Debug, serde::Serialize, serde::Deserialize)]
    struct InteropLeaf {
        generation: u32,
        #[serde(with = "hex::serde")]
        application_key: Vec<u8>,
        #[serde(with = "hex::serde")]
        application_nonce: Vec<u8>,
        #[serde(with = "hex::serde")]
        handshake_key: Vec<u8>,
        #[serde(with = "hex::serde")]
        handshake_nonce: Vec<u8>,
    }

    fn load_interop_test_cases() -> Vec<InteropTestCase> {
        load_test_case_json!(secret_tree_interop, generate_test_vector())
    }

    #[cfg(not(mls_build_async))]
    #[cfg_attr(coverage_nightly, coverage(off))]
    fn generate_test_vector() -> Vec<InteropTestCase> {
        let mut test_cases = vec![];

        for cs in CipherSuite::all() {
            let Some(cs) = try_test_cipher_suite_provider(*cs) else {
                continue;
            };

            let gens = [0, 15];
            let tree_sizes = [1, 8, 32];

            for n_leaves in tree_sizes {
                let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();

                let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));

                let leaves = (0..n_leaves)
                    .map(|leaf| {
                        gens.into_iter()
                            .map(|gen| {
                                let index = leaf * 2u32;

                                let handshake_key = tree
                                    .message_key_generation(&cs, index, KeyType::Handshake, gen)
                                    .unwrap();

                                let app_key = tree
                                    .message_key_generation(&cs, index, KeyType::Application, gen)
                                    .unwrap();

                                InteropLeaf {
                                    generation: gen,
                                    application_key: app_key.key.to_vec(),
                                    application_nonce: app_key.nonce.to_vec(),
                                    handshake_key: handshake_key.key.to_vec(),
                                    handshake_nonce: handshake_key.nonce.to_vec(),
                                }
                            })
                            .collect()
                    })
                    .collect();

                let case = InteropTestCase {
                    cipher_suite: *cs.cipher_suite(),
                    encryption_secret,
                    sender_data: InteropSenderData::new(&cs),
                    leaves,
                };

                test_cases.push(case);
            }
        }

        test_cases
    }

    #[cfg(mls_build_async)]
    fn generate_test_vector() -> Vec<InteropTestCase> {
        panic!("Tests cannot be generated in async mode");
    }
}

[ Dauer der Verarbeitung: 0.34 Sekunden  (vorverarbeitet)  ]