ML-KEM Chiplet

FIPS 203 Kyber decapsulation, expressed as a Hekate AIR, proves a shared secret was correctly recovered from a
ciphertext without ever revealing the secret key.
TL;DR
ML-KEM (Kyber) decapsulation, lifted into a Hekate AIR chiplet so a verifier can confirm a shared secret was correctly
recovered from a ciphertext + secret key, without ever seeing the secret key. NIST-standardized Module-LWE KEM,
re-expressed over the binary tower field with native NTT, basemul, Keccak, and RAM chiplets, all under a single AIR
proven by Hekate (sumcheck + Brakedown PCS). Solves the "I decapsulated correctly, prove it without revealing my
long-term key" problem that no off-the-shelf KEM library answers.
Key Characteristics
- FIPS 203 compliant. Decapsulation matches
pqcrypto_mlkem::mlkem768 byte-for-byte; example asserts the recovered
shared secret against the NIST reference encapsulation.
- All three FIPS 203 parameter sets.
MLKEM_512, MLKEM_768, MLKEM_1024 — chosen at runtime via MlKemLevel.
- Binary-field native. Trace lives in
Block128 (the binary tower field B128); NTT, basemul, twiddle ROM, Keccak,
and RAM are separate chiplets glued by LogUp permutation buses, not a monolithic circuit.
- Apple M3 Max budget: 1.40 s proving, 331 MB peak, 4,232 KiB proof, 30.6 ms verify.
- Two public buses.
MLKEM_DATA_BUS_ID exposes the ciphertext; MLKEM_SS_BUS_ID exposes the recovered shared
secret — both as B32 columns linked into the CPU AIR.
Usage in Action
use hekate_core::config::Config;
use hekate_core::trace::{ColumnTrace, ColumnType, TraceBuilder};
use hekate_crypto::DefaultHasher;
use hekate_crypto::transcript::Transcript;
use hekate_math::{Bit, Block32, Block128, TowerField};
use hekate_pqc::mlkem::{
self, CpuMlKemColumns, CpuMlKemUnit, MlKemChiplet, MlKemLevel, MlKemParams,
};
use hekate_program::{
Air, Program, ProgramInstance, ProgramWitness,
chiplet::ChipletDef,
constraint::{BoundaryConstraint, ConstraintAst, builder::ConstraintSystem},
permutation::PermutationCheckSpec,
};
use hekate_prover_sys::prove;
use hekate_verifier::HekateVerifier;
use pqcrypto_mlkem::mlkem768;
use pqcrypto_traits::kem::{Ciphertext as _, SecretKey as _, SharedSecret as _};
use rand::TryRngCore;
use rand::rngs::OsRng;
type F = Block128;
type H = DefaultHasher;
#[derive(Clone)]
struct MlKemDecapsProgram {
mlkem: MlKemChiplet<F>,
num_public: usize,
}
impl Air<F> for MlKemDecapsProgram {
fn name(&self) -> String {
"MlKemDecapsProgram".into()
}
fn num_columns(&self) -> usize {
CpuMlKemUnit::num_columns()
}
fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
(0..self.num_public)
.map(|k| BoundaryConstraint::with_public_input(CpuMlKemColumns::DATA, k, k))
.collect()
}
fn column_layout(&self) -> &[ColumnType] {
Box::leak(CpuMlKemColumns::build_layout().into_boxed_slice())
}
fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
vec![
(mlkem::MLKEM_DATA_BUS_ID.into(), CpuMlKemUnit::linking_spec()),
(mlkem::MLKEM_SS_BUS_ID.into(), CpuMlKemUnit::ss_linking_spec()),
]
}
fn constraint_ast(&self) -> ConstraintAst<F> {
let cs = ConstraintSystem::<F>::new();
cs.assert_boolean(cs.col(CpuMlKemColumns::SELECTOR));
cs.assert_boolean(cs.col(CpuMlKemColumns::SS_SELECTOR));
cs.build()
}
}
impl Program<F> for MlKemDecapsProgram {
fn num_public_inputs(&self) -> usize {
self.num_public
}
fn chiplet_defs(&self) -> hekate_core::errors::Result<Vec<ChipletDef<F>>> {
self.mlkem.composite().flatten_defs()
}
}
fn generate_trace(
ct: &[u8],
shared_secret: &[u8; 32],
cpu_num_rows: usize,
) -> Result<ColumnTrace, Box<dyn std::error::Error>> {
let layout = CpuMlKemColumns::build_layout();
let cpu_vars = cpu_num_rows.trailing_zeros() as usize;
let mut cpu_tb = TraceBuilder::new(&layout, cpu_vars)?;
for (i, chunk) in ct.chunks(4).enumerate() {
let mut buf = [0u8; 4];
buf[..chunk.len()].copy_from_slice(chunk);
cpu_tb.set_b32(CpuMlKemColumns::DATA, i, Block32::from(u32::from_le_bytes(buf)))?;
cpu_tb.set_bit(CpuMlKemColumns::SELECTOR, i, Bit::ONE)?;
}
let ss_row = ct.chunks(4).count();
for i in 0..4 {
let lo = u32::from_le_bytes(shared_secret[i * 8..i * 8 + 4].try_into()?);
let hi = u32::from_le_bytes(shared_secret[i * 8 + 4..i * 8 + 8].try_into()?);
cpu_tb.set_b32(CpuMlKemColumns::SS_DATA + i, ss_row, Block32::from(lo))?;
cpu_tb.set_b32(CpuMlKemColumns::SS_DATA + 4 + i, ss_row, Block32::from(hi))?;
}
cpu_tb.set_bit(CpuMlKemColumns::SS_SELECTOR, ss_row, Bit::ONE)?;
Ok(cpu_tb.build())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let params = MlKemParams {
ctrl_rows: 1 << 16,
keccak_rows: 1 << 11,
ntt_rows: 1 << 15,
twiddle_rows: 1 << 15,
basemul_rows: 1 << 12,
ram_rows: 1 << 16,
};
let cpu_num_rows: usize = 1 << 10;
let (nist_pk, nist_sk) = mlkem768::keypair();
let (nist_ss, nist_ct) = mlkem768::encapsulate(&nist_pk);
let ct = nist_ct.as_bytes();
let sk = nist_sk.as_bytes();
let expected_ss = nist_ss.as_bytes();
let mlkem_chiplet = MlKemChiplet::<F>::new(MlKemLevel::MLKEM_768, params);
let (chiplet_traces, shared_secret) = mlkem_chiplet.generate_traces(ct, sk)?;
assert_eq!(&shared_secret, expected_ss);
let cpu_trace = generate_trace(ct, &shared_secret, cpu_num_rows)?;
let ct_public: Vec<F> = ct
.chunks(4)
.map(|chunk| {
let mut buf = [0u8; 4];
buf[..chunk.len()].copy_from_slice(chunk);
Block128(u32::from_le_bytes(buf) as u128)
})
.collect();
let air = MlKemDecapsProgram {
mlkem: mlkem_chiplet,
num_public: ct_public.len(),
};
let instance = ProgramInstance::new(cpu_num_rows, ct_public);
let witness = ProgramWitness::new(cpu_trace).with_chiplets(chiplet_traces);
let mut config = Config {
sumcheck_blinding_factor: 2,
..Config::default()
};
OsRng.try_fill_bytes(&mut config.matrix_seed)?;
let mut blinding_seed = [0u8; 32];
OsRng.try_fill_bytes(&mut blinding_seed)?;
let proof = prove(b"ML-KEM-768_Decaps", &air, &instance, &witness, &config, blinding_seed, None)?;
let mut verifier_transcript = Transcript::<H>::new(b"ML-KEM-768_Decaps");
let ok = HekateVerifier::<F, H>::verify(&air, &instance, &proof, &mut verifier_transcript, &config)?;
assert!(ok);
Ok(())
}