Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/mls-rs-core/src/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 8 kB image not shown  

Quelle  extension.rs   Sprache: unbekannt

 
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use core::{
    fmt::{self, Debug},
    ops::Deref,
};

use crate::error::{AnyError, IntoAnyError};
use alloc::vec::Vec;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

mod list;

pub use list::*;

/// Wrapper type representing an extension identifier along with default values
/// defined by the MLS RFC.
#[derive(
    Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct ExtensionType(u16);

// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl ExtensionType {
    pub const APPLICATION_ID: ExtensionType = ExtensionType(1);
    pub const RATCHET_TREE: ExtensionType = ExtensionType(2);
    pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3);
    pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
    pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);

    /// Default extension types defined
    /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents)
    pub const DEFAULT: &'static [ExtensionType] = &[
        ExtensionType::APPLICATION_ID,
        ExtensionType::RATCHET_TREE,
        ExtensionType::REQUIRED_CAPABILITIES,
        ExtensionType::EXTERNAL_PUB,
        ExtensionType::EXTERNAL_SENDERS,
    ];

    /// Extension type from a raw value
    pub const fn new(raw_value: u16) -> Self {
        ExtensionType(raw_value)
    }

    /// Raw numerical wrapped value.
    pub const fn raw_value(&self) -> u16 {
        self.0
    }

    /// Determines if this extension type is required to be implemented
    /// by the MLS RFC.
    pub const fn is_default(&self) -> bool {
        self.0 <= 5
    }
}

impl From<u16> for ExtensionType {
    fn from(value: u16) -> Self {
        ExtensionType(value)
    }
}

impl Deref for ExtensionType {
    type Target = u16;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

#[derive(Debug)]
#[cfg_attr(feature = "std", derive(thiserror::Error))]
pub enum ExtensionError {
    #[cfg_attr(feature = "std", error(transparent))]
    SerializationError(AnyError),
    #[cfg_attr(feature = "std", error(transparent))]
    DeserializationError(AnyError),
    #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))]
    IncorrectType(ExtensionType),
}

impl IntoAnyError for ExtensionError {
    #[cfg(feature = "std")]
    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
        Ok(self.into())
    }
}

#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
// #[cfg_attr(
//     all(feature = "ffi", not(test)),
//     safer_ffi_gen::ffi_type(clone, opaque)
// )]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
/// An MLS protocol [extension](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-extensions).
///
/// Extensions are used as customization points in various parts of the
/// MLS protocol and are inserted into an [ExtensionList](self::ExtensionList).
pub struct Extension {
    /// Extension type of this extension
    pub extension_type: ExtensionType,
    /// Data held within this extension
    #[mls_codec(with = "mls_rs_codec::byte_vec")]
    #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
    pub extension_data: Vec<u8>,
}

impl Debug for Extension {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Extension")
            .field("extension_type", &self.extension_type)
            .field(
                "extension_data",
                &crate::debug::pretty_bytes(&self.extension_data),
            )
            .finish()
    }
}

// #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl Extension {
    /// Create an extension with specified type and data properties.
    pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
        Extension {
            extension_type,
            extension_data,
        }
    }

    /// Extension type of this extension
    #[cfg(feature = "ffi")]
    pub fn extension_type(&self) -> ExtensionType {
        self.extension_type
    }

    /// Data held within this extension
    #[cfg(feature = "ffi")]
    pub fn extension_data(&self) -> &[u8] {
        &self.extension_data
    }
}

/// Trait used to convert a type to and from an [Extension]
pub trait MlsExtension: Sized {
    /// Error type of the underlying serializer that can convert this type into a `Vec<u8>`.
    type SerializationError: IntoAnyError;

    /// Error type of the underlying deserializer that can convert a `Vec<u8>` into this type.
    type DeserializationError: IntoAnyError;

    /// Extension type value that this type represents.
    fn extension_type() -> ExtensionType;

    /// Convert this type to opaque bytes.
    fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>;

    /// Create this type from opaque bytes.
    fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>;

    /// Convert this type into an [Extension].
    fn into_extension(self) -> Result<Extension, ExtensionError> {
        Ok(Extension::new(
            Self::extension_type(),
            self.to_bytes()
                .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?,
        ))
    }

    /// Create this type from an [Extension].
    fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> {
        if ext.extension_type != Self::extension_type() {
            return Err(ExtensionError::IncorrectType(ext.extension_type));
        }

        Self::from_bytes(&ext.extension_data)
            .map_err(|e| ExtensionError::DeserializationError(e.into_any_error()))
    }
}

/// Convenience trait for custom extension types that use
/// [mls_rs_codec] as an underlying serialization mechanism
pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode {
    fn extension_type() -> ExtensionType;
}

impl<T> MlsExtension for T
where
    T: MlsCodecExtension,
{
    type SerializationError = mls_rs_codec::Error;
    type DeserializationError = mls_rs_codec::Error;

    fn extension_type() -> ExtensionType {
        <Self as MlsCodecExtension>::extension_type()
    }

    fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
        self.mls_encode_to_vec()
    }

    fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
        Self::mls_decode(&mut &*data)
    }
}

#[cfg(test)]
mod tests {
    use core::convert::Infallible;

    use alloc::vec;
    use alloc::vec::Vec;
    use assert_matches::assert_matches;
    use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

    use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension};

    struct TestExtension;

    #[derive(Debug, MlsSize, MlsEncode, MlsDecode)]
    struct AnotherTestExtension;

    impl MlsExtension for TestExtension {
        type SerializationError = Infallible;
        type DeserializationError = Infallible;

        fn extension_type() -> super::ExtensionType {
            ExtensionType(42)
        }

        fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
            Ok(vec![0])
        }

        fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> {
            Ok(TestExtension)
        }
    }

    impl MlsCodecExtension for AnotherTestExtension {
        fn extension_type() -> ExtensionType {
            ExtensionType(43)
        }
    }

    #[test]
    fn into_extension() {
        assert_eq!(
            TestExtension.into_extension().unwrap(),
            Extension::new(42.into(), vec![0])
        )
    }

    #[test]
    fn incorrect_type_is_discovered() {
        let ext = Extension::new(42.into(), vec![0]);

        assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into());
    }
}

[ Dauer der Verarbeitung: 0.25 Sekunden  (vorverarbeitet)  ]