Anforderungen  |   Konzepte  |   Entwurf  |   Entwicklung  |   Qualitätssicherung  |   Lebenszyklus  |   Steuerung
 
 
 
 


Quelle  flp.rs   Sprache: unbekannt

 
// SPDX-License-Identifier: MPL-2.0

//! Implementation of the generic Fully Linear Proof (FLP) system specified in
//! [[draft-irtf-cfrg-vdaf-08]]. This is the main building block of [`Prio3`](crate::vdaf::prio3).
//!
//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation
//! specifies a validity circuit that defines the set of valid measurements, as well as the finite
//! field in which the validity circuit is evaluated. It also determines how raw measurements are
//! encoded as inputs to the validity circuit, and how aggregates are decoded from sums of
//! measurements.
//!
//! # Overview
//!
//! The proof system is comprised of three algorithms. The first, `prove`, is run by the prover in
//! order to generate a proof of a statement's validity. The second and third, `query` and
//! `decide`, are run by the verifier in order to check the proof. The proof asserts that the input
//! is an element of a language recognized by the arithmetic circuit. If an input is _not_ valid,
//! then the verification step will fail with high probability:
//!
//! ```
//! use prio::flp::types::Count;
//! use prio::flp::Type;
//! use prio::field::{random_vector, FieldElement, Field64};
//!
//! // The prover chooses a measurement.
//! let count = Count::new();
//! let input: Vec<Field64> = count.encode_measurement(&false).unwrap();
//!
//! // The prover and verifier agree on "joint randomness" used to generate and
//! // check the proof. The application needs to ensure that the prover
//! // "commits" to the input before this point. In Prio3, the joint
//! // randomness is derived from additive shares of the input.
//! let joint_rand = random_vector(count.joint_rand_len()).unwrap();
//!
//! // The prover generates the proof.
//! let prove_rand = random_vector(count.prove_rand_len()).unwrap();
//! let proof = count.prove(&input, &prove_rand, &joint_rand).unwrap();
//!
//! // The verifier checks the proof. In the first step, the verifier "queries"
//! // the input and proof, getting the "verifier message" in response. It then
//! // inspects the verifier to decide if the input is valid.
//! let query_rand = random_vector(count.query_rand_len()).unwrap();
//! let verifier = count.query(&input, &proof, &query_rand, &joint_rand, 1).unwrap();
//! assert!(count.decide(&verifier).unwrap());
//! ```
//!
//! [draft-irtf-cfrg-vdaf-08]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/08/

#[cfg(feature = "experimental")]
use crate::dp::DifferentialPrivacyStrategy;
use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError};
use crate::field::{FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldError};
use crate::fp::log2;
use crate::polynomial::poly_eval;
use std::any::Any;
use std::convert::TryFrom;
use std::fmt::Debug;

pub mod gadgets;
pub mod types;

