Anforderungen  |   Konzepte  |   Entwurf  |   Entwicklung  |   Qualitätssicherung  |   Lebenszyklus  |   Steuerung
 
 
 
 


Quelle  gadgets.rs   Sprache: unbekannt

 
// SPDX-License-Identifier: MPL-2.0

//! A collection of gadgets.

use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
use crate::field::FftFriendlyFieldElement;
use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
use crate::polynomial::{poly_deg, poly_eval, poly_mul};

#[cfg(feature = "multithreaded")]
use rayon::prelude::*;

use std::any::Any;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::marker::PhantomData;

/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for
/// polynomial multiplication. Otherwise, the gadget uses direct multiplication.
const FFT_THRESHOLD: usize = 60;

/// An arity-2 gadget that multiples its inputs.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Mul<F: FftFriendlyFieldElement> {
    /// Size of buffer for FFT operations.
    n: usize,
    /// Inverse of `n` in `F`.
    n_inv: F,
    /// The number of times this gadget will be called.
    num_calls: usize,
}

impl<F: FftFriendlyFieldElement> Mul<F> {
    /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be
    /// called by the validity circuit.
    pub fn new(num_calls: usize) -> Self {
        let n = gadget_poly_fft_mem_len(2, num_calls);
        let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
        Self {
            n,
            n_inv,
            num_calls,
        }
    }

    // Multiply input polynomials directly.
    pub(crate) fn call_poly_direct(
        &mut self,
        outp: &mut [F],
        inp: &[Vec<F>],
    ) -> Result<(), FlpError> {
        let v = poly_mul(&inp[0], &inp[1]);
        outp[..v.len()].clone_from_slice(&v);
        Ok(())
    }

    // Multiply input polynomials using FFT.
    pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        let n = self.n;
        let mut buf = vec![F::zero(); n];

        discrete_fourier_transform(&mut buf, &inp[0], n)?;
        discrete_fourier_transform(outp, &inp[1], n)?;

        for i in 0..n {
            buf[i] *= outp[i];
        }

        discrete_fourier_transform(outp, &buf, n)?;
        discrete_fourier_transform_inv_finish(outp, n, self.n_inv);
        Ok(())
    }
}

impl<F: FftFriendlyFieldElement> Gadget<F> for Mul<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        gadget_call_check(self, inp.len())?;
        Ok(inp[0] * inp[1])
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;
        if inp[0].len() >= FFT_THRESHOLD {
            self.call_poly_fft(outp, inp)
        } else {
            self.call_poly_direct(outp, inp)
        }
    }

    fn arity(&self) -> usize {
        2
    }

    fn degree(&self) -> usize {
        2
    }

    fn calls(&self) -> usize {
        self.num_calls
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

/// An arity-1 gadget that evaluates its input on some polynomial.
//
// TODO Make `poly` an array of length determined by a const generic.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PolyEval<F: FftFriendlyFieldElement> {
    poly: Vec<F>,
    /// Size of buffer for FFT operations.
    n: usize,
    /// Inverse of `n` in `F`.
    n_inv: F,
    /// The number of times this gadget will be called.
    num_calls: usize,
}

impl<F: FftFriendlyFieldElement> PolyEval<F> {
    /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times
    /// this gadget is called by the validity circuit.
    pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
        let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls);
        let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
        Self {
            poly,
            n,
            n_inv,
            num_calls,
        }
    }
}

