Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/prio/src/vdaf/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 18 kB image not shown  

Quelle  prio2.rs   Sprache: unbekannt

 
Spracherkennung für: .rs vermutete Sprache: Unknown {[0] [0] [0]} [Methode: Schwerpunktbildung, einfache Gewichte, sechs Dimensionen]

// SPDX-License-Identifier: MPL-2.0

//! Backwards-compatible port of the ENPA Prio system to a VDAF.

use crate::{
    codec::{CodecError, Decode, Encode, ParameterizedDecode},
    field::{
        decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldPrio2,
    },
    prng::Prng,
    vdaf::{
        prio2::{
            client::{self as v2_client, proof_length},
            server as v2_server,
        },
        xof::Seed,
        Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare,
        PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError,
    },
};
use hmac::{Hmac, Mac};
use rand_core::RngCore;
use sha2::Sha256;
use std::{convert::TryFrom, io::Cursor};
use subtle::{Choice, ConstantTimeEq};

mod client;
mod server;
#[cfg(test)]
mod test_vector;

/// The Prio2 VDAF. It supports the same measurement type as
/// [`Prio3SumVec`](crate::vdaf::prio3::Prio3SumVec) with `bits == 1` but uses the proof system and
/// finite field deployed in ENPA.
#[derive(Clone, Debug)]
pub struct Prio2 {
    input_len: usize,
}

impl Prio2 {
    /// Returns an instance of the VDAF for the given input length.
    pub fn new(input_len: usize) -> Result<Self, VdafError> {
        let n = (input_len + 1).next_power_of_two();
        if let Ok(size) = u32::try_from(2 * n) {
            if size > FieldPrio2::generator_order() {
                return Err(VdafError::Uncategorized(
                    "input size exceeds field capacity".into(),
                ));
            }
        } else {
            return Err(VdafError::Uncategorized(
                "input size exceeds memory capacity".into(),
            ));
        }

        Ok(Prio2 { input_len })
    }

    /// Prepare an input share for aggregation using the given field element `query_rand` to
    /// compute the verifier share.
    ///
    /// In the [`Aggregator`] trait implementation for [`Prio2`], the query randomness is computed
    /// jointly by the Aggregators. This method is designed to be used in applications, like ENPA,
    /// in which the query randomness is instead chosen by a third-party.
    pub fn prepare_init_with_query_rand(
        &self,
        query_rand: FieldPrio2,
        input_share: &Share<FieldPrio2, 32>,
        is_leader: bool,
    ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
        let expanded_data: Option<Vec<FieldPrio2>> = match input_share {
            Share::Leader(_) => None,
            Share::Helper(ref seed) => {
                let prng = Prng::from_prio2_seed(seed.as_ref());
                Some(prng.take(proof_length(self.input_len)).collect())
            }
        };
        let data = match input_share {
            Share::Leader(ref data) => data,
            Share::Helper(_) => expanded_data.as_ref().unwrap(),
        };

        let verifier_share = v2_server::generate_verification_message(
            self.input_len,
            query_rand,
            data, // Combined input and proof shares
            is_leader,
        )
        .map_err(|e| VdafError::Uncategorized(e.to_string()))?;

        let truncated_share = match input_share {
            Share::Leader(data) => Share::Leader(data[..self.input_len].to_vec()),
            Share::Helper(seed) => Share::Helper(seed.clone()),
        };

        Ok((
            Prio2PrepareState(truncated_share),
            Prio2PrepareShare(verifier_share),
        ))
    }

    /// Choose a random point for polynomial evaluation.
    ///
    /// The point returned is not one of the roots used for polynomial interpolation.
    pub(crate) fn choose_eval_at<S>(&self, prng: &mut Prng<FieldPrio2, S>) -> FieldPrio2
    where
        S: RngCore,
    {
        // Make sure the query randomness isn't a root of unity. Evaluating the proof at any of
        // these points would be a privacy violation, since these points were used by the prover to
        // construct the wire polynomials.
        let n = (self.input_len + 1).next_power_of_two();
        let proof_length = 2 * n;
        loop {
            let eval_at: FieldPrio2 = prng.get();
            // Unwrap safety: the constructor checks that this conversion succeeds.
            if eval_at.pow(u32::try_from(proof_length).unwrap()) != FieldPrio2::one() {
                return eval_at;
            }
        }
    }
}

