Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/rust/prio/src/flp/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 17 kB image not shown  

Quelle  gadgets.rs   Sprache: unbekannt

 
Spracherkennung für: .rs vermutete Sprache: Unknown {[0] [0] [0]} [Methode: Schwerpunktbildung, einfache Gewichte, sechs Dimensionen]

// SPDX-License-Identifier: MPL-2.0

//! A collection of gadgets.

use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
use crate::field::FftFriendlyFieldElement;
use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
use crate::polynomial::{poly_deg, poly_eval, poly_mul};

#[cfg(feature = "multithreaded")]
use rayon::prelude::*;

use std::any::Any;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::marker::PhantomData;

/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for
/// polynomial multiplication. Otherwise, the gadget uses direct multiplication.
const FFT_THRESHOLD: usize = 60;

/// An arity-2 gadget that multiples its inputs.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Mul<F: FftFriendlyFieldElement> {
    /// Size of buffer for FFT operations.
    n: usize,
    /// Inverse of `n` in `F`.
    n_inv: F,
    /// The number of times this gadget will be called.
    num_calls: usize,
}

impl<F: FftFriendlyFieldElement> Mul<F> {
    /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be
    /// called by the validity circuit.
    pub fn new(num_calls: usize) -> Self {
        let n = gadget_poly_fft_mem_len(2, num_calls);
        let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
        Self {
            n,
            n_inv,
            num_calls,
        }
    }

    // Multiply input polynomials directly.
    pub(crate) fn call_poly_direct(
        &mut self,
        outp: &mut [F],
        inp: &[Vec<F>],
    ) -> Result<(), FlpError> {
        let v = poly_mul(&inp[0], &inp[1]);
        outp[..v.len()].clone_from_slice(&v);
        Ok(())
    }

    // Multiply input polynomials using FFT.
    pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        let n = self.n;
        let mut buf = vec![F::zero(); n];

        discrete_fourier_transform(&mut buf, &inp[0], n)?;
        discrete_fourier_transform(outp, &inp[1], n)?;

        for i in 0..n {
            buf[i] *= outp[i];
        }

        discrete_fourier_transform(outp, &buf, n)?;
        discrete_fourier_transform_inv_finish(outp, n, self.n_inv);
        Ok(())
    }
}

impl<F: FftFriendlyFieldElement> Gadget<F> for Mul<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        gadget_call_check(self, inp.len())?;
        Ok(inp[0] * inp[1])
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;
        if inp[0].len() >= FFT_THRESHOLD {
            self.call_poly_fft(outp, inp)
        } else {
            self.call_poly_direct(outp, inp)
        }
    }

    fn arity(&self) -> usize {
        2
    }

    fn degree(&self) -> usize {
        2
    }

    fn calls(&self) -> usize {
        self.num_calls
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

/// An arity-1 gadget that evaluates its input on some polynomial.
//
// TODO Make `poly` an array of length determined by a const generic.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PolyEval<F: FftFriendlyFieldElement> {
    poly: Vec<F>,
    /// Size of buffer for FFT operations.
    n: usize,
    /// Inverse of `n` in `F`.
    n_inv: F,
    /// The number of times this gadget will be called.
    num_calls: usize,
}

impl<F: FftFriendlyFieldElement> PolyEval<F> {
    /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times
    /// this gadget is called by the validity circuit.
    pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
        let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls);
        let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
        Self {
            poly,
            n,
            n_inv,
            num_calls,
        }
    }
}

impl<F: FftFriendlyFieldElement> PolyEval<F> {
    // Multiply input polynomials directly.
    fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        outp[0] = self.poly[0];
        let mut x = inp[0].to_vec();
        for i in 1..self.poly.len() {
            for j in 0..x.len() {
                outp[j] += self.poly[i] * x[j];
            }

            if i < self.poly.len() - 1 {
                x = poly_mul(&x, &inp[0]);
            }
        }
        Ok(())
    }

    // Multiply input polynomials using FFT.
    fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        let n = self.n;
        let inp = &inp[0];

        let mut inp_vals = vec![F::zero(); n];
        discrete_fourier_transform(&mut inp_vals, inp, n)?;

        let mut x_vals = inp_vals.clone();
        let mut x = vec![F::zero(); n];
        x[..inp.len()].clone_from_slice(inp);

        outp[0] = self.poly[0];
        for i in 1..self.poly.len() {
            for j in 0..n {
                outp[j] += self.poly[i] * x[j];
            }

            if i < self.poly.len() - 1 {
                for j in 0..n {
                    x_vals[j] *= inp_vals[j];
                }

                discrete_fourier_transform(&mut x, &x_vals, n)?;
                discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv);
            }
        }
        Ok(())
    }
}