impl<F: FftFriendlyFieldElement> PolyEval<F> {
    // Multiply input polynomials directly.
    fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        outp[0] = self.poly[0];
        let mut x = inp[0].to_vec();
        for i in 1..self.poly.len() {
            for j in 0..x.len() {
                outp[j] += self.poly[i] * x[j];
            }

            if i < self.poly.len() - 1 {
                x = poly_mul(&x, &inp[0]);
            }
        }
        Ok(())
    }

    // Multiply input polynomials using FFT.
    fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        let n = self.n;
        let inp = &inp[0];

        let mut inp_vals = vec![F::zero(); n];
        discrete_fourier_transform(&mut inp_vals, inp, n)?;

        let mut x_vals = inp_vals.clone();
        let mut x = vec![F::zero(); n];
        x[..inp.len()].clone_from_slice(inp);

        outp[0] = self.poly[0];
        for i in 1..self.poly.len() {
            for j in 0..n {
                outp[j] += self.poly[i] * x[j];
            }

            if i < self.poly.len() - 1 {
                for j in 0..n {
                    x_vals[j] *= inp_vals[j];
                }

                discrete_fourier_transform(&mut x, &x_vals, n)?;
                discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv);
            }
        }
        Ok(())
    }
}

impl<F: FftFriendlyFieldElement> Gadget<F> for PolyEval<F> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        gadget_call_check(self, inp.len())?;
        Ok(poly_eval(&self.poly, inp[0]))
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;

        for item in outp.iter_mut() {
            *item = F::zero();
        }

        if inp[0].len() >= FFT_THRESHOLD {
            self.call_poly_fft(outp, inp)
        } else {
            self.call_poly_direct(outp, inp)
        }
    }

    fn arity(&self) -> usize {
        1
    }

    fn degree(&self) -> usize {
        poly_deg(&self.poly)
    }

    fn calls(&self) -> usize {
        self.num_calls
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

/// Trait for abstracting over [`ParallelSum`].
pub trait ParallelSumGadget<F: FftFriendlyFieldElement, G>: Gadget<F> + Debug {
    /// Wraps `inner` into a sum gadget that calls it `chunks` many times, and adds the reuslts.
    fn new(inner: G, chunks: usize) -> Self;
}

/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
/// outputs. The arity is equal to the arity of the inner gadget times the number of times it is
/// called.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParallelSum<F: FftFriendlyFieldElement, G: Gadget<F>> {
    inner: G,
    chunks: usize,
    phantom: PhantomData<F>,
}

impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G>
    for ParallelSum<F, G>
{
    fn new(inner: G, chunks: usize) -> Self {
        Self {
            inner,
            chunks,
            phantom: PhantomData,
        }
    }
}

impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        gadget_call_check(self, inp.len())?;
        let mut outp = F::zero();
        for chunk in inp.chunks(self.inner.arity()) {
            outp += self.inner.call(chunk)?;
        }
        Ok(outp)
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;

        for x in outp.iter_mut() {
            *x = F::zero();
        }

        let mut partial_outp = vec![F::zero(); outp.len()];

        for chunk in inp.chunks(self.inner.arity()) {
            self.inner.call_poly(&mut partial_outp, chunk)?;
            for i in 0..outp.len() {
                outp[i] += partial_outp[i]
            }
        }

        Ok(())
    }

    fn arity(&self) -> usize {
        self.chunks * self.inner.arity()
    }

    fn degree(&self) -> usize {
        self.inner.degree()
    }

    fn calls(&self) -> usize {
        self.inner.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum
/// evaluation is multithreaded.
#[cfg(feature = "multithreaded")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParallelSumMultithreaded<F: FftFriendlyFieldElement, G: Gadget<F>> {
    serial_sum: ParallelSum<F, G>,
}

#[cfg(feature = "multithreaded")]
impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G>
where
    F: FftFriendlyFieldElement + Sync + Send,
    G: 'static + Gadget<F> + Clone + Sync + Send,
{
    fn new(inner: G, chunks: usize) -> Self {
        Self {
            serial_sum: ParallelSum::new(inner, chunks),
        }
    }
}

/// Data structures passed between fold operations in [`ParallelSumMultithreaded`].
#[cfg(feature = "multithreaded")]
struct ParallelSumFoldState<F, G> {
    /// Inner gadget.
    inner: G,
    /// Output buffer for `call_poly()`.
    partial_output: Vec<F>,
    /// Sum accumulator.
    partial_sum: Vec<F>,
}

