Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/mls-rs-codec-derive/src/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 10 kB image not shown  

Quelle  lib.rs   Sprache: unbekannt

 
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use std::str::FromStr;

use darling::{
    ast::{self, Fields},
    FromDeriveInput, FromField, FromVariant,
};
use proc_macro2::{Literal, TokenStream};
use quote::quote;
use syn::{
    parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
};

enum Operation {
    Size,
    Encode,
    Decode,
}

impl Operation {
    fn path(&self) -> Path {
        match self {
            Operation::Size => parse_quote! { mls_rs_codec::MlsSize },
            Operation::Encode => parse_quote! { mls_rs_codec::MlsEncode },
            Operation::Decode => parse_quote! { mls_rs_codec::MlsDecode },
        }
    }

    fn call(&self) -> TokenStream {
        match self {
            Operation::Size => quote! { mls_encoded_len },
            Operation::Encode => quote! { mls_encode },
            Operation::Decode => quote! { mls_decode },
        }
    }

    fn extras(&self) -> TokenStream {
        match self {
            Operation::Size => quote! {},
            Operation::Encode => quote! { , writer },
            Operation::Decode => quote! { reader },
        }
    }

    fn is_result(&self) -> bool {
        match self {
            Operation::Size => false,
            Operation::Encode => true,
            Operation::Decode => true,
        }
    }
}

#[derive(Debug, FromField)]
#[darling(attributes(mls_codec))]
struct MlsFieldReceiver {
    ident: Option<Ident>,
    with: Option<Path>,
}