impl<F: FftFriendlyFieldElement> Gadget<F> for PolyEval<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        gadget_call_check(self, inp.len())?;
        Ok(poly_eval(&self.poly, inp[0]))
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;

        for item in outp.iter_mut() {
            *item = F::zero();
        }

        if inp[0].len() >= FFT_THRESHOLD {
            self.call_poly_fft(outp, inp)
        } else {
            self.call_poly_direct(outp, inp)
        }
    }

    fn arity(&self) -> usize {
        1
    }

    fn degree(&self) -> usize {
        poly_deg(&self.poly)
    }

    fn calls(&self) -> usize {
        self.num_calls
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

/// Trait for abstracting over [`ParallelSum`].
pub trait ParallelSumGadget<F: FftFriendlyFieldElement, G>: Gadget<F> + Debug {
    /// Wraps `inner` into a sum gadget that calls it `chunks` many times, and adds the reuslts.
    fn new(inner: G, chunks: usize) -> Self;
}

/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
/// outputs. The arity is equal to the arity of the inner gadget times the number of times it is
/// called.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParallelSum<F: FftFriendlyFieldElement, G: Gadget<F>> {
    inner: G,
    chunks: usize,
    phantom: PhantomData<F>,
}

impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G>
    for ParallelSum<F, G>
{
    fn new(inner: G, chunks: usize) -> Self {
        Self {
            inner,
            chunks,
            phantom: PhantomData,
        }
    }
}

impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        gadget_call_check(self, inp.len())?;
        let mut outp = F::zero();
        for chunk in inp.chunks(self.inner.arity()) {
            outp += self.inner.call(chunk)?;
        }
        Ok(outp)
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;

        for x in outp.iter_mut() {
            *x = F::zero();
        }

        let mut partial_outp = vec![F::zero(); outp.len()];

        for chunk in inp.chunks(self.inner.arity()) {
            self.inner.call_poly(&mut partial_outp, chunk)?;
            for i in 0..outp.len() {
                outp[i] += partial_outp[i]
            }
        }

        Ok(())
    }

    fn arity(&self) -> usize {
        self.chunks * self.inner.arity()
    }

    fn degree(&self) -> usize {
        self.inner.degree()
    }

    fn calls(&self) -> usize {
        self.inner.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum
/// evaluation is multithreaded.
#[cfg(feature = "multithreaded")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParallelSumMultithreaded<F: FftFriendlyFieldElement, G: Gadget<F>> {
    serial_sum: ParallelSum<F, G>,
}

#[cfg(feature = "multithreaded")]
impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G>
where
    F: FftFriendlyFieldElement + Sync + Send,
    G: 'static + Gadget<F> + Clone + Sync + Send,
{
    fn new(inner: G, chunks: usize) -> Self {
        Self {
            serial_sum: ParallelSum::new(inner, chunks),
        }
    }
}

/// Data structures passed between fold operations in [`ParallelSumMultithreaded`].
#[cfg(feature = "multithreaded")]
struct ParallelSumFoldState<F, G> {
    /// Inner gadget.
    inner: G,
    /// Output buffer for `call_poly()`.
    partial_output: Vec<F>,
    /// Sum accumulator.
    partial_sum: Vec<F>,
}

#[cfg(feature = "multithreaded")]
impl<F, G> ParallelSumFoldState<F, G> {
    fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G>
    where
        G: Clone,
        F: FftFriendlyFieldElement,
    {
        ParallelSumFoldState {
            inner: gadget.clone(),
            partial_output: vec![F::zero(); length],
            partial_sum: vec![F::zero(); length],
        }
    }
}

#[cfg(feature = "multithreaded")]
impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G>
where
    F: FftFriendlyFieldElement + Sync + Send,
    G: 'static + Gadget<F> + Clone + Sync + Send,
{
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        self.serial_sum.call(inp)
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;

        // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the
        // gadget on each input polynomial, using the first temporary buffer as an output buffer.
        // Then accumulate that result into the second temporary buffer, which acts as a running
        // sum. Then, discard everything but the partial sums, add them, and finally copy the sum
        // to the output parameter. This is equivalent to the single threaded calculation in
        // ParallelSum, since we only rearrange additions, and field addition is associative.
        let res = inp
            .par_chunks(self.serial_sum.inner.arity())
            .fold(
                || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()),
                |mut state, chunk| {
                    state
                        .inner
                        .call_poly(&mut state.partial_output, chunk)
                        .unwrap();
                    for (sum_elem, output_elem) in state
                        .partial_sum
                        .iter_mut()
                        .zip(state.partial_output.iter())
                    {
                        *sum_elem += *output_elem;
                    }
                    state
                },
            )
            .map(|state| state.partial_sum)
            .reduce(
                || vec![F::zero(); outp.len()],
                |mut x, y| {
                    for (xi, yi) in x.iter_mut().zip(y.iter()) {
                        *xi += *yi;
                    }
                    x
                },
            );

        outp.copy_from_slice(&res[..]);
        Ok(())
    }

    fn arity(&self) -> usize {
        self.serial_sum.arity()
    }

    fn degree(&self) -> usize {
        self.serial_sum.degree()
    }

    fn calls(&self) -> usize {
        self.serial_sum.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

// Check that the input parameters of g.call() are well-formed.
fn gadget_call_check<F: FftFriendlyFieldElement, G: Gadget<F>>(
    gadget: &G,
    in_len: usize,
) -> Result<(), FlpError> {
    if in_len != gadget.arity() {
        return Err(FlpError::Gadget(format!(
            "unexpected number of inputs: got {}; want {}",
            in_len,
            gadget.arity()
        )));
    }

    if in_len == 0 {
        return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string()));
    }

    Ok(())
}