#[cfg(feature = "multithreaded")]
impl<F, G> ParallelSumFoldState<F, G> {
    fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G>
    where
        G: Clone,
        F: FftFriendlyFieldElement,
    {
        ParallelSumFoldState {
            inner: gadget.clone(),
            partial_output: vec![F::zero(); length],
            partial_sum: vec![F::zero(); length],
        }
    }
}

#[cfg(feature = "multithreaded")]
impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G>
where
    F: FftFriendlyFieldElement + Sync + Send,
    G: 'static + Gadget<F> + Clone + Sync + Send,
{
    fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
        self.serial_sum.call(inp)
    }

    fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
        gadget_call_poly_check(self, outp, inp)?;

        // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the
        // gadget on each input polynomial, using the first temporary buffer as an output buffer.
        // Then accumulate that result into the second temporary buffer, which acts as a running
        // sum. Then, discard everything but the partial sums, add them, and finally copy the sum
        // to the output parameter. This is equivalent to the single threaded calculation in
        // ParallelSum, since we only rearrange additions, and field addition is associative.
        let res = inp
            .par_chunks(self.serial_sum.inner.arity())
            .fold(
                || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()),
                |mut state, chunk| {
                    state
                        .inner
                        .call_poly(&mut state.partial_output, chunk)
                        .unwrap();
                    for (sum_elem, output_elem) in state
                        .partial_sum
                        .iter_mut()
                        .zip(state.partial_output.iter())
                    {
                        *sum_elem += *output_elem;
                    }
                    state
                },
            )
            .map(|state| state.partial_sum)
            .reduce(
                || vec![F::zero(); outp.len()],
                |mut x, y| {
                    for (xi, yi) in x.iter_mut().zip(y.iter()) {
                        *xi += *yi;
                    }
                    x
                },
            );

        outp.copy_from_slice(&res[..]);
        Ok(())
    }

    fn arity(&self) -> usize {
        self.serial_sum.arity()
    }

    fn degree(&self) -> usize {
        self.serial_sum.degree()
    }

    fn calls(&self) -> usize {
        self.serial_sum.calls()
    }

    fn as_any(&mut self) -> &mut dyn Any {
        self
    }
}

// Check that the input parameters of g.call() are well-formed.
fn gadget_call_check<F: FftFriendlyFieldElement, G: Gadget<F>>(
    gadget: &G,
    in_len: usize,
) -> Result<(), FlpError> {
    if in_len != gadget.arity() {
        return Err(FlpError::Gadget(format!(
            "unexpected number of inputs: got {}; want {}",
            in_len,
            gadget.arity()
        )));
    }

    if in_len == 0 {
        return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string()));
    }

    Ok(())
}

// Check that the input parameters of g.call_poly() are well-formed.
fn gadget_call_poly_check<F: FftFriendlyFieldElement, G: Gadget<F>>(
    gadget: &G,
    outp: &[F],
    inp: &[Vec<F>],
) -> Result<(), FlpError>
where
    G: Gadget<F>,
{
    gadget_call_check(gadget, inp.len())?;

    for i in 1..inp.len() {
        if inp[i].len() != inp[0].len() {
            return Err(FlpError::Gadget(
                "gadget called on wire polynomials with different lengths".to_string(),
            ));
        }
    }

    let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
    if outp.len() != expected {
        return Err(FlpError::Gadget(format!(
            "incorrect output length: got {}; want {}",
            outp.len(),
            expected
        )));
    }

    Ok(())
}

