Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/prio/src/vdaf/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 9 kB image not shown  

Quelle  prio3_test.rs   Sprache: unbekannt

 
Spracherkennung für: .rs vermutete Sprache: Unknown {[0] [0] [0]} [Methode: Schwerpunktbildung, einfache Gewichte, sechs Dimensionen]

// SPDX-License-Identifier: MPL-2.0

//! Tools for evaluating Prio3 test vectors.

use crate::{
    codec::{Encode, ParameterizedDecode},
    flp::Type,
    vdaf::{
        prio3::{Prio3, Prio3InputShare, Prio3PrepareShare, Prio3PublicShare},
        xof::Xof,
        Aggregator, Collector, OutputShare, PrepareTransition, Vdaf,
    },
};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, convert::TryInto, fmt::Debug};

#[derive(Debug, Deserialize, Serialize)]
struct TEncoded(#[serde(with = "hex")] Vec<u8>);

impl AsRef<[u8]> for TEncoded {
    fn as_ref(&self) -> &[u8] {
        &self.0
    }
}

#[derive(Deserialize, Serialize)]
struct TPrio3Prep<M> {
    measurement: M,
    #[serde(with = "hex")]
    nonce: Vec<u8>,
    #[serde(with = "hex")]
    rand: Vec<u8>,
    public_share: TEncoded,
    input_shares: Vec<TEncoded>,
    prep_shares: Vec<Vec<TEncoded>>,
    prep_messages: Vec<TEncoded>,
    out_shares: Vec<Vec<TEncoded>>,
}

#[derive(Deserialize, Serialize)]
struct TPrio3<M> {
    verify_key: TEncoded,
    shares: u8,
    prep: Vec<TPrio3Prep<M>>,
    agg_shares: Vec<TEncoded>,
    agg_result: serde_json::Value,
    #[serde(flatten)]
    other_params: HashMap<String, serde_json::Value>,
}

macro_rules! err {
    (
        $test_num:ident,
        $error:expr,
        $msg:expr
    ) => {
        panic!("test #{} failed: {} err: {}", $test_num, $msg, $error)
    };
}

// TODO Generalize this method to work with any VDAF. To do so we would need to add
// `shard_with_random()` to traits. (There may be a less invasive alternative.)
fn check_prep_test_vec<MS, MP, T, P, const SEED_SIZE: usize>(
    prio3: &Prio3<T, P, SEED_SIZE>,
    verify_key: &[u8; SEED_SIZE],
    test_num: usize,
    t: &TPrio3Prep<MS>,
) -> Vec<OutputShare<T::Field>>
where
    MS: Clone,
    MP: From<MS>,
    T: Type<Measurement = MP>,
    P: Xof<SEED_SIZE>,
{
    let nonce = <[u8; 16]>::try_from(t.nonce.clone()).unwrap();
    let (public_share, input_shares) = prio3
        .shard_with_random(&t.measurement.clone().into(), &nonce, &t.rand)
        .expect("failed to generate input shares");

    assert_eq!(
        public_share,
        Prio3PublicShare::get_decoded_with_param(prio3, t.public_share.as_ref())
            .unwrap_or_else(|e| err!(test_num, e, "decode test vector (public share)")),
    );
    for (agg_id, want) in t.input_shares.iter().enumerate() {
        assert_eq!(
            input_shares[agg_id],
            Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref())
                .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")),
            "#{test_num}"
        );
        assert_eq!(
            input_shares[agg_id].get_encoded().unwrap(),
            want.as_ref(),
            "#{test_num}"
        )
    }

    let mut states = Vec::new();
    let mut prep_shares = Vec::new();
    for (agg_id, input_share) in input_shares.iter().enumerate() {
        let (state, prep_share) = prio3
            .prepare_init(verify_key, agg_id, &(), &nonce, &public_share, input_share)
            .unwrap_or_else(|e| err!(test_num, e, "prep state init"));
        states.push(state);
        prep_shares.push(prep_share);
    }

    assert_eq!(1, t.prep_shares.len(), "#{test_num}");
    for (i, want) in t.prep_shares[0].iter().enumerate() {
        assert_eq!(
            prep_shares[i],
            Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref())
                .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")),
            "#{test_num}"
        );
        assert_eq!(
            prep_shares[i].get_encoded().unwrap(),
            want.as_ref(),
            "#{test_num}"
        );
    }

    let inbound = prio3
        .prepare_shares_to_prepare_message(&(), prep_shares)
        .unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
    assert_eq!(t.prep_messages.len(), 1);
    assert_eq!(inbound.get_encoded().unwrap(), t.prep_messages[0].as_ref());

    let mut out_shares = Vec::new();
    for state in states.iter_mut() {
        match prio3.prepare_next(state.clone(), inbound.clone()).unwrap() {
            PrepareTransition::Finish(out_share) => {
                out_shares.push(out_share);
            }
            _ => panic!("unexpected transition"),
        }
    }

    for (got, want) in out_shares.iter().zip(t.out_shares.iter()) {
        let got: Vec<Vec<u8>> = got
            .as_ref()
            .iter()
            .map(|x| x.get_encoded().unwrap())
            .collect();
        assert_eq!(got.len(), want.len());
        for (got_elem, want_elem) in got.iter().zip(want.iter()) {
            assert_eq!(got_elem.as_slice(), want_elem.as_ref());
        }
    }

    out_shares
}