// Check that the input parameters of g.call_poly() are well-formed.
fn gadget_call_poly_check<F: FftFriendlyFieldElement, G: Gadget<F>>(
    gadget: &G,
    outp: &[F],
    inp: &[Vec<F>],
) -> Result<(), FlpError>
where
    G: Gadget<F>,
{
    gadget_call_check(gadget, inp.len())?;

    for i in 1..inp.len() {
        if inp[i].len() != inp[0].len() {
            return Err(FlpError::Gadget(
                "gadget called on wire polynomials with different lengths".to_string(),
            ));
        }
    }

    let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
    if outp.len() != expected {
        return Err(FlpError::Gadget(format!(
            "incorrect output length: got {}; want {}",
            outp.len(),
            expected
        )));
    }

    Ok(())
}

#[inline]
fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize {
    gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[cfg(feature = "multithreaded")]
    use crate::field::FieldElement;
    use crate::field::{random_vector, Field64 as TestField};
    use crate::prng::Prng;

    #[test]
    fn test_mul() {
        // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the
        // naive multiplication code path.
        let num_calls = FFT_THRESHOLD / 2;
        let mut g: Mul<TestField> = Mul::new(num_calls);
        gadget_test(&mut g, num_calls);

        // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises
        // FFT-based polynomial multiplication.
        let num_calls = FFT_THRESHOLD;
        let mut g: Mul<TestField> = Mul::new(num_calls);
        gadget_test(&mut g, num_calls);
    }

    #[test]
    fn test_poly_eval() {
        let poly: Vec<TestField> = random_vector(10).unwrap();

        let num_calls = FFT_THRESHOLD / 2;
        let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls);
        gadget_test(&mut g, num_calls);

        let num_calls = FFT_THRESHOLD;
        let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls);
        gadget_test(&mut g, num_calls);
    }

    #[test]
    fn test_parallel_sum() {
        let num_calls = 10;
        let chunks = 23;

        let mut g = ParallelSum::new(Mul::<TestField>::new(num_calls), chunks);
        gadget_test(&mut g, num_calls);
    }

    #[test]
    #[cfg(feature = "multithreaded")]
    fn test_parallel_sum_multithreaded() {
        use std::iter;

        for num_calls in [1, 10, 100] {
            let chunks = 23;

            let mut g = ParallelSumMultithreaded::new(Mul::new(num_calls), chunks);
            gadget_test(&mut g, num_calls);

            // Test that the multithreaded version has the same output as the normal version.
            let mut g_serial = ParallelSum::new(Mul::new(num_calls), chunks);
            assert_eq!(g.arity(), g_serial.arity());
            assert_eq!(g.degree(), g_serial.degree());
            assert_eq!(g.calls(), g_serial.calls());

            let arity = g.arity();
            let degree = g.degree();

            // Test that both gadgets evaluate to the same value when run on scalar inputs.
            let inp: Vec<TestField> = random_vector(arity).unwrap();
            let result = g.call(&inp).unwrap();
            let result_serial = g_serial.call(&inp).unwrap();
            assert_eq!(result, result_serial);

            // Test that both gadgets evaluate to the same value when run on polynomial inputs.
            let mut poly_outp =
                vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
            let mut poly_outp_serial =
                vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
            let mut prng: Prng<TestField, _> = Prng::new().unwrap();
            let poly_inp: Vec<_> = iter::repeat_with(|| {
                iter::repeat_with(|| prng.get())
                    .take(1 + num_calls)
                    .collect::<Vec<_>>()
            })
            .take(arity)
            .collect();

            g.call_poly(&mut poly_outp, &poly_inp).unwrap();
            g_serial
                .call_poly(&mut poly_outp_serial, &poly_inp)
                .unwrap();
            assert_eq!(poly_outp, poly_outp_serial);
        }
    }

    // Test that calling g.call_poly() and evaluating the output at a given point is equivalent
    // to evaluating each of the inputs at the same point and applying g.call() on the results.
    fn gadget_test<F: FftFriendlyFieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
        let wire_poly_len = (1 + num_calls).next_power_of_two();
        let mut prng = Prng::new().unwrap();
        let mut inp = vec![F::zero(); g.arity()];
        let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)];
        let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()];

        let r = prng.get();
        for i in 0..g.arity() {
            for j in 0..wire_poly_len {
                wire_polys[i][j] = prng.get();
            }
            inp[i] = poly_eval(&wire_polys[i], r);
        }

        g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
        let got = poly_eval(&gadget_poly, r);
        let want = g.call(&inp).unwrap();
        assert_eq!(got, want);

        // Repeat the call to make sure that the gadget's memory is reset properly between calls.
        g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
        let got = poly_eval(&gadget_poly, r);
        assert_eq!(got, want);
    }
}

[ Dauer der Verarbeitung: 0.35 Sekunden  ]