Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/mls-rs/src/group/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 18 kB image not shown  

Quelle  state_repo.rs   Sprache: unbekannt

 
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use crate::client::MlsError;
use crate::{group::PriorEpoch, key_package::KeyPackageRef};

use alloc::collections::VecDeque;
use alloc::vec::Vec;
use core::fmt::{self, Debug};
use mls_rs_codec::{MlsDecode, MlsEncode};
use mls_rs_core::group::{EpochRecord, GroupState};
use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};

use super::snapshot::Snapshot;

#[cfg(feature = "psk")]
use crate::group::ResumptionPsk;

#[cfg(feature = "psk")]
use mls_rs_core::psk::PreSharedKey;

/// A set of changes to apply to a GroupStateStorage implementation. These changes MUST
/// be made in a single transaction to avoid creating invalid states.
#[derive(Default, Clone, Debug)]
struct EpochStorageCommit {
    pub(crate) inserts: VecDeque<PriorEpoch>,
    pub(crate) updates: Vec<PriorEpoch>,
}

#[derive(Clone)]
pub(crate) struct GroupStateRepository<S, K>
where
    S: GroupStateStorage,
    K: KeyPackageStorage,
{
    pending_commit: EpochStorageCommit,
    pending_key_package_removal: Option<KeyPackageRef>,
    group_id: Vec<u8>,
    storage: S,
    key_package_repo: K,
}

impl<S, K> Debug for GroupStateRepository<S, K>
where
    S: GroupStateStorage + Debug,
    K: KeyPackageStorage + Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("GroupStateRepository")
            .field("pending_commit", &self.pending_commit)
            .field(
                "pending_key_package_removal",
                &self.pending_key_package_removal,
            )
            .field(
                "group_id",
                &mls_rs_core::debug::pretty_group_id(&self.group_id),
            )
            .field("storage", &self.storage)
            .field("key_package_repo", &self.key_package_repo)
            .finish()
    }
}

impl<S, K> GroupStateRepository<S, K>
where
    S: GroupStateStorage,
    K: KeyPackageStorage,
{
    pub fn new(
        group_id: Vec<u8>,
        storage: S,
        key_package_repo: K,
        // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
        key_package_to_remove: Option<KeyPackageRef>,
    ) -> Result<GroupStateRepository<S, K>, MlsError> {
        Ok(GroupStateRepository {
            group_id,
            storage,
            pending_key_package_removal: key_package_to_remove,
            pending_commit: Default::default(),
            key_package_repo,
        })
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
        if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
            Ok(Some(max))
        } else {
            self.storage
                .max_epoch_id(&self.group_id)
                .await
                .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
        }
    }

    #[cfg(feature = "psk")]
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn resumption_secret(
        &self,
        psk_id: &ResumptionPsk,
    ) -> Result<Option<PreSharedKey>, MlsError> {
        // Search the local inserts cache
        if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
            if psk_id.psk_epoch >= min {
                return Ok(self
                    .pending_commit
                    .inserts
                    .get((psk_id.psk_epoch - min) as usize)
                    .map(|e| e.secrets.resumption_secret.clone()));
            }
        }

        // Search the local updates cache
        let maybe_pending = self.find_pending(psk_id.psk_epoch);

        if let Some(pending) = maybe_pending {
            return Ok(Some(
                self.pending_commit.updates[pending]
                    .secrets
                    .resumption_secret
                    .clone(),
            ));
        }

        // Search the stored cache
        self.storage
            .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
            .await
            .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
            .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
            .transpose()
    }

    #[cfg(feature = "private_message")]
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn get_epoch_mut(
        &mut self,
        epoch_id: u64,
    ) -> Result<Option<&mut PriorEpoch>, MlsError> {
        // Search the local inserts cache
        if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
            if epoch_id >= min {
                return Ok(self
                    .pending_commit
                    .inserts
                    .get_mut((epoch_id - min) as usize));
            }
        }

        // Look in the cached updates map, and if not found look in disk storage
        // and insert into the updates map for future caching
        match self.find_pending(epoch_id) {
            Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
            None => self
                .storage
                .epoch(&self.group_id, epoch_id)
                .await
                .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
                .and_then(|epoch| {
                    PriorEpoch::mls_decode(&mut &*epoch)
                        .map(|epoch| {
                            self.pending_commit.updates.push(epoch);
                            self.pending_commit.updates.last_mut()
                        })
                        .transpose()
                }),
        }
        .transpose()
        .map_err(Into::into)
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
        if epoch.group_id() != self.group_id {
            return Err(MlsError::GroupIdMismatch);
        }

        let epoch_id = epoch.epoch_id();

        if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
            if epoch_id != expected_id {
                return Err(MlsError::InvalidEpoch);
            }
        }

        self.pending_commit.inserts.push_back(epoch);

        Ok(())
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
        let inserts = self
            .pending_commit
            .inserts
            .iter()
            .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
            .collect::<Result<_, MlsError>>()?;

        let updates = self
            .pending_commit
            .updates
            .iter()
            .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
            .collect::<Result<_, MlsError>>()?;

        let group_state = GroupState {
            data: group_snapshot.mls_encode_to_vec()?,
            id: group_snapshot.state.context.group_id,
        };

        self.storage
            .write(group_state, inserts, updates)
            .await
            .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;

        if let Some(ref key_package_ref) = self.pending_key_package_removal {
            self.key_package_repo
                .delete(key_package_ref)
                .await
                .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
        }

        self.pending_commit.inserts.clear();
        self.pending_commit.updates.clear();

        Ok(())
    }

    #[cfg(any(feature = "psk", feature = "private_message"))]
    fn find_pending(&self, epoch_id: u64) -> Option<usize> {
        self.pending_commit
            .updates
            .iter()
            .position(|ep| ep.context.epoch == epoch_id)
    }
}