impl MlsFieldReceiver {
    pub fn call_tokens(&self, index: Index) -> TokenStream {
        if let Some(ref ident) = self.ident {
            quote! { &self.#ident }
        } else {
            quote! { &self.#index }
        }
    }

    pub fn name(&self, index: Index) -> TokenStream {
        if let Some(ref ident) = self.ident {
            quote! {#ident: }
        } else {
            quote! { #index: }
        }
    }
}

#[derive(Debug, FromVariant)]
#[darling(attributes(mls_codec))]
struct MlsVariantReceiver {
    ident: Ident,
    discriminant: Option<Expr>,
    fields: ast::Fields<MlsFieldReceiver>,
}

#[derive(FromDeriveInput)]
#[darling(attributes(mls_codec), forward_attrs(repr))]
struct MlsInputReceiver {
    attrs: Vec<Attribute>,
    ident: Ident,
    generics: Generics,
    data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
}

impl MlsInputReceiver {
    fn handle_input(&self, operation: Operation) -> TokenStream {
        match self.data {
            ast::Data::Struct(ref s) => struct_impl(s, operation),
            ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
        }
    }
}

fn repr_ident(attrs: &[Attribute]) -> Option<Ident> {
    let repr_path = attrs
        .iter()
        .filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
        .find(|attr| attr.path().is_ident("repr"))
        .map(|repr| repr.parse_args())
        .transpose()
        .ok()
        .flatten();

    let Some(Expr::Path(path)) = repr_path else {
        return None;
    };

    path.path
        .segments
        .iter()
        .find(|s| s.ident != "C")
        .map(|path| path.ident.clone())
}

/// Provides the discriminant for a given variant. If the variant does not specify a suffix
/// and a `repr_ident` is provided, it will be appended to number.
fn discriminant_for_variant(
    variant: &MlsVariantReceiver,
    repr_ident: &Option<Ident>,
) -> TokenStream {
    let discriminant = variant
        .discriminant
        .clone()
        .expect("Enum discriminants must be explicitly defined");

    let Expr::Lit(lit_expr) = &discriminant else {
        return quote! {#discriminant};
    };

    let Lit::Int(lit_int) = &lit_expr.lit else {
        return quote! {#discriminant};
    };

    if lit_int.suffix().is_empty() {
        // This is dirty and there is probably a better way of doing this but I'm way too much of a noob at
        // proc macros to pull it off...
        // TODO: Add proper support for correctly ignoring transparent, packed and modifiers
        let str = format!(
            "{}{}",
            lit_int.base10_digits(),
            &repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.")
        );
        Literal::from_str(&str)
            .map(|l| quote! {#l})
            .ok()
            .unwrap_or_else(|| quote! {#discriminant})
    } else {
        quote! {#discriminant}
    }
}

fn enum_impl(
    ident: &Ident,
    attrs: &[Attribute],
    variants: &[MlsVariantReceiver],
    operation: Operation,
) -> TokenStream {
    let handle_error = operation.is_result().then_some(quote! { ? });
    let path = operation.path();
    let call = operation.call();
    let extras = operation.extras();
    let enum_name = &ident;
    let repr_ident = repr_ident(attrs);
    if matches!(operation, Operation::Decode) {
        let cases = variants.iter().map(|variant| {
            let variant_name = &variant.ident;

            let discriminant = discriminant_for_variant(variant, &repr_ident);

            // TODO: Support more than 1 field
            match variant.fields.len() {
                0 => quote! { #discriminant => Ok(#enum_name::#variant_name), },
                1 =>{
                    let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
                    quote! { #discriminant => Ok(#enum_name::#variant_name(#path::#call(#extras) #handle_error)), }
                },
                _ => panic!("Enum discriminants with more than 1 field are not currently supported")
            }
        });

        return quote! {
            let discriminant = #path::#call(#extras)#handle_error;

            match discriminant {
                #(#cases)*
                _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
            }
        };
    }

    let cases = variants.iter().map(|variant| {
        let variant_name = &variant.ident;

        let discriminant = discriminant_for_variant(variant, &repr_ident);

        let (parameter, field) = if variant.fields.is_empty() {
            (None, None)
        } else {
            let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);

            let start = match operation {
                Operation::Size => Some(quote! { + }),
                Operation::Encode => Some(quote! {;}),
                Operation::Decode => None,
            };

            (
                Some(quote! {(ref val)}),
                Some(quote! { #start #path::#call (val #extras) #handle_error }),
            )
        };

        let discrim = quote! { #path::#call (&#discriminant #extras) #handle_error };

        quote! { #enum_name::#variant_name #parameter => { #discrim #field }}
    });

    let enum_impl = quote! {
        match self {
            #(#cases)*
        }
    };

    if operation.is_result() {
        quote! {
            Ok(#enum_impl)
        }
    } else {
        enum_impl
    }
}

fn struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream {
    let recurse = s.fields.iter().enumerate().map(|(index, field)| {
        let (call_tokens, field_name) = match operation {
            Operation::Size | Operation::Encode => {
                (field.call_tokens(Index::from(index)), quote! {})
            }
            Operation::Decode => (quote! {}, field.name(Index::from(index))),
        };

        let handle_error = operation.is_result().then_some(quote! { ? });
        let path = field.with.clone().unwrap_or(operation.path());
        let call = operation.call();
        let extras = operation.extras();

        quote! {
           #field_name #path::#call (#call_tokens #extras) #handle_error
        }
    });

    match operation {
        Operation::Size => quote! { 0 #(+ #recurse)* },
        Operation::Encode => quote! { #(#recurse;)* Ok(()) },
        Operation::Decode => quote! { Ok(Self { #(#recurse,)* }) },
    }
}

fn derive_impl<F>(
    input: proc_macro::TokenStream,
    trait_name: TokenStream,
    function_def: TokenStream,
    internals: F,
) -> proc_macro::TokenStream
where
    F: FnOnce(&MlsInputReceiver) -> TokenStream,
{
    let input = parse_macro_input!(input as DeriveInput);

    let input = MlsInputReceiver::from_derive_input(&input).unwrap();

    let name = &input.ident;

    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

    // Generate an expression to sum up the heap size of each field.
    let function_impl = internals(&input);

    let expanded = quote! {
        // The generated impl.
        impl #impl_generics #trait_name for #name #ty_generics #where_clause {
            #function_def {
                #function_impl
            }
        }
    };

    // Hand the output tokens back to the compiler.
    proc_macro::TokenStream::from(expanded)
}

#[proc_macro_derive(MlsSize, attributes(mls_codec))]
pub fn derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let trait_name = quote! { mls_rs_codec::MlsSize };
    let function_def = quote! {fn mls_encoded_len(&self) -> usize };

    derive_impl(input, trait_name, function_def, |input| {
        input.handle_input(Operation::Size)
    })
}

#[proc_macro_derive(MlsEncode, attributes(mls_codec))]
pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let trait_name = quote! { mls_rs_codec::MlsEncode };

    let function_def = quote! { fn mls_encode(&self, writer: &mut mls_rs_codec::Vec<u8>) -> Result<(), mls_rs_codec::Error> };

    derive_impl(input, trait_name, function_def, |input| {
        input.handle_input(Operation::Encode)
    })
}

#[proc_macro_derive(MlsDecode, attributes(mls_codec))]
pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let trait_name = quote! { mls_rs_codec::MlsDecode };

    let function_def =
        quote! { fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> };

    derive_impl(input, trait_name, function_def, |input| {
        input.handle_input(Operation::Decode)
    })
}

[ Dauer der Verarbeitung: 0.24 Sekunden  (vorverarbeitet)  ]