/// Errors propagated by methods in this module.
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum FlpError {
    /// Calling [`Type::prove`] returned an error.
    #[error("prove error: {0}")]
    Prove(String),

    /// Calling [`Type::query`] returned an error.
    #[error("query error: {0}")]
    Query(String),

    /// Calling [`Type::decide`] returned an error.
    #[error("decide error: {0}")]
    Decide(String),

    /// Calling a gadget returned an error.
    #[error("gadget error: {0}")]
    Gadget(String),

    /// Calling the validity circuit returned an error.
    #[error("validity circuit error: {0}")]
    Valid(String),

    /// Calling [`Type::encode_measurement`] returned an error.
    #[error("value error: {0}")]
    Encode(String),

    /// Calling [`Type::decode_result`] returned an error.
    #[error("value error: {0}")]
    Decode(String),

    /// Calling [`Type::truncate`] returned an error.
    #[error("truncate error: {0}")]
    Truncate(String),

    /// Generic invalid parameter. This may be returned when an FLP type cannot be constructed.
    #[error("invalid paramter: {0}")]
    InvalidParameter(String),

    /// Returned if an FFT operation propagates an error.
    #[error("FFT error: {0}")]
    Fft(#[from] FftError),

    /// Returned if a field operation encountered an error.
    #[error("Field error: {0}")]
    Field(#[from] FieldError),

    #[cfg(feature = "experimental")]
    /// An error happened during noising.
    #[error("differential privacy error: {0}")]
    DifferentialPrivacy(#[from] crate::dp::DpError),
}

/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
/// as a vector of field elements and how validity of the encoded measurement is determined.
/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement.
pub trait Type: Sized + Eq + Clone + Debug {
    /// The type of raw measurement to be encoded.
    type Measurement: Clone + Debug;

    /// The type of aggregate result for this type.
    type AggregateResult: Clone + Debug;

    /// The finite field used for this type.
    type Field: FftFriendlyFieldElement;

    /// Encodes a measurement as a vector of [`Self::input_len`] field elements.
    fn encode_measurement(
        &self,
        measurement: &Self::Measurement,
    ) -> Result<Vec<Self::Field>, FlpError>;

    /// Decode an aggregate result.
    fn decode_result(
        &self,
        data: &[Self::Field],
        num_measurements: usize,
    ) -> Result<Self::AggregateResult, FlpError>;

    /// Returns the sequence of gadgets associated with the validity circuit.
    ///
    /// # Notes
    ///
    /// The construction of [[BBCG+19], Theorem 4.3] uses a single gadget rather than many.  The
    /// idea to generalize the proof system to allow multiple gadgets is discussed briefly in
    /// [[BBCG+19], Remark 4.5], but no construction is given. The construction implemented here
    /// requires security analysis.
    ///
    /// [BBCG+19]: https://ia.cr/2019/188
    fn gadget(&self) -> Vec<Box<dyn Gadget<Self::Field>>>;

    /// Evaluates the validity circuit on an input and returns the output.
    ///
    /// # Parameters
    ///
    /// * `gadgets` is the sequence of gadgets, presumably output by [`Self::gadget`].
    /// * `input` is the input to be validated.
    /// * `joint_rand` is the joint randomness shared by the prover and verifier.
    /// * `num_shares` is the number of input shares.
    ///
    /// # Example usage
    ///
    /// Applications typically do not call this method directly. It is used internally by
    /// [`Self::prove`] and [`Self::query`] to generate and verify the proof respectively.
    ///
    /// ```
    /// use prio::flp::types::Count;
    /// use prio::flp::Type;
    /// use prio::field::{random_vector, FieldElement, Field64};
    ///
    /// let count = Count::new();
    /// let input: Vec<Field64> = count.encode_measurement(&true).unwrap();
    /// let joint_rand = random_vector(count.joint_rand_len()).unwrap();
    /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap();
    /// assert_eq!(v, Field64::zero());
    /// ```
    fn valid(
        &self,
        gadgets: &mut Vec<Box<dyn Gadget<Self::Field>>>,
        input: &[Self::Field],
        joint_rand: &[Self::Field],
        num_shares: usize,
    ) -> Result<Self::Field, FlpError>;

    /// Constructs an aggregatable output from an encoded input. Calling this method is only safe
    /// once `input` has been validated.
    fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError>;

    /// The length in field elements of the encoded input returned by [`Self::encode_measurement`].
    fn input_len(&self) -> usize;

    /// The length in field elements of the proof generated for this type.
    fn proof_len(&self) -> usize;

    /// The length in field elements of the verifier message constructed by [`Self::query`].
    fn verifier_len(&self) -> usize;

    /// The length of the truncated output (i.e., the output of [`Type::truncate`]).
    fn output_len(&self) -> usize;

    /// The length of the joint random input.
    fn joint_rand_len(&self) -> usize;

    /// The length in field elements of the random input consumed by the prover to generate a
    /// proof. This is the same as the sum of the arity of each gadget in the validity circuit.
    fn prove_rand_len(&self) -> usize;

    /// The length in field elements of the random input consumed by the verifier to make queries
    /// against inputs and proofs. This is the same as the number of gadgets in the validity
    /// circuit.
    fn query_rand_len(&self) -> usize;

    /// Generate a proof of an input's validity. The return value is a sequence of
    /// [`Self::proof_len`] field elements.
    ///
    /// # Parameters
    ///
    /// * `input` is the input.
    /// * `prove_rand` is the prover' randomness.
    /// * `joint_rand` is the randomness shared by the prover and verifier.
    fn prove(
        &self,
        input: &[Self::Field],
        prove_rand: &[Self::Field],
        joint_rand: &[Self::Field],
    ) -> Result<Vec<Self::Field>, FlpError> {
        if input.len() != self.input_len() {
            return Err(FlpError::Prove(format!(
                "unexpected input length: got {}; want {}",
                input.len(),
                self.input_len()
            )));
        }

        if prove_rand.len() != self.prove_rand_len() {
            return Err(FlpError::Prove(format!(
                "unexpected prove randomness length: got {}; want {}",
                prove_rand.len(),
                self.prove_rand_len()
            )));
        }

        if joint_rand.len() != self.joint_rand_len() {
            return Err(FlpError::Prove(format!(
                "unexpected joint randomness length: got {}; want {}",
                joint_rand.len(),
                self.joint_rand_len()
            )));
        }

        let mut prove_rand_len = 0;
        let mut shims = self
            .gadget()
            .into_iter()
            .map(|inner| {
                let inner_arity = inner.arity();
                if prove_rand_len + inner_arity > prove_rand.len() {
                    return Err(FlpError::Prove(format!(
                        "short prove randomness: got {}; want at least {}",
                        prove_rand.len(),
                        prove_rand_len + inner_arity
                    )));
                }

                let gadget = Box::new(ProveShimGadget::new(
                    inner,
                    &prove_rand[prove_rand_len..prove_rand_len + inner_arity],
                )?) as Box<dyn Gadget<Self::Field>>;
                prove_rand_len += inner_arity;

                Ok(gadget)
            })
            .collect::<Result<Vec<_>, FlpError>>()?;
        assert_eq!(prove_rand_len, self.prove_rand_len());

        // Create a buffer for storing the proof. The buffer is longer than the proof itself; the extra
        // length is to accommodate the computation of each gadget polynomial.
        let data_len = shims
            .iter()
            .map(|shim| {
                let gadget_poly_len = gadget_poly_len(shim.degree(), wire_poly_len(shim.calls()));

                // Computing the gadget polynomial using FFT requires an amount of memory that is a
                // power of 2. Thus we choose the smallest power of 2 that is at least as large as
                // the gadget polynomial. The wire seeds are encoded in the proof, too, so we
                // include the arity of the gadget to ensure there is always enough room at the end
                // of the buffer to compute the next gadget polynomial. It's likely that the
                // memory footprint here can be reduced, with a bit of care.
                shim.arity() + gadget_poly_len.next_power_of_two()
            })
            .sum();
        let mut proof = vec![Self::Field::zero(); data_len];

        // Run the validity circuit with a sequence of "shim" gadgets that record the value of each
        // input wire of each gadget evaluation. These values are used to construct the wire
        // polynomials for each gadget in the next step.
        let _ = self.valid(&mut shims, input, joint_rand, 1)?;

        // Construct the proof.
        let mut proof_len = 0;
        for shim in shims.iter_mut() {
            let gadget = shim
                .as_any()
                .downcast_mut::<ProveShimGadget<Self::Field>>()
                .unwrap();

            // Interpolate the wire polynomials `f[0], ..., f[g_arity-1]` from the input wires of each
            // evaluation of the gadget.
            let m = wire_poly_len(gadget.calls());
            let m_inv = Self::Field::from(
                <Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap(),
            )
            .inv();
            let mut f = vec![vec![Self::Field::zero(); m]; gadget.arity()];
            for ((coefficients, values), proof_val) in f[..gadget.arity()]
                .iter_mut()
                .zip(gadget.f_vals[..gadget.arity()].iter())
                .zip(proof[proof_len..proof_len + gadget.arity()].iter_mut())
            {
                discrete_fourier_transform(coefficients, values, m)?;
                discrete_fourier_transform_inv_finish(coefficients, m, m_inv);

                // The first point on each wire polynomial is a random value chosen by the prover. This
                // point is stored in the proof so that the verifier can reconstruct the wire
                // polynomials.
                *proof_val = values[0];
            }

            // Construct the gadget polynomial `G(f[0], ..., f[g_arity-1])` and append it to `proof`.
            let gadget_poly_len = gadget_poly_len(gadget.degree(), m);
            let start = proof_len + gadget.arity();
            let end = start + gadget_poly_len.next_power_of_two();
            gadget.call_poly(&mut proof[start..end], &f)?;
            proof_len += gadget.arity() + gadget_poly_len;
        }

        // Truncate the buffer to the size of the proof.
        assert_eq!(proof_len, self.proof_len());
        proof.truncate(proof_len);
        Ok(proof)
    }

    /// Query an input and proof and return the verifier message. The return value has length
    /// [`Self::verifier_len`].
    ///
    /// # Parameters
    ///
    /// * `input` is the input or input share.
    /// * `proof` is the proof or proof share.
    /// * `query_rand` is the verifier's randomness.
    /// * `joint_rand` is the randomness shared by the prover and verifier.
    /// * `num_shares` is the total number of input shares.
    fn query(
        &self,
        input: &[Self::Field],
        proof: &[Self::Field],
        query_rand: &[Self::Field],
        joint_rand: &[Self::Field],
        num_shares: usize,
    ) -> Result<Vec<Self::Field>, FlpError> {
        if input.len() != self.input_len() {
            return Err(FlpError::Query(format!(
                "unexpected input length: got {}; want {}",
                input.len(),
                self.input_len()
            )));
        }

        if proof.len() != self.proof_len() {
            return Err(FlpError::Query(format!(
                "unexpected proof length: got {}; want {}",
                proof.len(),
                self.proof_len()
            )));
        }

        if query_rand.len() != self.query_rand_len() {
            return Err(FlpError::Query(format!(
                "unexpected query randomness length: got {}; want {}",
                query_rand.len(),
                self.query_rand_len()
            )));
        }

        if joint_rand.len() != self.joint_rand_len() {
            return Err(FlpError::Query(format!(
                "unexpected joint randomness length: got {}; want {}",
                joint_rand.len(),
                self.joint_rand_len()
            )));
        }

        let mut proof_len = 0;
        let mut shims = self
            .gadget()
            .into_iter()
            .enumerate()
            .map(|(idx, gadget)| {
                let gadget_degree = gadget.degree();
                let gadget_arity = gadget.arity();
                let m = (1 + gadget.calls()).next_power_of_two();
                let r = query_rand[idx];

                // Make sure the query randomness isn't a root of unity. Evaluating the gadget
                // polynomial at any of these points would be a privacy violation, since these points
                // were used by the prover to construct the wire polynomials.
                if r.pow(<Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap())
                    == Self::Field::one()
                {
                    return Err(FlpError::Query(format!(
                        "invalid query randomness: encountered 2^{m}-th root of unity"
                    )));
                }

                // Compute the length of the sub-proof corresponding to the `idx`-th gadget.
                let next_len = gadget_arity + gadget_degree * (m - 1) + 1;
                let proof_data = &proof[proof_len..proof_len + next_len];
                proof_len += next_len;

                Ok(Box::new(QueryShimGadget::new(gadget, r, proof_data)?)
                    as Box<dyn Gadget<Self::Field>>)
            })
            .collect::<Result<Vec<_>, _>>()?;

        // Create a buffer for the verifier data. This includes the output of the validity circuit and,
        // for each gadget `shim[idx].inner`, the wire polynomials evaluated at the query randomness
        // `query_rand[idx]` and the gadget polynomial evaluated at `query_rand[idx]`.
        let data_len = 1 + shims.iter().map(|shim| shim.arity() + 1).sum::<usize>();
        let mut verifier = Vec::with_capacity(data_len);

        // Run the validity circuit with a sequence of "shim" gadgets that record the inputs to each
        // wire for each gadget call. Record the output of the circuit and append it to the verifier
        // message.
        //
        // NOTE The proof of [BBC+19, Theorem 4.3] assumes that the output of the validity circuit is
        // equal to the output of the last gadget evaluation. Here we relax this assumption. This
        // should be OK, since it's possible to transform any circuit into one for which this is true.
        // (Needs security analysis.)
        let validity = self.valid(&mut shims, input, joint_rand, num_shares)?;
        verifier.push(validity);

        // Fill the buffer with the verifier message.
        for (query_rand_val, shim) in query_rand[..shims.len()].iter().zip(shims.iter_mut()) {
            let gadget = shim
                .as_any()
                .downcast_ref::<QueryShimGadget<Self::Field>>()
                .unwrap();

            // Reconstruct the wire polynomials `f[0], ..., f[g_arity-1]` and evaluate each wire
            // polynomial at query randomness value.
            let m = (1 + gadget.calls()).next_power_of_two();
            let m_inv = Self::Field::from(
                <Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap(),
            )
            .inv();
            let mut f = vec![Self::Field::zero(); m];
            for wire in 0..gadget.arity() {
                discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?;
                discrete_fourier_transform_inv_finish(&mut f, m, m_inv);
                verifier.push(poly_eval(&f, *query_rand_val));
            }

            // Add the value of the gadget polynomial evaluated at the query randomness value.
            verifier.push(gadget.p_at_r);
        }

        assert_eq!(verifier.len(), self.verifier_len());
        Ok(verifier)
    }

    /// Returns true if the verifier message indicates that the input from which it was generated is valid.
    fn decide(&self, verifier: &[Self::Field]) -> Result<bool, FlpError> {
        if verifier.len() != self.verifier_len() {
            return Err(FlpError::Decide(format!(
                "unexpected verifier length: got {}; want {}",
                verifier.len(),
                self.verifier_len()
            )));
        }

        // Check if the output of the circuit is 0.
        if verifier[0] != Self::Field::zero() {
            return Ok(false);
        }

        // Check that each of the proof polynomials are well-formed.
        let mut gadgets = self.gadget();
        let mut verifier_len = 1;
        for gadget in gadgets.iter_mut() {
            let next_len = 1 + gadget.arity();

            let e = gadget.call(&verifier[verifier_len..verifier_len + next_len - 1])?;
            if e != verifier[verifier_len + next_len - 1] {
                return Ok(false);
            }

            verifier_len += next_len;
        }

        Ok(true)
    }

    /// Check whether `input` and `joint_rand` have the length expected by `self`,
    /// return [`FlpError::Valid`] otherwise.
    fn valid_call_check(
        &self,
        input: &[Self::Field],
        joint_rand: &[Self::Field],
    ) -> Result<(), FlpError> {
        if input.len() != self.input_len() {
            return Err(FlpError::Valid(format!(
                "unexpected input length: got {}; want {}",
                input.len(),
                self.input_len(),
            )));
        }

        if joint_rand.len() != self.joint_rand_len() {
            return Err(FlpError::Valid(format!(
                "unexpected joint randomness length: got {}; want {}",
                joint_rand.len(),
                self.joint_rand_len()
            )));
        }

        Ok(())
    }

    /// Check if the length of `input` matches `self`'s `input_len()`,
    /// return [`FlpError::Truncate`] otherwise.
    fn truncate_call_check(&self, input: &[Self::Field]) -> Result<(), FlpError> {
        if input.len() != self.input_len() {
            return Err(FlpError::Truncate(format!(
                "Unexpected input length: got {}; want {}",
                input.len(),
                self.input_len()
            )));
        }

        Ok(())
    }
}

/// A type which supports adding noise to aggregate shares for Server Differential Privacy.
#[cfg(feature = "experimental")]
#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))]
pub trait TypeWithNoise<S>: Type
where
    S: DifferentialPrivacyStrategy,
{
    /// Add noise to the aggregate share to obtain differential privacy.
    fn add_noise_to_result(
        &self,
        dp_strategy: &S,
        agg_result: &mut [Self::Field],
        num_measurements: usize,
    ) -> Result<(), FlpError>;
}

/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit.
pub trait Gadget<F: FftFriendlyFieldElement>: Debug {
    /// Evaluates the gadget on input `inp` and returns the output.
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError>;

    /// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`.
    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>;

    /// Returns the arity of the gadget. This is the length of `inp` passed to `call` or
    /// `call_poly`.
    fn arity(&self) -> usize;

    /// Returns the circuit's arithmetic degree. This determines the minimum length the `outp`
    /// buffer passed to `call_poly`.
    fn degree(&self) -> usize;

    /// Returns the number of times the gadget is expected to be called.
    fn calls(&self) -> usize;

    /// This call is used to downcast a `Box<dyn Gadget<F>>` to a concrete type.
    fn as_any(&mut self) -> &mut dyn Any;
}

// A "shim" gadget used during proof generation to record the input wires each time a gadget is
// evaluated.
#[derive(Debug)]
struct ProveShimGadget<F: FftFriendlyFieldElement> {
    inner: Box<dyn Gadget<F>>,

    /// Points at which the wire polynomials are interpolated.
    f_vals: Vec<Vec<F>>,

    /// The number of times the gadget has been called so far.
    ct: usize,
}

impl<F: FftFriendlyFieldElement> ProveShimGadget<F> {
    fn new(inner: Box<dyn Gadget<F>>, prove_rand: &[F]) -> Result<Self, FlpError> {
        let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()];

        for (prove_rand_val, wire_poly_vals) in
            prove_rand[..f_vals.len()].iter().zip(f_vals.iter_mut())
        {
            // Choose a random field element as the first point on the wire polynomial.
            wire_poly_vals[0] = *prove_rand_val;
        }

        Ok(Self {
            inner,
            f_vals,
            ct: 1,
        })
    }
}

impl<F: FftFriendlyFieldElement> Gadget<F> for ProveShimGadget<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
            wire_poly_vals[self.ct] = *inp_val;
        }
        self.ct += 1;
        self.inner.call(inp)
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        self.inner.call_poly(outp, inp)
    }

    fn arity(&self) -> usize {
        self.inner.arity()
    }

    fn degree(&self) -> usize {
        self.inner.degree()
    }

    fn calls(&self) -> usize {
        self.inner.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

// A "shim" gadget used during proof verification to record the points at which the intermediate
// proof polynomials are evaluated.
#[derive(Debug)]
struct QueryShimGadget<F: FftFriendlyFieldElement> {
    inner: Box<dyn Gadget<F>>,

    /// Points at which intermediate proof polynomials are interpolated.
    f_vals: Vec<Vec<F>>,

    /// Points at which the gadget polynomial is interpolated.
    p_vals: Vec<F>,

    /// The gadget polynomial evaluated on a random input `r`.
    p_at_r: F,

    /// Used to compute an index into `p_val`.
    step: usize,

    /// The number of times the gadget has been called so far.
    ct: usize,
}

impl<F: FftFriendlyFieldElement> QueryShimGadget<F> {
    fn new(inner: Box<dyn Gadget<F>>, r: F, proof_data: &[F]) -> Result<Self, FlpError> {
        let gadget_degree = inner.degree();
        let gadget_arity = inner.arity();
        let m = (1 + inner.calls()).next_power_of_two();
        let p = m * gadget_degree;

        // Each call to this gadget records the values at which intermediate proof polynomials were
        // interpolated. The first point was a random value chosen by the prover and transmitted in
        // the proof.
        let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; gadget_arity];
        for wire in 0..gadget_arity {
            f_vals[wire][0] = proof_data[wire];
        }

        // Evaluate the gadget polynomial at roots of unity.
        let size = p.next_power_of_two();
        let mut p_vals = vec![F::zero(); size];
        discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?;

        // The step is used to compute the element of `p_val` that will be returned by a call to
        // the gadget.
        let step = (1 << (log2(p as u128) - log2(m as u128))) as usize;

        // Evaluate the gadget polynomial `p` at query randomness `r`.
        let p_at_r = poly_eval(&proof_data[gadget_arity..], r);

        Ok(Self {
            inner,
            f_vals,
            p_vals,
            p_at_r,
            step,
            ct: 1,
        })
    }
}