#[must_use]
fn check_aggregate_test_vec<MS, MP, T, P, const SEED_SIZE: usize>(
    prio3: &Prio3<T, P, SEED_SIZE>,
    t: &TPrio3<MS>,
) -> T::AggregateResult
where
    MS: Clone,
    MP: From<MS>,
    T: Type<Measurement = MP>,
    P: Xof<SEED_SIZE>,
{
    let verify_key = t.verify_key.as_ref().try_into().unwrap();

    let mut all_output_shares = vec![Vec::new(); prio3.num_aggregators()];
    for (test_num, p) in t.prep.iter().enumerate() {
        let output_shares = check_prep_test_vec(prio3, verify_key, test_num, p);
        for (aggregator_output_shares, output_share) in
            all_output_shares.iter_mut().zip(output_shares.into_iter())
        {
            aggregator_output_shares.push(output_share);
        }
    }

    let aggregate_shares = all_output_shares
        .into_iter()
        .map(|aggregator_output_shares| prio3.aggregate(&(), aggregator_output_shares).unwrap())
        .collect::<Vec<_>>();

    for (got, want) in aggregate_shares.iter().zip(t.agg_shares.iter()) {
        let got = got.get_encoded().unwrap();
        assert_eq!(got.as_slice(), want.as_ref());
    }

    prio3.unshard(&(), aggregate_shares, 1).unwrap()
}

/// Evaluate a Prio3 test vector. The instance of Prio3 is constructed from the `new_vdaf` callback,
/// which takes in the VDAF parameters encoded by the test vectors and the number of shares.
///
/// This version allows customizing the deserialization of measurements, via an additional type
/// parameter.
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub fn check_test_vec_custom_de<MS, MP, A, T, P, const SEED_SIZE: usize>(
    test_vec_json_str: &str,
    new_vdaf: impl Fn(&HashMap<String, serde_json::Value>, u8) -> Prio3<T, P, SEED_SIZE>,
) where
    MS: for<'de> Deserialize<'de> + Clone,
    MP: From<MS>,
    A: for<'de> Deserialize<'de> + Debug + Eq,
    T: Type<Measurement = MP, AggregateResult = A>,
    P: Xof<SEED_SIZE>,
{
    let t: TPrio3<MS> = serde_json::from_str(test_vec_json_str).unwrap();
    let vdaf = new_vdaf(&t.other_params, t.shares);
    let agg_result = check_aggregate_test_vec(&vdaf, &t);
    assert_eq!(agg_result, serde_json::from_value(t.agg_result).unwrap());
}

/// Evaluate a Prio3 test vector. The instance of Prio3 is constructed from the `new_vdaf` callback,
/// which takes in the VDAF parameters encoded by the test vectors and the number of shares.
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub fn check_test_vec<M, A, T, P, const SEED_SIZE: usize>(
    test_vec_json_str: &str,
    new_vdaf: impl Fn(&HashMap<String, serde_json::Value>, u8) -> Prio3<T, P, SEED_SIZE>,
) where
    M: for<'de> Deserialize<'de> + Clone,
    A: for<'de> Deserialize<'de> + Debug + Eq,
    T: Type<Measurement = M, AggregateResult = A>,
    P: Xof<SEED_SIZE>,
{
    check_test_vec_custom_de::<M, M, _, _, _, SEED_SIZE>(test_vec_json_str, new_vdaf)
}

#[derive(Debug, Clone, Deserialize)]
#[serde(transparent)]
struct Prio3CountMeasurement(u8);

impl From<Prio3CountMeasurement> for bool {
    fn from(value: Prio3CountMeasurement) -> Self {
        value.0 != 0
    }
}

#[test]
fn test_vec_prio3_count() {
    for test_vector_str in [
        include_str!("test_vec/08/Prio3Count_0.json"),
        include_str!("test_vec/08/Prio3Count_1.json"),
    ] {
        check_test_vec_custom_de::<Prio3CountMeasurement, _, _, _, _, 16>(
            test_vector_str,
            |_json_params, num_shares| Prio3::new_count(num_shares).unwrap(),
        );
    }
}

#[test]
fn test_vec_prio3_sum() {
    for test_vector_str in [
        include_str!("test_vec/08/Prio3Sum_0.json"),
        include_str!("test_vec/08/Prio3Sum_1.json"),
    ] {
        check_test_vec(test_vector_str, |json_params, num_shares| {
            let bits = json_params["bits"].as_u64().unwrap() as usize;
            Prio3::new_sum(num_shares, bits).unwrap()
        });
    }
}

#[test]
fn test_vec_prio3_sum_vec() {
    for test_vector_str in [
        include_str!("test_vec/08/Prio3SumVec_0.json"),
        include_str!("test_vec/08/Prio3SumVec_1.json"),
    ] {
        check_test_vec(test_vector_str, |json_params, num_shares| {
            let bits = json_params["bits"].as_u64().unwrap() as usize;
            let length = json_params["length"].as_u64().unwrap() as usize;
            let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize;
            Prio3::new_sum_vec(num_shares, bits, length, chunk_length).unwrap()
        });
    }
}

#[test]
fn test_vec_prio3_histogram() {
    for test_vector_str in [
        include_str!("test_vec/08/Prio3Histogram_0.json"),
        include_str!("test_vec/08/Prio3Histogram_1.json"),
    ] {
        check_test_vec(test_vector_str, |json_params, num_shares| {
            let length = json_params["length"].as_u64().unwrap() as usize;
            let chunk_length = json_params["chunk_length"].as_u64().unwrap() as usize;
            Prio3::new_histogram(num_shares, length, chunk_length).unwrap()
        });
    }
}

[ Dauer der Verarbeitung: 0.32 Sekunden  ]