Oumuamua Labs

ML-DSA Chiplet

Crates.io Docs.rs CI License: Apache 2.0

FIPS 204 Dilithium verification, expressed as a Hekate AIR, a valid proof attests the signature checks out while keeping the public key, signature, and message private.

TL;DR

ML-DSA (Dilithium) signature verification expressed as a Hekate AIR — the prover runs the FIPS 204 Verify algorithm inside the circuit, and a valid proof is the verdict: the constraint system is unsatisfiable for forged signatures, so no proof can exist. Wires NTT, twiddle ROM, Keccak, norm-check, high-bits, and RAM chiplets together over the binary tower field. Solves the "I checked this Dilithium signature, here's a 70 ms / 5 MB proof you can verify on a phone" problem instead of asking the verifier to redo 64 KB of polynomial arithmetic.

Key Characteristics

Usage in Action

use hekate_core::config::Config;
use hekate_core::trace::{ColumnTrace, ColumnType, TraceBuilder};
use hekate_crypto::DefaultHasher;
use hekate_crypto::transcript::Transcript;
use hekate_math::{Bit, Block32, Block128, TowerField};
use hekate_pqc::mldsa::{
    self, CpuMlDsaColumns, CpuMlDsaUnit, MlDsaChiplet, MlDsaLevel, MlDsaParams,
    MlDsaPublicKey, MlDsaSignature,
};
use hekate_program::{
    Air, Program, ProgramInstance, ProgramWitness,
    chiplet::ChipletDef,
    constraint::{BoundaryConstraint, ConstraintAst, builder::ConstraintSystem},
    permutation::PermutationCheckSpec,
};
use hekate_prover_sys::prove;
use hekate_verifier::HekateVerifier;
use pqcrypto_mldsa::mldsa65;
use pqcrypto_traits::sign::{DetachedSignature, PublicKey};
use rand::TryRngCore;
use rand::rngs::OsRng;

type F = Block128;
type H = DefaultHasher;

#[derive(Clone)]
struct MlDsaVerifyProgram {
    mldsa: MlDsaChiplet<F>,
    num_public: usize,
}

impl Air<F> for MlDsaVerifyProgram {
    fn name(&self) -> String {
        "MlDsaVerifyProgram".into()
    }

    fn num_columns(&self) -> usize {
        CpuMlDsaUnit::num_columns()
    }

    fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
        (0..self.num_public)
            .map(|k| BoundaryConstraint::with_public_input(CpuMlDsaColumns::DATA, k, k))
            .collect()
    }

    fn column_layout(&self) -> &[ColumnType] {
        // Air wants &'static;
        // leak is intentional.
        Box::leak(CpuMlDsaColumns::build_layout().into_boxed_slice())
    }

    fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
        vec![(mldsa::MLDSA_DATA_BUS_ID.into(), CpuMlDsaUnit::linking_spec())]
    }

    fn constraint_ast(&self) -> ConstraintAst<F> {
        let cs = ConstraintSystem::<F>::new();
        cs.assert_boolean(cs.col(CpuMlDsaColumns::SELECTOR));

        cs.build()
    }
}

impl Program<F> for MlDsaVerifyProgram {
    fn num_public_inputs(&self) -> usize {
        self.num_public
    }

    fn chiplet_defs(&self) -> hekate_core::errors::Result<Vec<ChipletDef<F>>> {
        self.mldsa.composite().flatten_defs()
    }
}

/// Returns `(cpu_trace, io_buf)`. `io_buf` is c̃ padded
/// to a 4-byte boundary so it packs into B32 lanes.
fn generate_trace(
    c_tilde: &[u8],
    cpu_num_rows: usize,
) -> Result<(ColumnTrace, Vec<u8>), Box<dyn std::error::Error>> {
    let layout = CpuMlDsaColumns::build_layout();
    let cpu_vars = cpu_num_rows.trailing_zeros() as usize;

    let mut cpu_tb = TraceBuilder::new(&layout, cpu_vars)?;

    let mut io_buf = c_tilde.to_vec();
    while !io_buf.len().is_multiple_of(4) {
        io_buf.push(0);
    }

    for (i, chunk) in io_buf.chunks(4).enumerate() {
        let val = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
        cpu_tb.set_b32(CpuMlDsaColumns::DATA, i, Block32::from(val))?;
        cpu_tb.set_bit(CpuMlDsaColumns::SELECTOR, i, Bit::ONE)?;
    }

    Ok((cpu_tb.build(), io_buf))
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let level = MlDsaLevel::MLDSA_65;
    let domain = b"ML-DSA_Verify";

    // Manual sizing today; future releases
    // derive these from `MlDsaLevel`.
    let params = MlDsaParams {
        ctrl_rows: 1 << 16,
        keccak_rows: 1 << 13,
        ntt_rows: 1 << 16,
        twiddle_rows: 1 << 16,
        norm_rows: 1 << 11,
        highbits_rows: 1 << 11,
        ram_rows: 1 << 16,
    };

    let cpu_num_rows: usize = 1 << 10;

    let (pk, sk) = mldsa65::keypair();
    let msg = b"Hekate ML-DSA-65 verification example";
    let sig = mldsa65::detached_sign(msg, &sk);

    let pk_air = MlDsaPublicKey::from_bytes(level, pk.as_bytes());
    let sig_air = MlDsaSignature::from_bytes(level, sig.as_bytes())?;

    let mldsa_chiplet = MlDsaChiplet::<F>::new(level, params);
    let chiplet_traces = mldsa_chiplet.generate_traces(&pk_air, &sig_air, msg)?;

    let (cpu_trace, io_buf) = generate_trace(&sig_air.c_tilde, cpu_num_rows)?;

    let ct_public: Vec<F> = io_buf
        .chunks(4)
        .map(|c| Block128(u32::from_le_bytes([c[0], c[1], c[2], c[3]]) as u128))
        .collect();

    let air = MlDsaVerifyProgram {
        mldsa: mldsa_chiplet,
        num_public: ct_public.len(),
    };

    let instance = ProgramInstance::new(cpu_num_rows, ct_public);
    let witness = ProgramWitness::new(cpu_trace).with_chiplets(chiplet_traces);

    let mut config = Config {
        sumcheck_blinding_factor: 2,
        ..Config::default()
    };
    OsRng.try_fill_bytes(&mut config.matrix_seed)?;

    let mut blinding_seed = [0u8; 32];
    OsRng.try_fill_bytes(&mut blinding_seed)?;

    let proof = prove(domain, &air, &instance, &witness, &config, blinding_seed, None)?;

    let mut verifier_transcript = Transcript::<H>::new(domain);
    let ok = HekateVerifier::<F, H>::verify(&air, &instance, &proof, &mut verifier_transcript, &config)?;

    assert!(ok);

    Ok(())
}