impl<F: FftFriendlyFieldElement> Gadget<F> for QueryShimGadget<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
            wire_poly_vals[self.ct] = *inp_val;
        }
        let outp = self.p_vals[self.ct * self.step];
        self.ct += 1;
        Ok(outp)
    }

    fn call_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), FlpError> {
        panic!("no-op");
    }

    fn arity(&self) -> usize {
        self.inner.arity()
    }

    fn degree(&self) -> usize {
        self.inner.degree()
    }

    fn calls(&self) -> usize {
        self.inner.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

/// Compute the length of the wire polynomial constructed from the given number of gadget calls.
#[inline]
pub(crate) fn wire_poly_len(num_calls: usize) -> usize {
    (1 + num_calls).next_power_of_two()
}

/// Compute the length of the gadget polynomial for a gadget with the given degree and from wire
/// polynomials of the given length.
#[inline]
pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usize {
    gadget_degree * (wire_poly_len - 1) + 1
}

/// Utilities for testing FLPs.
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_utils {
    use super::*;
    use crate::field::{random_vector, FieldElement, FieldElementWithInteger};

    /// Various tests for an FLP.
    #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
    pub struct FlpTest<'a, T: Type> {
        /// The FLP.
        pub flp: &'a T,

        /// Optional test name.
        pub name: Option<&'a str>,

        /// The input to use for the tests.
        pub input: &'a [T::Field],

        /// If set, the expected result of truncating the input.
        pub expected_output: Option<&'a [T::Field]>,

        /// Whether the input is expected to be valid.
        pub expect_valid: bool,
    }

    impl<T: Type> FlpTest<'_, T> {
        /// Construct a test and run it. Expect the input to be valid and compare the truncated
        /// output to the provided value.
        pub fn expect_valid<const SHARES: usize>(
            flp: &T,
            input: &[T::Field],
            expected_output: &[T::Field],
        ) {
            FlpTest {
                flp,
                name: None,
                input,
                expected_output: Some(expected_output),
                expect_valid: true,
            }
            .run::<SHARES>()
        }

        /// Construct a test and run it. Expect the input to be invalid.
        pub fn expect_invalid<const SHARES: usize>(flp: &T, input: &[T::Field]) {
            FlpTest {
                flp,
                name: None,
                input,
                expect_valid: false,
                expected_output: None,
            }
            .run::<SHARES>()
        }

        /// Construct a test and run it. Expect the input to be valid.
        pub fn expect_valid_no_output<const SHARES: usize>(flp: &T, input: &[T::Field]) {
            FlpTest {
                flp,
                name: None,
                input,
                expect_valid: true,
                expected_output: None,
            }
            .run::<SHARES>()
        }

        /// Run the tests.
        pub fn run<const SHARES: usize>(&self) {
            let name = self.name.unwrap_or("unnamed test");

            assert_eq!(
                self.input.len(),
                self.flp.input_len(),
                "{name}: unexpected input length"
            );

            let mut gadgets = self.flp.gadget();
            let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap();
            let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap();
            let query_rand = random_vector(self.flp.query_rand_len()).unwrap();
            assert_eq!(
                self.flp.query_rand_len(),
                gadgets.len(),
                "{name}: unexpected number of gadgets"
            );
            assert_eq!(
                self.flp.joint_rand_len(),
                joint_rand.len(),
                "{name}: unexpected joint rand length"
            );
            assert_eq!(
                self.flp.prove_rand_len(),
                prove_rand.len(),
                "{name}: unexpected prove rand length",
            );
            assert_eq!(
                self.flp.query_rand_len(),
                query_rand.len(),
                "{name}: unexpected query rand length",
            );

            // Run the validity circuit.
            let v = self
                .flp
                .valid(&mut gadgets, self.input, &joint_rand, 1)
                .unwrap();
            assert_eq!(
                v == T::Field::zero(),
                self.expect_valid,
                "{name}: unexpected output of valid() returned {v}",
            );

            // Generate the proof.
            let proof = self
                .flp
                .prove(self.input, &prove_rand, &joint_rand)
                .unwrap();
            assert_eq!(
                proof.len(),
                self.flp.proof_len(),
                "{name}: unexpected proof length"
            );

            // Query the proof.
            let verifier = self
                .flp
                .query(self.input, &proof, &query_rand, &joint_rand, 1)
                .unwrap();
            assert_eq!(
                verifier.len(),
                self.flp.verifier_len(),
                "{name}: unexpected verifier length"
            );

            // Decide if the input is valid.
            let res = self.flp.decide(&verifier).unwrap();
            assert_eq!(res, self.expect_valid, "{name}: unexpected decision");

            // Run distributed FLP.
            let input_shares = split_vector::<_, SHARES>(self.input);
            let proof_shares = split_vector::<_, SHARES>(&proof);
            let verifier: Vec<T::Field> = (0..SHARES)
                .map(|i| {
                    self.flp
                        .query(
                            &input_shares[i],
                            &proof_shares[i],
                            &query_rand,
                            &joint_rand,
                            SHARES,
                        )
                        .unwrap()
                })
                .reduce(|mut left, right| {
                    for (x, y) in left.iter_mut().zip(right.iter()) {
                        *x += *y;
                    }
                    left
                })
                .unwrap();

            let res = self.flp.decide(&verifier).unwrap();
            assert_eq!(
                res, self.expect_valid,
                "{name}: unexpected distributed decision"
            );

            // Try verifying various proof mutants.
            for i in 0..std::cmp::min(proof.len(), 10) {
                let mut mutated_proof = proof.clone();
                mutated_proof[i] *= T::Field::from(
                    <T::Field as FieldElementWithInteger>::Integer::try_from(23).unwrap(),
                );
                let verifier = self
                    .flp
                    .query(self.input, &mutated_proof, &query_rand, &joint_rand, 1)
                    .unwrap();
                assert!(
                    !self.flp.decide(&verifier).unwrap(),
                    "{name}: proof mutant {} deemed valid",
                    i
                );
            }

            // Try truncating the input.
            if let Some(ref expected_output) = self.expected_output {
                let output = self.flp.truncate(self.input.to_vec()).unwrap();

                assert_eq!(
                    output.len(),
                    self.flp.output_len(),
                    "{name}: unexpected output length of truncate()"
                );

                assert_eq!(
                    &output, expected_output,
                    "{name}: unexpected output of truncate()"
                );
            }
        }
    }

    fn split_vector<F: FieldElement, const SHARES: usize>(inp: &[F]) -> [Vec<F>; SHARES] {
        let mut outp = Vec::with_capacity(SHARES);
        outp.push(inp.to_vec());

        for _ in 1..SHARES {
            let share: Vec<F> =
                random_vector(inp.len()).expect("failed to generate a random vector");
            for (x, y) in outp[0].iter_mut().zip(&share) {
                *x -= *y;
            }
            outp.push(share);
        }

        outp.try_into().unwrap()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::field::{random_vector, split_vector, Field128};
    use crate::flp::gadgets::{Mul, PolyEval};
    use crate::polynomial::poly_range_check;

    use std::marker::PhantomData;

    // Simple integration test for the core FLP logic. You'll find more extensive unit tests for
    // each implemented data type in src/types.rs.
    #[test]
    fn test_flp() {
        const NUM_SHARES: usize = 2;

        let typ: TestType<Field128> = TestType::new();
        let input = typ.encode_measurement(&3).unwrap();
        assert_eq!(input.len(), typ.input_len());

        let input_shares: Vec<Vec<Field128>> = split_vector(input.as_slice(), NUM_SHARES)
            .unwrap()
            .into_iter()
            .collect();

        let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
        let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
        let query_rand = random_vector(typ.query_rand_len()).unwrap();

        let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
        assert_eq!(proof.len(), typ.proof_len());

        let proof_shares: Vec<Vec<Field128>> = split_vector(&proof, NUM_SHARES)
            .unwrap()
            .into_iter()
            .collect();

        let verifier: Vec<Field128> = (0..NUM_SHARES)
            .map(|i| {
                typ.query(
                    &input_shares[i],
                    &proof_shares[i],
                    &query_rand,
                    &joint_rand,
                    NUM_SHARES,
                )
                .unwrap()
            })
            .reduce(|mut left, right| {
                for (x, y) in left.iter_mut().zip(right.iter()) {
                    *x += *y;
                }
                left
            })
            .unwrap();
        assert_eq!(verifier.len(), typ.verifier_len());

        assert!(typ.decide(&verifier).unwrap());
    }

    /// A toy type used for testing multiple gadgets. Valid inputs of this type consist of a pair
    /// of field elements `(x, y)` where `2 <= x < 5` and `x^3 == y`.
    #[derive(Clone, Debug, PartialEq, Eq)]
    struct TestType<F>(PhantomData<F>);

    impl<F> TestType<F> {
        fn new() -> Self {
            Self(PhantomData)
        }
    }

    impl<F: FftFriendlyFieldElement> Type for TestType<F> {
        type Measurement = F::Integer;
        type AggregateResult = F::Integer;
        type Field = F;

        fn valid(
            &self,
            g: &mut Vec<Box<dyn Gadget<F>>>,
            input: &[F],
            joint_rand: &[F],
            _num_shares: usize,
        ) -> Result<F, FlpError> {
            let r = joint_rand[0];
            let mut res = F::zero();

            // Check that `data[0]^3 == data[1]`.
            let mut inp = [input[0], input[0]];
            inp[0] = g[0].call(&inp)?;
            inp[0] = g[0].call(&inp)?;
            let x3_diff = inp[0] - input[1];
            res += r * x3_diff;

            // Check that `data[0]` is in the correct range.
            let x_checked = g[1].call(&[input[0]])?;
            res += (r * r) * x_checked;

            Ok(res)
        }

        fn input_len(&self) -> usize {
            2
        }

        fn proof_len(&self) -> usize {
            // First chunk
            let mul = 2 /* gadget arity */ + 2 /* gadget degree */ * (
                (1 + 2_usize /* gadget calls */).next_power_of_two() - 1) + 1;

            // Second chunk
            let poly = 1 /* gadget arity */ + 3 /* gadget degree */ * (
                (1 + 1_usize /* gadget calls */).next_power_of_two() - 1) + 1;

            mul + poly
        }

        fn verifier_len(&self) -> usize {
            // First chunk
            let mul = 1 + 2 /* gadget arity */;

            // Second chunk
            let poly = 1 + 1 /* gadget arity */;

            1 + mul + poly
        }

        fn output_len(&self) -> usize {
            self.input_len()
        }

        fn joint_rand_len(&self) -> usize {
            1
        }

        fn prove_rand_len(&self) -> usize {
            3
        }

        fn query_rand_len(&self) -> usize {
            2
        }

        fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
            vec![
                Box::new(Mul::new(2)),
                Box::new(PolyEval::new(poly_range_check(2, 5), 1)),
            ]
        }

        fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
            Ok(vec![
                F::from(*measurement),
                F::from(*measurement).pow(F::Integer::try_from(3).unwrap()),
            ])
        }

        fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
            Ok(input)
        }

        fn decode_result(
            &self,
            _data: &[F],
            _num_measurements: usize,
        ) -> Result<F::Integer, FlpError> {
            panic!("not implemented");
        }
    }

    // In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that
    // gets triggered when the size of the buffer passed to `gadget.call_poly()` is larger than
    // needed for computing the gadget polynomial.
    #[test]
    fn issue254() {
        let typ: Issue254Type<Field128> = Issue254Type::new();
        let input = typ.encode_measurement(&0).unwrap();
        assert_eq!(input.len(), typ.input_len());
        let joint_rand = random_vector(typ.joint_rand_len()).unwrap();
        let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
        let query_rand = random_vector(typ.query_rand_len()).unwrap();
        let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap();
        let verifier = typ
            .query(&input, &proof, &query_rand, &joint_rand, 1)
            .unwrap();
        assert_eq!(verifier.len(), typ.verifier_len());
        assert!(typ.decide(&verifier).unwrap());
    }

    #[derive(Clone, Debug, PartialEq, Eq)]
    struct Issue254Type<F> {
        num_gadget_calls: [usize; 2],
        phantom: PhantomData<F>,
    }

    impl<F> Issue254Type<F> {
        fn new() -> Self {
            Self {
                // The bug is triggered when there are two gadgets, but it doesn't matter how many
                // times the second gadget is called.
                num_gadget_calls: [100, 0],
                phantom: PhantomData,
            }
        }
    }

    impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> {
        type Measurement = F::Integer;
        type AggregateResult = F::Integer;
        type Field = F;

        fn valid(
            &self,
            g: &mut Vec<Box<dyn Gadget<F>>>,
            input: &[F],
            _joint_rand: &[F],
            _num_shares: usize,
        ) -> Result<F, FlpError> {
            // This is a useless circuit, as it only accepts "0". Its purpose is to exercise the
            // use of multiple gadgets, each of which is called an arbitrary number of times.
            let mut res = F::zero();
            for _ in 0..self.num_gadget_calls[0] {
                res += g[0].call(&[input[0]])?;
            }
            for _ in 0..self.num_gadget_calls[1] {
                res += g[1].call(&[input[0]])?;
            }
            Ok(res)
        }

        fn input_len(&self) -> usize {
            1
        }

        fn proof_len(&self) -> usize {
            // First chunk
            let first = 1 /* gadget arity */ + 2 /* gadget degree */ * (
                (1 + self.num_gadget_calls[0]).next_power_of_two() - 1) + 1;

            // Second chunk
            let second = 1 /* gadget arity */ + 2 /* gadget degree */ * (
                (1 + self.num_gadget_calls[1]).next_power_of_two() - 1) + 1;

            first + second
        }

        fn verifier_len(&self) -> usize {
            // First chunk
            let first = 1 + 1 /* gadget arity */;

            // Second chunk
            let second = 1 + 1 /* gadget arity */;

            1 + first + second
        }

        fn output_len(&self) -> usize {
            self.input_len()
        }

        fn joint_rand_len(&self) -> usize {
            0
        }

        fn prove_rand_len(&self) -> usize {
            // First chunk
            let first = 1; // gadget arity

            // Second chunk
            let second = 1; // gadget arity

            first + second
        }

        fn query_rand_len(&self) -> usize {
            2 // number of gadgets
        }

        fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
            let poly = poly_range_check(0, 2); // A polynomial with degree 2
            vec![
                Box::new(PolyEval::new(poly.clone(), self.num_gadget_calls[0])),
                Box::new(PolyEval::new(poly, self.num_gadget_calls[1])),
            ]
        }

        fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
            Ok(vec![F::from(*measurement)])
        }

        fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> {
            Ok(input)
        }

        fn decode_result(
            &self,
            _data: &[F],
            _num_measurements: usize,
        ) -> Result<F::Integer, FlpError> {
            panic!("not implemented");
        }
    }
}

[ Dauer der Verarbeitung: 0.35 Sekunden  (vorverarbeitet)  ]

                                                                                                                                                                                                                                                                                                                                                                                                     


Neuigkeiten

     Aktuelles
     Motto des Tages

Software

     Produkte
     Quellcodebibliothek

Aktivitäten

     Artikel über Sicherheit
     Anleitung zur Aktivierung von SSL

Muße

     Gedichte
     Musik
     Bilder

Jenseits des Üblichen ....

Besucherstatistik

Besucherstatistik

Monitoring

Montastic status badge