#[inline]
fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize {
    gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[cfg(feature = "multithreaded")]
    use crate::field::FieldElement;
    use crate::field::{random_vector, Field64 as TestField};
    use crate::prng::Prng;

    #[test]
    fn test_mul() {
        // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the
        // naive multiplication code path.
        let num_calls = FFT_THRESHOLD / 2;
        let mut g: Mul<TestField> = Mul::new(num_calls);
        gadget_test(&mut g, num_calls);

        // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises
        // FFT-based polynomial multiplication.
        let num_calls = FFT_THRESHOLD;
        let mut g: Mul<TestField> = Mul::new(num_calls);
        gadget_test(&mut g, num_calls);
    }

    #[test]
    fn test_poly_eval() {
        let poly: Vec<TestField> = random_vector(10).unwrap();

        let num_calls = FFT_THRESHOLD / 2;
        let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls);
        gadget_test(&mut g, num_calls);

        let num_calls = FFT_THRESHOLD;
        let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls);
        gadget_test(&mut g, num_calls);
    }

    #[test]
    fn test_parallel_sum() {
        let num_calls = 10;
        let chunks = 23;

        let mut g = ParallelSum::new(Mul::<TestField>::new(num_calls), chunks);
        gadget_test(&mut g, num_calls);
    }

    #[test]
    #[cfg(feature = "multithreaded")]
    fn test_parallel_sum_multithreaded() {
        use std::iter;

        for num_calls in [1, 10, 100] {
            let chunks = 23;

            let mut g = ParallelSumMultithreaded::new(Mul::new(num_calls), chunks);
            gadget_test(&mut g, num_calls);

            // Test that the multithreaded version has the same output as the normal version.
            let mut g_serial = ParallelSum::new(Mul::new(num_calls), chunks);
            assert_eq!(g.arity(), g_serial.arity());
            assert_eq!(g.degree(), g_serial.degree());
            assert_eq!(g.calls(), g_serial.calls());

            let arity = g.arity();
            let degree = g.degree();

            // Test that both gadgets evaluate to the same value when run on scalar inputs.
            let inp: Vec<TestField> = random_vector(arity).unwrap();
            let result = g.call(&inp).unwrap();
            let result_serial = g_serial.call(&inp).unwrap();
            assert_eq!(result, result_serial);

            // Test that both gadgets evaluate to the same value when run on polynomial inputs.
            let mut poly_outp =
                vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
            let mut poly_outp_serial =
                vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()];
            let mut prng: Prng<TestField, _> = Prng::new().unwrap();
            let poly_inp: Vec<_> = iter::repeat_with(|| {
                iter::repeat_with(|| prng.get())
                    .take(1 + num_calls)
                    .collect::<Vec<_>>()
            })
            .take(arity)
            .collect();

            g.call_poly(&mut poly_outp, &poly_inp).unwrap();
            g_serial
                .call_poly(&mut poly_outp_serial, &poly_inp)
                .unwrap();
            assert_eq!(poly_outp, poly_outp_serial);
        }
    }

    // Test that calling g.call_poly() and evaluating the output at a given point is equivalent
    // to evaluating each of the inputs at the same point and applying g.call() on the results.
    fn gadget_test<F: FftFriendlyFieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
        let wire_poly_len = (1 + num_calls).next_power_of_two();
        let mut prng = Prng::new().unwrap();
        let mut inp = vec![F::zero(); g.arity()];
        let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)];
        let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()];

        let r = prng.get();
        for i in 0..g.arity() {
            for j in 0..wire_poly_len {
                wire_polys[i][j] = prng.get();
            }
            inp[i] = poly_eval(&wire_polys[i], r);
        }

        g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
        let got = poly_eval(&gadget_poly, r);
        let want = g.call(&inp).unwrap();
        assert_eq!(got, want);

        // Repeat the call to make sure that the gadget's memory is reset properly between calls.
        g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
        let got = poly_eval(&gadget_poly, r);
        assert_eq!(got, want);
    }
}

[ Dauer der Verarbeitung: 0.3 Sekunden  (vorverarbeitet)  ]

                                                                                                                                                                                                                                                                                                                                                                                                     


Neuigkeiten

     Aktuelles
     Motto des Tages

Software

     Produkte
     Quellcodebibliothek

Aktivitäten

     Artikel über Sicherheit
     Anleitung zur Aktivierung von SSL

Muße

     Gedichte
     Musik
     Bilder

Jenseits des Üblichen ....

Besucherstatistik

Besucherstatistik

Monitoring

Montastic status badge