Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/mls-rs/src/tree_kem/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 23 kB image not shown  

Quelle  kem.rs   Sprache: unbekannt

 
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use crate::client::MlsError;
use crate::crypto::{CipherSuiteProvider, SignatureSecretKey};
use crate::group::GroupContext;
use crate::identity::SigningIdentity;
use crate::iter::wrap_iter;
use crate::tree_kem::math as tree_math;
use alloc::vec;
use alloc::vec::Vec;
use itertools::Itertools;
use mls_rs_codec::MlsEncode;
use tree_math::{CopathNode, TreeIndex};

#[cfg(all(not(mls_build_async), feature = "rayon"))]
use {crate::iter::ParallelIteratorExt, rayon::prelude::*};

#[cfg(mls_build_async)]
use futures::{StreamExt, TryStreamExt};

#[cfg(feature = "std")]
use std::collections::HashSet;

use super::hpke_encryption::HpkeEncryptable;
use super::leaf_node::ConfigProperties;
use super::node::NodeTypeResolver;
use super::{
    node::{LeafIndex, NodeIndex},
    path_secret::{PathSecret, PathSecretGenerator},
    TreeKemPrivate, TreeKemPublic, UpdatePath, UpdatePathNode, ValidatedUpdatePath,
};

#[cfg(test)]
use crate::{group::CommitModifiers, signer::Signable};

pub struct TreeKem<'a> {
    tree_kem_public: &'a mut TreeKemPublic,
    private_key: &'a mut TreeKemPrivate,
}

pub struct EncapGeneration {
    pub update_path: UpdatePath,
    pub path_secrets: Vec<Option<PathSecret>>,
    pub commit_secret: PathSecret,
}

impl<'a> TreeKem<'a> {
    pub fn new(
        tree_kem_public: &'a mut TreeKemPublic,
        private_key: &'a mut TreeKemPrivate,
    ) -> Self {
        TreeKem {
            tree_kem_public,
            private_key,
        }
    }

    #[allow(clippy::too_many_arguments)]
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn encap<P>(
        self,
        context: &mut GroupContext,
        excluding: &[LeafIndex],
        signer: &SignatureSecretKey,
        update_leaf_properties: ConfigProperties,
        signing_identity: Option<SigningIdentity>,
        cipher_suite_provider: &P,
        #[cfg(test)] commit_modifiers: &CommitModifiers,
    ) -> Result<EncapGeneration, MlsError>
    where
        P: CipherSuiteProvider + Send + Sync,
    {
        let self_index = self.private_key.self_index;
        let path = self.tree_kem_public.nodes.direct_copath(self_index);
        let filtered = self.tree_kem_public.nodes.filtered(self_index)?;

        self.private_key.secret_keys.resize(path.len() + 1, None);

        let mut secret_generator = PathSecretGenerator::new(cipher_suite_provider);
        let mut path_secrets = vec![];

        for (i, (node, f)) in path.iter().zip(&filtered).enumerate() {
            if !f {
                let secret = secret_generator.next_secret().await?;

                let (secret_key, public_key) =
                    secret.to_hpke_key_pair(cipher_suite_provider).await?;

                self.private_key.secret_keys[i + 1] = Some(secret_key);
                self.tree_kem_public.update_node(public_key, node.path)?;
                path_secrets.push(Some(secret));
            } else {
                self.private_key.secret_keys[i + 1] = None;
                path_secrets.push(None);
            }
        }

        #[cfg(test)]
        (commit_modifiers.modify_tree)(self.tree_kem_public);

        self.tree_kem_public
            .update_parent_hashes(self_index, false, cipher_suite_provider)
            .await?;

        let update_path_leaf = {
            let own_leaf = self.tree_kem_public.nodes.borrow_as_leaf_mut(self_index)?;

            self.private_key.secret_keys[0] = Some(
                own_leaf
                    .commit(
                        cipher_suite_provider,
                        &context.group_id,
                        *self_index,
                        update_leaf_properties,
                        signing_identity,
                        signer,
                    )
                    .await?,
            );

            #[cfg(test)]
            if let Some(signer) = (commit_modifiers.modify_leaf)(own_leaf, signer) {
                let context = &(context.group_id.as_slice(), *self_index).into();

                own_leaf
                    .sign(cipher_suite_provider, &signer, context)
                    .await
                    .unwrap();
            }

            own_leaf.clone()
        };

        // Tree modifications are all done so we can update the tree hash and encrypt with the new context
        self.tree_kem_public
            .update_hashes(&[self_index], cipher_suite_provider)
            .await?;

        context.tree_hash = self
            .tree_kem_public
            .tree_hash(cipher_suite_provider)
            .await?;

        let context_bytes = context.mls_encode_to_vec()?;

        let node_updates = self
            .encrypt_path_secrets(
                path,
                &path_secrets,
                &context_bytes,
                cipher_suite_provider,
                excluding,
            )
            .await?;

        #[cfg(test)]
        let node_updates = (commit_modifiers.modify_path)(node_updates);

        // Create an update path with the new node and parent node updates
        let update_path = UpdatePath {
            leaf_node: update_path_leaf,
            nodes: node_updates,
        };

        Ok(EncapGeneration {
            update_path,
            path_secrets,
            commit_secret: secret_generator.next_secret().await?,
        })
    }

