Oumuamua Labs

AES Chiplet

Crates.io Docs.rs CI License: Apache 2.0

FIPS 197 AES-128/256, expressed as a Hekate AIR, proves a batch of ciphertexts is the correct encryption of plaintexts under a given key, native to the binary tower field.

TL;DR

AES-128 and AES-256 encryption realized as a Hekate AIR over the binary tower field — SubBytes / ShiftRows / MixColumns / AddRoundKey expressed as native GF(2^8) operations, not bit-blasted boolean circuits. The S-box is a separate ROM chiplet for the GF(2^8) inversion, wired into the round AIR through a LogUp bus and into a CPU AIR for plaintext/ciphertext I/O. Solves the "prove a batch of AES encryptions in <100 µs/block of prover time without a million boolean gates" problem.

Key Characteristics

Usage in Action

use hekate_aes::{
    Aes256Chiplet, AesRound256Air,
    CpuAes256Columns, CpuAes256Unit, PhysAes256Columns,
    trace::{Aes256Call, expand_key_256},
};
use hekate_core::config::Config;
use hekate_core::errors;
use hekate_core::trace::{ColumnTrace, ColumnType, TraceBuilder};
use hekate_crypto::DefaultHasher;
use hekate_crypto::transcript::Transcript;
use hekate_math::{Bit, Block8, Block128, TowerField};
use hekate_program::{
    Air, Program, ProgramInstance, ProgramWitness,
    chiplet::ChipletDef,
    constraint::{ConstraintAst, builder::ConstraintSystem},
    permutation::PermutationCheckSpec,
};
use hekate_prover_sys::prove;
use hekate_verifier::HekateVerifier;
use rand::{TryRngCore, rngs::OsRng};

type F = Block128;
type H = DefaultHasher;

const NUM_BLOCKS: usize = 31_250;
const CPU_IO_PER_BLOCK: usize = 2;
const ROWS_PER_BLOCK: usize = 15;
const SBOX_ROUNDS: usize = 14;

/// FIPS 197 Appendix C.3.
#[rustfmt::skip]
const FIPS256_KEY: [u8; 32] = [
    0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
    0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
    0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
    0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
];

#[derive(Clone)]
struct Aes256ExampleProgram {
    aes: Aes256Chiplet<F>,
}

impl Air<F> for Aes256ExampleProgram {
    fn column_layout(&self) -> &[ColumnType] {
        // Air wants &'static;
        // OnceLock anchors the layout once.
        static LAYOUT: std::sync::OnceLock<Vec<ColumnType>> = std::sync::OnceLock::new();
        LAYOUT.get_or_init(CpuAes256Columns::build_layout)
    }

    fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
        vec![
            (AesRound256Air::LINK_BUS_ID.into(), CpuAes256Unit::linking_spec()),
            (AesRound256Air::KEY_BUS_ID.into(), CpuAes256Unit::key_linking_spec()),
        ]
    }

    fn constraint_ast(&self) -> ConstraintAst<F> {
        let cs = ConstraintSystem::<F>::new();
        cs.assert_boolean(cs.col(CpuAes256Columns::SELECTOR));
        cs.assert_boolean(cs.col(CpuAes256Columns::KEY_SELECTOR));

        cs.build()
    }
}

impl Program<F> for Aes256ExampleProgram {
    fn chiplet_defs(&self) -> errors::Result<Vec<ChipletDef<F>>> {
        self.aes.composite().flatten_defs()
    }
}

/// Block `i`'s ciphertext lives on the last
/// row of its slice in the chiplet trace.
fn extract_ciphertext(
    chiplet_trace: &ColumnTrace,
    state_in_col: usize,
    rows_per_block: usize,
    block_idx: usize,
) -> [u8; 16] {
    let output_row = block_idx * rows_per_block + (rows_per_block - 1);

    let mut ct = [0u8; 16];
    for (j, byte) in ct.iter_mut().enumerate() {
        *byte = chiplet_trace.columns[state_in_col + j]
            .as_b8_slice()
            .unwrap()[output_row]
            .to_tower()
            .0;
    }

    ct
}

fn build_cpu256_trace(
    calls: &[Aes256Call],
    ciphertexts: &[[u8; 16]],
    num_rows: usize,
) -> ColumnTrace {
    let num_vars = num_rows.trailing_zeros() as usize;

    let mut row = 0;
    let mut tb = TraceBuilder::new(&CpuAes256Columns::build_layout(), num_vars).unwrap();

    for (call, ct) in calls.iter().zip(ciphertexts) {
        for j in 0..16 {
            let whitened = call.plaintext[j] ^ call.round_keys[0][j];
            tb.set_b8(CpuAes256Columns::DATA + j, row, Block8(whitened)).unwrap();
        }

        for j in 0..32 {
            tb.set_b8(CpuAes256Columns::KEY + j, row, Block8(call.key[j])).unwrap();
        }

        tb.set_bit(CpuAes256Columns::SELECTOR, row, Bit::ONE).unwrap();
        tb.set_bit(CpuAes256Columns::KEY_SELECTOR, row, Bit::ONE).unwrap();

        row += 1;

        for (j, &byte) in ct.iter().enumerate() {
            tb.set_b8(CpuAes256Columns::DATA + j, row, Block8(byte)).unwrap();
        }

        tb.set_bit(CpuAes256Columns::SELECTOR, row, Bit::ONE).unwrap();

        row += 1;
    }

    tb.build()
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let round_keys = expand_key_256(&FIPS256_KEY);
    let chiplet_rows = (NUM_BLOCKS * ROWS_PER_BLOCK).next_power_of_two();
    let cpu_rows = (NUM_BLOCKS * CPU_IO_PER_BLOCK).next_power_of_two();
    let sbox_rom_rows = (NUM_BLOCKS * SBOX_ROUNDS).next_power_of_two();

    let mut plaintexts = vec![[0u8; 16]; NUM_BLOCKS];
    for pt in &mut plaintexts {
        OsRng.try_fill_bytes(pt)?;
    }

    let calls: Vec<Aes256Call> = plaintexts
        .iter()
        .map(|pt| Aes256Call { key: FIPS256_KEY, plaintext: *pt, round_keys })
        .collect();

    let aes = Aes256Chiplet::<F>::new(chiplet_rows, sbox_rom_rows)?;
    let chiplet_traces = aes.generate_traces(&calls)?;

    let ciphertexts: Vec<[u8; 16]> = (0..NUM_BLOCKS)
        .map(|i| extract_ciphertext(&chiplet_traces[0], PhysAes256Columns::P_STATE_IN, ROWS_PER_BLOCK, i))
        .collect();

    let cpu_trace = build_cpu256_trace(&calls, &ciphertexts, cpu_rows);
    let air = Aes256ExampleProgram { aes };

    let mut config = Config {
        sumcheck_blinding_factor: 2,
        ..Config::default()
    };
    OsRng.try_fill_bytes(&mut config.matrix_seed)?;

    let mut blinding_seed = [0u8; 32];
    OsRng.try_fill_bytes(&mut blinding_seed)?;

    let instance = ProgramInstance::new(cpu_rows, vec![]);
    let witness = ProgramWitness::new(cpu_trace).with_chiplets(chiplet_traces);

    let proof = prove(b"AES256_Example", &air, &instance, &witness, &config, blinding_seed, None)?;

    let mut vt = Transcript::<H>::new(b"AES256_Example");
    let ok = HekateVerifier::<F, H>::verify(&air, &instance, &proof, &mut vt, &config)?;

    assert!(ok);

    Ok(())
}