impl Vdaf for Prio2 {
    type Measurement = Vec<u32>;
    type AggregateResult = Vec<u32>;
    type AggregationParam = ();
    type PublicShare = ();
    type InputShare = Share<FieldPrio2, 32>;
    type OutputShare = OutputShare<FieldPrio2>;
    type AggregateShare = AggregateShare<FieldPrio2>;

    fn algorithm_id(&self) -> u32 {
        0xFFFF0000
    }

    fn num_aggregators(&self) -> usize {
        // Prio2 can easily be extended to support more than two Aggregators.
        2
    }
}

impl Client<16> for Prio2 {
    fn shard(
        &self,
        measurement: &Vec<u32>,
        _nonce: &[u8; 16],
    ) -> Result<(Self::PublicShare, Vec<Share<FieldPrio2, 32>>), VdafError> {
        if measurement.len() != self.input_len {
            return Err(VdafError::Uncategorized("incorrect input length".into()));
        }
        let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len());
        for int in measurement {
            input.push((*int).into());
        }

        let mut mem = v2_client::ClientMemory::new(self.input_len)?;
        let copy_data = |share_data: &mut [FieldPrio2]| {
            share_data[..].clone_from_slice(&input);
        };
        let mut leader_data = mem.prove_with(self.input_len, copy_data);

        let helper_seed = Seed::generate()?;
        let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref());
        for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) {
            *s1 -= d;
        }

        Ok((
            (),
            vec![Share::Leader(leader_data), Share::Helper(helper_seed)],
        ))
    }
}

/// State of each [`Aggregator`] during the Preparation phase.
#[derive(Clone, Debug)]
pub struct Prio2PrepareState(Share<FieldPrio2, 32>);

impl PartialEq for Prio2PrepareState {
    fn eq(&self, other: &Self) -> bool {
        self.ct_eq(other).into()
    }
}

impl Eq for Prio2PrepareState {}

impl ConstantTimeEq for Prio2PrepareState {
    fn ct_eq(&self, other: &Self) -> Choice {
        self.0.ct_eq(&other.0)
    }
}

impl Encode for Prio2PrepareState {
    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
        self.0.encode(bytes)
    }

    fn encoded_len(&self) -> Option<usize> {
        self.0.encoded_len()
    }
}

impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState {
    fn decode_with_param(
        (prio2, agg_id): &(&'a Prio2, usize),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let share_decoder = if *agg_id == 0 {
            ShareDecodingParameter::Leader(prio2.input_len)
        } else {
            ShareDecodingParameter::Helper
        };
        let out_share = Share::decode_with_param(&share_decoder, bytes)?;
        Ok(Self(out_share))
    }
}

/// Message emitted by each [`Aggregator`] during the Preparation phase.
#[derive(Clone, Debug)]
pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>);

impl Encode for Prio2PrepareShare {
    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
        self.0.f_r.encode(bytes)?;
        self.0.g_r.encode(bytes)?;
        self.0.h_r.encode(bytes)
    }

    fn encoded_len(&self) -> Option<usize> {
        Some(FieldPrio2::ENCODED_SIZE * 3)
    }
}

impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare {
    fn decode_with_param(
        _state: &Prio2PrepareState,
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        Ok(Self(v2_server::VerificationMessage {
            f_r: FieldPrio2::decode(bytes)?,
            g_r: FieldPrio2::decode(bytes)?,
            h_r: FieldPrio2::decode(bytes)?,
        }))
    }
}

impl Aggregator<32, 16> for Prio2 {
    type PrepareState = Prio2PrepareState;
    type PrepareShare = Prio2PrepareShare;
    type PrepareMessage = ();

    fn prepare_init(
        &self,
        agg_key: &[u8; 32],
        agg_id: usize,
        _agg_param: &Self::AggregationParam,
        nonce: &[u8; 16],
        _public_share: &Self::PublicShare,
        input_share: &Share<FieldPrio2, 32>,
    ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
        let is_leader = role_try_from(agg_id)?;

        // In the ENPA Prio system, the query randomness is generated by a third party and
        // distributed to the Aggregators after they receive their input shares. In a VDAF, shared
        // randomness is derived from a nonce selected by the client. For Prio2 we compute the
        // query using HMAC-SHA256 evaluated over the nonce.
        //
        // Unwrap safety: new_from_slice() is infallible for Hmac.
        let mut mac = Hmac::<Sha256>::new_from_slice(agg_key).unwrap();
        mac.update(nonce);
        let hmac_tag = mac.finalize();
        let mut prng = Prng::from_prio2_seed(&hmac_tag.into_bytes().into());
        let query_rand = self.choose_eval_at(&mut prng);

        self.prepare_init_with_query_rand(query_rand, input_share, is_leader)
    }

    fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Prio2PrepareShare>>(
        &self,
        _: &Self::AggregationParam,
        inputs: M,
    ) -> Result<(), VdafError> {
        let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> =
            inputs.into_iter().map(|msg| msg.0).collect();
        if verifier_shares.len() != 2 {
            return Err(VdafError::Uncategorized(
                "wrong number of verifier shares".into(),
            ));
        }

        if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) {
            return Err(VdafError::Uncategorized(
                "proof verifier check failed".into(),
            ));
        }

        Ok(())
    }

    fn prepare_next(
        &self,
        state: Prio2PrepareState,
        _input: (),
    ) -> Result<PrepareTransition<Self, 32, 16>, VdafError> {
        let data = match state.0 {
            Share::Leader(data) => data,
            Share::Helper(seed) => {
                let prng = Prng::from_prio2_seed(seed.as_ref());
                prng.take(self.input_len).collect()
            }
        };
        Ok(PrepareTransition::Finish(OutputShare::from(data)))
    }

    fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>(
        &self,
        _agg_param: &Self::AggregationParam,
        out_shares: M,
    ) -> Result<AggregateShare<FieldPrio2>, VdafError> {
        let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
        for out_share in out_shares.into_iter() {
            agg_share.accumulate(&out_share)?;
        }

        Ok(agg_share)
    }
}

impl Collector for Prio2 {
    fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>(
        &self,
        _agg_param: &Self::AggregationParam,
        agg_shares: M,
        _num_measurements: usize,
    ) -> Result<Vec<u32>, VdafError> {
        let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
        for agg_share in agg_shares.into_iter() {
            agg.merge(&agg_share)?;
        }

        Ok(agg.0.into_iter().map(u32::from).collect())
    }
}

impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> {
    fn decode_with_param(
        (prio2, agg_id): &(&'a Prio2, usize),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
        let decoder = if is_leader {
            ShareDecodingParameter::Leader(proof_length(prio2.input_len))
        } else {
            ShareDecodingParameter::Helper
        };

        Share::decode_with_param(&decoder, bytes)
    }
}

impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for OutputShare<F>
where
    F: FieldElement,
{
    fn decode_with_param(
        (prio2, _): &(&'a Prio2, &'a ()),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        decode_fieldvec(prio2.input_len, bytes).map(Self)
    }
}

impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for AggregateShare<F>
where
    F: FieldElement,
{
    fn decode_with_param(
        (prio2, _): &(&'a Prio2, &'a ()),
        bytes: &mut Cursor<&[u8]>,
    ) -> Result<Self, CodecError> {
        decode_fieldvec(prio2.input_len, bytes).map(Self)
    }
}

fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
    match agg_id {
        0 => Ok(true),
        1 => Ok(false),
        _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::vdaf::{
        equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector,
        test_utils::run_vdaf,
    };
    use assert_matches::assert_matches;
    use rand::prelude::*;

    #[test]
    fn run_prio2() {
        let prio2 = Prio2::new(6).unwrap();

        assert_eq!(
            run_vdaf(
                &prio2,
                &(),
                [
                    vec![0, 0, 0, 0, 1, 0],
                    vec![0, 1, 0, 0, 0, 0],
                    vec![0, 1, 1, 0, 0, 0],
                    vec![1, 1, 1, 0, 0, 0],
                    vec![0, 0, 0, 0, 1, 1],
                ]
            )
            .unwrap(),
            vec![1, 3, 2, 0, 2, 1],
        );
    }

    #[test]
    fn prepare_state_serialization() {
        let mut rng = thread_rng();
        let verify_key = rng.gen::<[u8; 32]>();
        let nonce = rng.gen::<[u8; 16]>();
        let data = vec![0, 0, 1, 1, 0];
        let prio2 = Prio2::new(data.len()).unwrap();
        let (public_share, input_shares) = prio2.shard(&data, &nonce).unwrap();
        for (agg_id, input_share) in input_shares.iter().enumerate() {
            let (prepare_state, prepare_share) = prio2
                .prepare_init(
                    &verify_key,
                    agg_id,
                    &(),
                    &[0; 16],
                    &public_share,
                    input_share,
                )
                .unwrap();

            let encoded_prepare_state = prepare_state.get_encoded().unwrap();
            let decoded_prepare_state = Prio2PrepareState::get_decoded_with_param(
                &(&prio2, agg_id),
                &encoded_prepare_state,
            )
            .expect("failed to decode prepare state");
            assert_eq!(decoded_prepare_state, prepare_state);
            assert_eq!(
                prepare_state.encoded_len().unwrap(),
                encoded_prepare_state.len()
            );

            let encoded_prepare_share = prepare_share.get_encoded().unwrap();
            let decoded_prepare_share =
                Prio2PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share)
                    .expect("failed to decode prepare share");
            assert_eq!(decoded_prepare_share.0.f_r, prepare_share.0.f_r);
            assert_eq!(decoded_prepare_share.0.g_r, prepare_share.0.g_r);
            assert_eq!(decoded_prepare_share.0.h_r, prepare_share.0.h_r);
            assert_eq!(
                prepare_share.encoded_len().unwrap(),
                encoded_prepare_share.len()
            );
        }
    }

    #[test]
    fn roundtrip_output_share() {
        let vdaf = Prio2::new(31).unwrap();
        fieldvec_roundtrip_test::<FieldPrio2, Prio2, OutputShare<FieldPrio2>>(&vdaf, &(), 31);
    }

    #[test]
    fn roundtrip_aggregate_share() {
        let vdaf = Prio2::new(31).unwrap();
        fieldvec_roundtrip_test::<FieldPrio2, Prio2, AggregateShare<FieldPrio2>>(&vdaf, &(), 31);
    }

    #[test]
    fn priov2_backward_compatibility() {
        let test_vector: Priov2TestVector =
            serde_json::from_str(include_str!("test_vec/prio2/fieldpriov2.json")).unwrap();
        let vdaf = Prio2::new(test_vector.dimension).unwrap();
        let mut leader_output_shares = Vec::new();
        let mut helper_output_shares = Vec::new();
        for (server_1_share, server_2_share) in test_vector
            .server_1_decrypted_shares
            .iter()
            .zip(&test_vector.server_2_decrypted_shares)
        {
            let input_share_1 = Share::get_decoded_with_param(&(&vdaf, 0), server_1_share).unwrap();
            let input_share_2 = Share::get_decoded_with_param(&(&vdaf, 1), server_2_share).unwrap();
            let (prepare_state_1, prepare_share_1) = vdaf
                .prepare_init(&[0; 32], 0, &(), &[0; 16], &(), &input_share_1)
                .unwrap();
            let (prepare_state_2, prepare_share_2) = vdaf
                .prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2)
                .unwrap();
            vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2])
                .unwrap();
            let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap();
            let output_share_1 =
                assert_matches!(transition_1, PrepareTransition::Finish(out) => out);
            let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap();
            let output_share_2 =
                assert_matches!(transition_2, PrepareTransition::Finish(out) => out);
            leader_output_shares.push(output_share_1);
            helper_output_shares.push(output_share_2);
        }

        let leader_aggregate_share = vdaf.aggregate(&(), leader_output_shares).unwrap();
        let helper_aggregate_share = vdaf.aggregate(&(), helper_output_shares).unwrap();
        let aggregate_result = vdaf
            .unshard(
                &(),
                [leader_aggregate_share, helper_aggregate_share],
                test_vector.server_1_decrypted_shares.len(),
            )
            .unwrap();
        let reconstructed = aggregate_result
            .into_iter()
            .map(FieldPrio2::from)
            .collect::<Vec<_>>();

        assert_eq!(reconstructed, test_vector.reference_sum);
    }

    #[test]
    fn prepare_state_equality_test() {
        equality_comparison_test(&[
            Prio2PrepareState(Share::Leader(Vec::from([
                FieldPrio2::from(0),
                FieldPrio2::from(1),
            ]))),
            Prio2PrepareState(Share::Leader(Vec::from([
                FieldPrio2::from(1),
                FieldPrio2::from(0),
            ]))),
            Prio2PrepareState(Share::Helper(Seed(
                (0..32).collect::<Vec<_>>().try_into().unwrap(),
            ))),
            Prio2PrepareState(Share::Helper(Seed(
                (1..33).collect::<Vec<_>>().try_into().unwrap(),
            ))),
        ])
    }
}

[ Dauer der Verarbeitung: 0.36 Sekunden  ]