#[cfg(test)]
mod tests {
    use alloc::vec;
    use mls_rs_codec::MlsEncode;

    use crate::{
        client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
        group::{
            epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
            test_utils::{random_bytes, test_member, TEST_GROUP},
            PskGroupId, ResumptionPSKUsage,
        },
        storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
    };

    use super::*;

    fn test_group_state_repo(
        retention_limit: usize,
    ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
        GroupStateRepository::new(
            TEST_GROUP.to_vec(),
            InMemoryGroupStateStorage::new()
                .with_max_epoch_retention(retention_limit)
                .unwrap(),
            InMemoryKeyPackageStorage::default(),
            None,
        )
        .unwrap()
    }

    fn test_epoch(epoch_id: u64) -> PriorEpoch {
        get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
    }

    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    async fn test_snapshot(epoch_id: u64) -> Snapshot {
        crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_epoch_inserts() {
        let mut test_repo = test_group_state_repo(1);
        let test_epoch = test_epoch(0);

        test_repo.insert(test_epoch.clone()).await.unwrap();

        // Check the in-memory state
        assert_eq!(
            test_repo.pending_commit.inserts.back().unwrap(),
            &test_epoch
        );

        assert!(test_repo.pending_commit.updates.is_empty());

        #[cfg(feature = "std")]
        assert!(test_repo.storage.inner.lock().unwrap().is_empty());
        #[cfg(not(feature = "std"))]
        assert!(test_repo.storage.inner.lock().is_empty());

        let psk_id = ResumptionPsk {
            psk_epoch: 0,
            psk_group_id: PskGroupId(test_repo.group_id.clone()),
            usage: ResumptionPSKUsage::Application,
        };

        // Make sure you can recall an epoch sitting as a pending insert
        let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
        let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();

        assert_eq!(
            prior_epoch.clone().unwrap().secrets.resumption_secret,
            resumption.unwrap()
        );

        assert_eq!(prior_epoch.unwrap(), test_epoch);

        // Write to the storage
        let snapshot = test_snapshot(test_epoch.epoch_id()).await;
        test_repo.write_to_storage(snapshot.clone()).await.unwrap();

        // Make sure the memory cache cleared
        assert!(test_repo.pending_commit.inserts.is_empty());
        assert!(test_repo.pending_commit.updates.is_empty());

        // Make sure the storage was written
        #[cfg(feature = "std")]
        let storage = test_repo.storage.inner.lock().unwrap();
        #[cfg(not(feature = "std"))]
        let storage = test_repo.storage.inner.lock();

        assert_eq!(storage.len(), 1);

        let stored = storage.get(TEST_GROUP).unwrap();

        assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());

        assert_eq!(stored.epoch_data.len(), 1);

        assert_eq!(
            stored.epoch_data.back().unwrap(),
            &EpochRecord::new(
                test_epoch.epoch_id(),
                test_epoch.mls_encode_to_vec().unwrap()
            )
        );
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_updates() {
        let mut test_repo = test_group_state_repo(2);
        let test_epoch_0 = test_epoch(0);

        test_repo.insert(test_epoch_0.clone()).await.unwrap();

        test_repo
            .write_to_storage(test_snapshot(0).await)
            .await
            .unwrap();

        // Update the stored epoch
        let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
        assert_eq!(to_update, &test_epoch_0);

        let new_sender_secret = random_bytes(32);
        to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
        let to_update = to_update.clone();

        assert_eq!(test_repo.pending_commit.updates.len(), 1);
        assert!(test_repo.pending_commit.inserts.is_empty());

        assert_eq!(
            test_repo.pending_commit.updates.first().unwrap(),
            &to_update
        );

        // Make sure you can access an epoch pending update
        let psk_id = ResumptionPsk {
            psk_epoch: 0,
            psk_group_id: PskGroupId(test_repo.group_id.clone()),
            usage: ResumptionPSKUsage::Application,
        };

        let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
        assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));

        // Write the update to storage
        let snapshot = test_snapshot(1).await;
        test_repo.write_to_storage(snapshot.clone()).await.unwrap();