    #[cfg(any(mls_build_async, not(feature = "rayon")))]
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn encrypt_path_secrets<P: CipherSuiteProvider>(
        &self,
        path: Vec<CopathNode<NodeIndex>>,
        path_secrets: &[Option<PathSecret>],
        context_bytes: &[u8],
        cipher_suite: &P,
        excluding: &[LeafIndex],
    ) -> Result<Vec<UpdatePathNode>, MlsError> {
        let excluding = excluding.iter().copied().map(NodeIndex::from);

        #[cfg(feature = "std")]
        let excluding = excluding.collect::<HashSet<NodeIndex>>();
        #[cfg(not(feature = "std"))]
        let excluding = excluding.collect::<Vec<NodeIndex>>();

        let mut node_updates = Vec::new();

        for (index, path_secret) in path.into_iter().zip(path_secrets.iter()) {
            if let Some(path_secret) = path_secret {
                node_updates.push(
                    self.encrypt_copath_node_resolution(
                        cipher_suite,
                        path_secret,
                        index.copath,
                        context_bytes,
                        &excluding,
                    )
                    .await?,
                );
            }
        }

        Ok(node_updates)
    }

    #[cfg(all(not(mls_build_async), feature = "rayon"))]
    fn encrypt_path_secrets<P: CipherSuiteProvider>(
        &self,
        path: Vec<CopathNode<NodeIndex>>,
        path_secrets: &[Option<PathSecret>],
        context_bytes: &[u8],
        cipher_suite: &P,
        excluding: &[LeafIndex],
    ) -> Result<Vec<UpdatePathNode>, MlsError> {
        let excluding = excluding.iter().copied().map(NodeIndex::from);

        #[cfg(feature = "std")]
        let excluding = excluding.collect::<HashSet<NodeIndex>>();
        #[cfg(not(feature = "std"))]
        let excluding = excluding.collect::<Vec<NodeIndex>>();

        path.into_par_iter()
            .zip(path_secrets.par_iter())
            .filter_map(|(node, path_secret)| {
                path_secret.as_ref().map(|path_secret| {
                    self.encrypt_copath_node_resolution(
                        cipher_suite,
                        path_secret,
                        node.copath,
                        context_bytes,
                        &excluding,
                    )
                })
            })
            .collect()
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn decap<CP>(
        self,
        sender_index: LeafIndex,
        update_path: &ValidatedUpdatePath,
        added_leaves: &[LeafIndex],
        context_bytes: &[u8],
        cipher_suite_provider: &CP,
    ) -> Result<PathSecret, MlsError>
    where
        CP: CipherSuiteProvider,
    {
        let self_index = self.private_key.self_index;

        let lca_index =
            tree_math::leaf_lca_level(self_index.into(), sender_index.into()) as usize - 2;

        let mut path = self.tree_kem_public.nodes.direct_copath(self_index);
        let leaf = CopathNode::new(self_index.into(), 0);
        path.insert(0, leaf);
        let resolved_pos = self.find_resolved_pos(&path, lca_index)?;

        let ct_pos =
            self.find_ciphertext_pos(path[lca_index].path, path[resolved_pos].path, added_leaves)?;

        let lca_node = update_path.nodes[lca_index]
            .as_ref()
            .ok_or(MlsError::LcaNotFoundInDirectPath)?;

        let ct = lca_node
            .encrypted_path_secret
            .get(ct_pos)
            .ok_or(MlsError::LcaNotFoundInDirectPath)?;

        let secret = self.private_key.secret_keys[resolved_pos]
            .as_ref()
            .ok_or(MlsError::UpdateErrorNoSecretKey)?;

        let public = self
            .tree_kem_public
            .nodes
            .borrow_node(path[resolved_pos].path)?
            .as_ref()
            .ok_or(MlsError::UpdateErrorNoSecretKey)?
            .public_key();

        let lca_path_secret =
            PathSecret::decrypt(cipher_suite_provider, secret, public, context_bytes, ct).await?;

        // Derive the rest of the secrets for the tree and assign to the proper nodes
        let mut node_secret_gen =
            PathSecretGenerator::starting_with(cipher_suite_provider, lca_path_secret);

        // Update secrets based on the decrypted path secret in the update
        self.private_key.secret_keys.resize(path.len() + 1, None);

        for (i, update) in update_path.nodes.iter().enumerate().skip(lca_index) {
            if let Some(update) = update {
                let secret = node_secret_gen.next_secret().await?;

                // Verify the private key we calculated properly matches the public key we inserted into the tree. This guarantees
                // that we will be able to decrypt later.
                let (hpke_private, hpke_public) =
                    secret.to_hpke_key_pair(cipher_suite_provider).await?;

                if hpke_public != update.public_key {
                    return Err(MlsError::PubKeyMismatch);
                }

                self.private_key.secret_keys[i + 1] = Some(hpke_private);
            } else {
                self.private_key.secret_keys[i + 1] = None;
            }
        }

        node_secret_gen.next_secret().await
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn encrypt_copath_node_resolution<P: CipherSuiteProvider>(
        &self,
        cipher_suite_provider: &P,
        path_secret: &PathSecret,
        copath_index: NodeIndex,
        context: &[u8],
        #[cfg(feature = "std")] excluding: &HashSet<NodeIndex>,
        #[cfg(not(feature = "std"))] excluding: &[NodeIndex],
    ) -> Result<UpdatePathNode, MlsError> {
        let reso = self
            .tree_kem_public
            .nodes
            .get_resolution_index(copath_index)?;

        let make_ctxt = |idx| async move {
            let node = self
                .tree_kem_public
                .nodes
                .borrow_node(idx)?
                .as_non_empty()?;

            path_secret
                .encrypt(cipher_suite_provider, node.public_key(), context)
                .await
        };

        let ctxts = wrap_iter(reso).filter(|&idx| async move { !excluding.contains(&idx) });

        #[cfg(not(mls_build_async))]
        let ctxts = ctxts.map(make_ctxt);

        #[cfg(mls_build_async)]
        let ctxts = ctxts.then(make_ctxt);

        let ctxts = ctxts.try_collect().await?;

        let path_index = copath_index
            .parent_sibling(&self.tree_kem_public.total_leaf_count())
            .ok_or(MlsError::ExpectedNode)?
            .parent;

        Ok(UpdatePathNode {
            public_key: self
                .tree_kem_public
                .nodes
                .borrow_as_parent(path_index)?
                .public_key
                .clone(),
            encrypted_path_secret: ctxts,
        })
    }

    #[inline]
    fn find_resolved_pos(
        &self,
        path: &[CopathNode<NodeIndex>],
        mut lca_index: usize,
    ) -> Result<usize, MlsError> {
        while self.tree_kem_public.nodes.is_blank(path[lca_index].path)? {
            lca_index -= 1;
        }

        // If we don't have the key, we should be an unmerged leaf at the resolved node. (If
        // we're not, an error will be thrown later.)
        if self.private_key.secret_keys[lca_index].is_none() {
            lca_index = 0;
        }

        Ok(lca_index)
    }

    #[inline]
    fn find_ciphertext_pos(
        &self,
        lca: NodeIndex,
        resolved: NodeIndex,
        excluding: &[LeafIndex],
    ) -> Result<usize, MlsError> {
        let reso = self.tree_kem_public.nodes.get_resolution_index(lca)?;

        let (ct_pos, _) = reso
            .iter()
            .filter(|idx| **idx % 2 == 1 || !excluding.contains(&LeafIndex(**idx / 2)))
            .find_position(|idx| idx == &&resolved)
            .ok_or(MlsError::UpdateErrorNoSecretKey)?;

        Ok(ct_pos)
    }
}

#[cfg(test)]
mod tests {
    use super::{tree_math, TreeKem};
    use crate::{
        cipher_suite::CipherSuite,
        client::test_utils::TEST_CIPHER_SUITE,
        crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
        extension::test_utils::TestExtension,
        group::test_utils::{get_test_group_context, random_bytes},
        identity::basic::BasicIdentityProvider,
        tree_kem::{
            leaf_node::{
                test_utils::{get_basic_test_node_sig_key, get_test_capabilities},
                ConfigProperties,
            },
            node::LeafIndex,
            Capabilities, TreeKemPrivate, TreeKemPublic, UpdatePath, ValidatedUpdatePath,
        },
        ExtensionList,
    };
    use alloc::{format, vec, vec::Vec};
    use mls_rs_codec::MlsEncode;
    use mls_rs_core::crypto::CipherSuiteProvider;
    use tree_math::TreeIndex;

    // Verify that the tree is in the correct state after generating an update path
    fn verify_tree_update_path(
        tree: &TreeKemPublic,
        update_path: &UpdatePath,
        index: LeafIndex,
        capabilities: Option<Capabilities>,
        extensions: Option<ExtensionList>,
    ) {
        // Make sure the update path is based on the direct path of the sender
        let direct_path = tree.nodes.direct_copath(index);

        for (i, n) in direct_path.iter().enumerate() {
            assert_eq!(
                *tree
                    .nodes
                    .borrow_node(n.path)
                    .unwrap()
                    .as_ref()
                    .unwrap()
                    .public_key(),
                update_path.nodes[i].public_key
            );
        }

        // Verify that the leaf from the update path has been installed
        assert_eq!(
            tree.nodes.borrow_as_leaf(index).unwrap(),
            &update_path.leaf_node
        );

        // Verify that updated capabilities were installed
        if let Some(capabilities) = capabilities {
            assert_eq!(update_path.leaf_node.capabilities, capabilities);
        }

        // Verify that update extensions were installed
        if let Some(extensions) = extensions {
            assert_eq!(update_path.leaf_node.extensions, extensions);
        }

        // Verify that we have a public keys up to the root
        let root = tree.total_leaf_count().root();
        assert!(tree.nodes.borrow_node(root).unwrap().is_some());
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn verify_tree_private_path(
        cipher_suite: &CipherSuite,
        public_tree: &TreeKemPublic,
        private_tree: &TreeKemPrivate,
        index: LeafIndex,
    ) {
        let provider = test_cipher_suite_provider(*cipher_suite);

        assert_eq!(private_tree.self_index, index);

        // Make sure we have private values along the direct path, and the public keys match
        let path_iter = public_tree
            .nodes
            .direct_copath(index)
            .into_iter()
            .enumerate();

        for (i, n) in path_iter {
            let secret_key = private_tree.secret_keys[i + 1].as_ref().unwrap();

            let public_key = public_tree
                .nodes
                .borrow_node(n.path)
                .unwrap()
                .as_ref()
                .unwrap()
                .public_key();

            let test_data = random_bytes(32);

            let sealed = provider
                .hpke_seal(public_key, &[], None, &test_data)
                .await
                .unwrap();

            let opened = provider
                .hpke_open(&sealed, secret_key, public_key, &[], None)
                .await
                .unwrap();

            assert_eq!(test_data, opened);
        }
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn encap_decap(
        cipher_suite: CipherSuite,
        size: usize,
        capabilities: Option<Capabilities>,
        extensions: Option<ExtensionList>,
    ) {
        let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);

        // Generate signing keys and key package generations, and private keys for multiple
        // participants in order to set up state

        let mut leaf_nodes = Vec::new();
        let mut private_keys = Vec::new();

        for index in 1..size {
            let (leaf_node, hpke_secret, _) =
                get_basic_test_node_sig_key(cipher_suite, &format!("{index}")).await;

            let private_key = TreeKemPrivate::new_self_leaf(LeafIndex(index as u32), hpke_secret);

            leaf_nodes.push(leaf_node);
            private_keys.push(private_key);
        }

        let (encap_node, encap_hpke_secret, encap_signer) =
            get_basic_test_node_sig_key(cipher_suite, "encap").await;

        // Build a test tree we can clone for all leaf nodes
        let (mut test_tree, mut encap_private_key) = TreeKemPublic::derive(
            encap_node,
            encap_hpke_secret,
            &BasicIdentityProvider,
            &Default::default(),
        )
        .await
        .unwrap();

        test_tree
            .add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
            .await
            .unwrap();

        // Clone the tree for the first leaf, generate a new key package for that leaf
        let mut encap_tree = test_tree.clone();

        let update_leaf_properties = ConfigProperties {
            capabilities: capabilities.clone().unwrap_or_else(get_test_capabilities),
            extensions: extensions.clone().unwrap_or_default(),
        };

        // Perform the encap function
        let encap_gen = TreeKem::new(&mut encap_tree, &mut encap_private_key)
            .encap(
                &mut get_test_group_context(42, cipher_suite).await,
                &[],
                &encap_signer,
                update_leaf_properties,
                None,
                &cipher_suite_provider,
                #[cfg(test)]
                &Default::default(),
            )
            .await
            .unwrap();

        // Verify that the state of the tree matches the produced update path
        verify_tree_update_path(
            &encap_tree,
            &encap_gen.update_path,
            LeafIndex(0),
            capabilities,
            extensions,
        );

        // Verify that the private key matches the data in the public key
        verify_tree_private_path(&cipher_suite, &encap_tree, &encap_private_key, LeafIndex(0))
            .await;

        let filtered = test_tree.nodes.filtered(LeafIndex(0)).unwrap();
        let mut unfiltered_nodes = vec![None; filtered.len()];
        filtered
            .into_iter()
            .enumerate()
            .filter(|(_, f)| !*f)
            .zip(encap_gen.update_path.nodes.iter())
            .for_each(|((i, _), node)| {
                unfiltered_nodes[i] = Some(node.clone());
            });

        // Apply the update path to the rest of the leaf nodes using the decap function
        let validated_update_path = ValidatedUpdatePath {
            leaf_node: encap_gen.update_path.leaf_node,
            nodes: unfiltered_nodes,
        };

        encap_tree
            .update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
            .await
            .unwrap();

        let mut receiver_trees: Vec<TreeKemPublic> = (1..size).map(|_| test_tree.clone()).collect();

        for (i, tree) in receiver_trees.iter_mut().enumerate() {
            tree.apply_update_path(
                LeafIndex(0),
                &validated_update_path,
                &Default::default(),
                BasicIdentityProvider,
                &cipher_suite_provider,
            )
            .await
            .unwrap();

            let mut context = get_test_group_context(42, cipher_suite).await;
            context.tree_hash = tree.tree_hash(&cipher_suite_provider).await.unwrap();

            TreeKem::new(tree, &mut private_keys[i])
                .decap(
                    LeafIndex(0),
                    &validated_update_path,
                    &[],
                    &context.mls_encode_to_vec().unwrap(),
                    &cipher_suite_provider,
                )
                .await
                .unwrap();

            tree.update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
                .await
                .unwrap();

            assert_eq!(tree, &encap_tree);
        }
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_encap_decap() {
        for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
            encap_decap(cipher_suite, 10, None, None).await;
        }
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_encap_capabilities() {
        let cipher_suite = TEST_CIPHER_SUITE;
        let mut capabilities = get_test_capabilities();
        capabilities.extensions.push(42.into());

        encap_decap(cipher_suite, 10, Some(capabilities.clone()), None).await;
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_encap_extensions() {
        let cipher_suite = TEST_CIPHER_SUITE;
        let mut extensions = ExtensionList::default();
        extensions.set_from(TestExtension { foo: 10 }).unwrap();

        encap_decap(cipher_suite, 10, None, Some(extensions)).await;
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_encap_capabilities_extensions() {
        let cipher_suite = TEST_CIPHER_SUITE;
        let mut capabilities = get_test_capabilities();
        capabilities.extensions.push(42.into());

        let mut extensions = ExtensionList::default();
        extensions.set_from(TestExtension { foo: 10 }).unwrap();

        encap_decap(cipher_suite, 10, Some(capabilities), Some(extensions)).await;
    }
}

[ Dauer der Verarbeitung: 0.25 Sekunden  (vorverarbeitet)  ]