Oumuamua Labs

ML-KEM Chiplet

Crates.io Docs.rs CI License: Apache 2.0

FIPS 203 Kyber decapsulation, expressed as a Hekate AIR, proves a shared secret was correctly recovered from a ciphertext without ever revealing the secret key.

TL;DR

ML-KEM (Kyber) decapsulation, lifted into a Hekate AIR chiplet so a verifier can confirm a shared secret was correctly recovered from a ciphertext + secret key, without ever seeing the secret key. NIST-standardized Module-LWE KEM, re-expressed over the binary tower field with native NTT, basemul, Keccak, and RAM chiplets, all under a single AIR proven by Hekate (sumcheck + Brakedown PCS). Solves the "I decapsulated correctly, prove it without revealing my long-term key" problem that no off-the-shelf KEM library answers.

Key Characteristics

Usage in Action

use hekate_core::config::Config;
use hekate_core::trace::{ColumnTrace, ColumnType, TraceBuilder};
use hekate_crypto::DefaultHasher;
use hekate_crypto::transcript::Transcript;
use hekate_math::{Bit, Block32, Block128, TowerField};
use hekate_pqc::mlkem::{
    self, CpuMlKemColumns, CpuMlKemUnit, MlKemChiplet, MlKemLevel, MlKemParams,
};
use hekate_program::{
    Air, Program, ProgramInstance, ProgramWitness,
    chiplet::ChipletDef,
    constraint::{BoundaryConstraint, ConstraintAst, builder::ConstraintSystem},
    permutation::PermutationCheckSpec,
};
use hekate_prover_sys::prove;
use hekate_verifier::HekateVerifier;
use pqcrypto_mlkem::mlkem768;
use pqcrypto_traits::kem::{Ciphertext as _, SecretKey as _, SharedSecret as _};
use rand::TryRngCore;
use rand::rngs::OsRng;

type F = Block128;
type H = DefaultHasher;

#[derive(Clone)]
struct MlKemDecapsProgram {
    mlkem: MlKemChiplet<F>,
    num_public: usize,
}

impl Air<F> for MlKemDecapsProgram {
    fn name(&self) -> String {
        "MlKemDecapsProgram".into()
    }

    fn num_columns(&self) -> usize {
        CpuMlKemUnit::num_columns()
    }

    fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
        (0..self.num_public)
            .map(|k| BoundaryConstraint::with_public_input(CpuMlKemColumns::DATA, k, k))
            .collect()
    }

    fn column_layout(&self) -> &[ColumnType] {
        // Air wants &'static;
        // leak is intentional.
        Box::leak(CpuMlKemColumns::build_layout().into_boxed_slice())
    }

    fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
        vec![
            (mlkem::MLKEM_DATA_BUS_ID.into(), CpuMlKemUnit::linking_spec()),
            (mlkem::MLKEM_SS_BUS_ID.into(), CpuMlKemUnit::ss_linking_spec()),
        ]
    }

    fn constraint_ast(&self) -> ConstraintAst<F> {
        let cs = ConstraintSystem::<F>::new();
        cs.assert_boolean(cs.col(CpuMlKemColumns::SELECTOR));
        cs.assert_boolean(cs.col(CpuMlKemColumns::SS_SELECTOR));

        cs.build()
    }
}

impl Program<F> for MlKemDecapsProgram {
    fn num_public_inputs(&self) -> usize {
        self.num_public
    }

    fn chiplet_defs(&self) -> hekate_core::errors::Result<Vec<ChipletDef<F>>> {
        self.mlkem.composite().flatten_defs()
    }
}

fn generate_trace(
    ct: &[u8],
    shared_secret: &[u8; 32],
    cpu_num_rows: usize,
) -> Result<ColumnTrace, Box<dyn std::error::Error>> {
    let layout = CpuMlKemColumns::build_layout();
    let cpu_vars = cpu_num_rows.trailing_zeros() as usize;

    let mut cpu_tb = TraceBuilder::new(&layout, cpu_vars)?;

    for (i, chunk) in ct.chunks(4).enumerate() {
        let mut buf = [0u8; 4];
        buf[..chunk.len()].copy_from_slice(chunk);

        cpu_tb.set_b32(CpuMlKemColumns::DATA, i, Block32::from(u32::from_le_bytes(buf)))?;
        cpu_tb.set_bit(CpuMlKemColumns::SELECTOR, i, Bit::ONE)?;
    }

    // Shared secret occupies the row
    // immediately after the ct chunks.
    let ss_row = ct.chunks(4).count();
    for i in 0..4 {
        let lo = u32::from_le_bytes(shared_secret[i * 8..i * 8 + 4].try_into()?);
        let hi = u32::from_le_bytes(shared_secret[i * 8 + 4..i * 8 + 8].try_into()?);

        cpu_tb.set_b32(CpuMlKemColumns::SS_DATA + i, ss_row, Block32::from(lo))?;
        cpu_tb.set_b32(CpuMlKemColumns::SS_DATA + 4 + i, ss_row, Block32::from(hi))?;
    }

    cpu_tb.set_bit(CpuMlKemColumns::SS_SELECTOR, ss_row, Bit::ONE)?;

    Ok(cpu_tb.build())
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Manual sizing today; future releases
    // derive these from `MlKemLevel`.
    let params = MlKemParams {
        ctrl_rows: 1 << 16,
        keccak_rows: 1 << 11,
        ntt_rows: 1 << 15,
        twiddle_rows: 1 << 15,
        basemul_rows: 1 << 12,
        ram_rows: 1 << 16,
    };

    let cpu_num_rows: usize = 1 << 10;

    let (nist_pk, nist_sk) = mlkem768::keypair();
    let (nist_ss, nist_ct) = mlkem768::encapsulate(&nist_pk);

    let ct = nist_ct.as_bytes();
    let sk = nist_sk.as_bytes();
    let expected_ss = nist_ss.as_bytes();

    let mlkem_chiplet = MlKemChiplet::<F>::new(MlKemLevel::MLKEM_768, params);

    let (chiplet_traces, shared_secret) = mlkem_chiplet.generate_traces(ct, sk)?;
    assert_eq!(&shared_secret, expected_ss);

    let cpu_trace = generate_trace(ct, &shared_secret, cpu_num_rows)?;

    let ct_public: Vec<F> = ct
        .chunks(4)
        .map(|chunk| {
            let mut buf = [0u8; 4];
            buf[..chunk.len()].copy_from_slice(chunk);

            Block128(u32::from_le_bytes(buf) as u128)
        })
        .collect();

    let air = MlKemDecapsProgram {
        mlkem: mlkem_chiplet,
        num_public: ct_public.len(),
    };

    let instance = ProgramInstance::new(cpu_num_rows, ct_public);
    let witness = ProgramWitness::new(cpu_trace).with_chiplets(chiplet_traces);

    let mut config = Config {
        sumcheck_blinding_factor: 2,
        ..Config::default()
    };
    OsRng.try_fill_bytes(&mut config.matrix_seed)?;

    let mut blinding_seed = [0u8; 32];
    OsRng.try_fill_bytes(&mut blinding_seed)?;

    let proof = prove(b"ML-KEM-768_Decaps", &air, &instance, &witness, &config, blinding_seed, None)?;

    let mut verifier_transcript = Transcript::<H>::new(b"ML-KEM-768_Decaps");
    let ok = HekateVerifier::<F, H>::verify(&air, &instance, &proof, &mut verifier_transcript, &config)?;

    assert!(ok);

    Ok(())
}