        assert!(test_repo.pending_commit.updates.is_empty());
        assert!(test_repo.pending_commit.inserts.is_empty());

        // Make sure the storage was written
        #[cfg(feature = "std")]
        let storage = test_repo.storage.inner.lock().unwrap();
        #[cfg(not(feature = "std"))]
        let storage = test_repo.storage.inner.lock();

        assert_eq!(storage.len(), 1);

        let stored = storage.get(TEST_GROUP).unwrap();

        assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());

        assert_eq!(stored.epoch_data.len(), 1);

        assert_eq!(
            stored.epoch_data.back().unwrap(),
            &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
        );
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_insert_and_update() {
        let mut test_repo = test_group_state_repo(2);
        let test_epoch_0 = test_epoch(0);

        test_repo.insert(test_epoch_0).await.unwrap();

        test_repo
            .write_to_storage(test_snapshot(0).await)
            .await
            .unwrap();

        // Update the stored epoch
        let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
        let new_sender_secret = random_bytes(32);
        to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
        let to_update = to_update.clone();

        // Insert another epoch
        let test_epoch_1 = test_epoch(1);
        test_repo.insert(test_epoch_1.clone()).await.unwrap();

        test_repo
            .write_to_storage(test_snapshot(1).await)
            .await
            .unwrap();

        assert!(test_repo.pending_commit.inserts.is_empty());
        assert!(test_repo.pending_commit.updates.is_empty());

        // Make sure the storage was written
        #[cfg(feature = "std")]
        let storage = test_repo.storage.inner.lock().unwrap();
        #[cfg(not(feature = "std"))]
        let storage = test_repo.storage.inner.lock();

        assert_eq!(storage.len(), 1);

        let stored = storage.get(TEST_GROUP).unwrap();

        assert_eq!(stored.epoch_data.len(), 2);

        assert_eq!(
            stored.epoch_data.front().unwrap(),
            &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
        );

        assert_eq!(
            stored.epoch_data.back().unwrap(),
            &EpochRecord::new(
                test_epoch_1.epoch_id(),
                test_epoch_1.mls_encode_to_vec().unwrap()
            )
        );
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_many_epochs_in_storage() {
        let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();

        let mut test_repo = test_group_state_repo(10);

        for epoch in epochs.iter().cloned() {
            test_repo.insert(epoch).await.unwrap()
        }

        test_repo
            .write_to_storage(test_snapshot(9).await)
            .await
            .unwrap();

        for mut epoch in epochs {
            let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();

            assert_eq!(res, Some(&mut epoch));
        }
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn test_stored_groups_list() {
        let mut test_repo = test_group_state_repo(2);
        let test_epoch_0 = test_epoch(0);

        test_repo.insert(test_epoch_0.clone()).await.unwrap();

        test_repo
            .write_to_storage(test_snapshot(0).await)
            .await
            .unwrap();

        assert_eq!(
            test_repo.storage.stored_groups(),
            vec![test_epoch_0.context.group_id]
        )
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn reducing_retention_limit_takes_effect_on_epoch_access() {
        let mut repo = test_group_state_repo(1);

        repo.insert(test_epoch(0)).await.unwrap();
        repo.insert(test_epoch(1)).await.unwrap();

        repo.write_to_storage(test_snapshot(0).await).await.unwrap();

        let mut repo = GroupStateRepository {
            storage: repo.storage,
            ..test_group_state_repo(1)
        };

        let res = repo.get_epoch_mut(0).await.unwrap();

        assert!(res.is_none());
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn in_memory_storage_obeys_retention_limit_after_saving() {
        let mut repo = test_group_state_repo(1);

        repo.insert(test_epoch(0)).await.unwrap();
        repo.write_to_storage(test_snapshot(0).await).await.unwrap();
        repo.insert(test_epoch(1)).await.unwrap();
        repo.write_to_storage(test_snapshot(1).await).await.unwrap();

        #[cfg(feature = "std")]
        let lock = repo.storage.inner.lock().unwrap();
        #[cfg(not(feature = "std"))]
        let lock = repo.storage.inner.lock();

        assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn used_key_package_is_deleted() {
        let key_package_repo = InMemoryKeyPackageStorage::default();

        let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
            .await
            .0;

        let (id, data) = key_package.to_storage().unwrap();

        key_package_repo.insert(id, data);

        let mut repo = GroupStateRepository::new(
            TEST_GROUP.to_vec(),
            InMemoryGroupStateStorage::new(),
            key_package_repo,
            Some(key_package.reference.clone()),
        )
        .unwrap();

        repo.key_package_repo.get(&key_package.reference).unwrap();

        repo.write_to_storage(test_snapshot(4).await).await.unwrap();

        assert!(repo.key_package_repo.get(&key_package.reference).is_none());
    }
}

[ Dauer der Verarbeitung: 0.24 Sekunden  (vorverarbeitet)  ]