#!/usr/bin/python3
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import io
import os
import requests
import sys
import tarfile
def remove_include_guard(x: io.StringIO, guard: str) -> io.StringIO:
out = io.StringIO()
depth = 0
inside_guard =
False
for line
in x.readlines():
tokens = line.split()
if tokens
and tokens[0]
in [
"#if", "#ifdef", "#ifndef"]:
depth += 1
if len(tokens) > 1
and tokens[0] ==
"#ifndef" and tokens[1] == guard:
assert depth == 1,
"error: nested include guard"
inside_guard =
True
continue
if len(tokens) > 1
and tokens[0] ==
"#define" and tokens[1] == guard:
continue
if tokens
and tokens[0] ==
"#endif":
depth -= 1
if depth == 0
and inside_guard:
inside_guard =
False
continue
out.write(line)
out.seek(0)
return out
def remove_includes(x: io.StringIO) -> io.StringIO:
out = io.StringIO()
for line
in x.readlines():
tokens = line.split()
if tokens
and tokens[0] ==
"#include":
continue
out.write(line)
out.seek(0)
return out
# Take the else branch of any #ifdef KYBER90s ... #else ... #endif
def remove_kyber90s(x: io.StringIO) -> io.StringIO:
out = io.StringIO()
states = [
"before",
"during-drop",
"during-keep"]
state =
"before"
current_depth = 0
kyber90s_depth =
None
for line
in x.readlines():
tokens = line.split()
if tokens
and tokens[0]
in [
"#if", "#ifdef", "#ifndef"]:
current_depth += 1
if len(tokens) > 1
and tokens[0] ==
"#ifdef" and tokens[1] == "KYBER_90S":
assert kyber90s_depth ==
None,
"cannot handle nested #ifdef KYBER90S"
kyber90s_depth = current_depth
state =
"during-drop"
continue
if len(tokens) > 1
and tokens[0] ==
"#ifndef" and tokens[1] == "KYBER_90S":
assert kyber90s_depth ==
None,
"cannot handle nested #ifndef KYBER90S"
kyber90s_depth = current_depth
state =
"during-keep"
continue
if current_depth == kyber90s_depth
and tokens:
if tokens[0] ==
"#else":
assert state !=
"before"
state =
"during-keep" if state ==
"during-drop" else "during-drop"
continue
if tokens[0] ==
"#elif":
assert False,
"cannot handle #elif branch of #ifdef KYBER90S"
if tokens[0] ==
"#endif":
assert state !=
"before"
state =
"before"
kyber90s_depth =
None
current_depth -= 1
continue
if tokens
and tokens[0] ==
"#endif":
current_depth -= 1
if state ==
"during-drop":
continue
out.write(line)
out.seek(0)
return out
def add_static_to_fns(x: io.StringIO) -> io.StringIO:
out = io.StringIO()
depth = 0
for line
in x.readlines():
tokens = line.split()
# assumes return type starts on column 0
if depth == 0
and any(
line.startswith(typ)
for typ
in [
"void",
"uint32_t",
"int16_t",
"int"]
):
out.write(
"static " + line)
else:
out.write(line)
if "{" in line:
depth += 1
if "}" in line:
depth -= 1
out.seek(0)
return out
def file_block(x: io.StringIO, filename: str) -> io.StringIO:
out = io.StringIO()
out.write(f
"\n/** begin: {filename} **/\n")
out.write(x.read().strip())
out.write(f
"\n/** end: {filename} **/\n")
out.seek(0)
return out
def test():
assert 0 == len(remove_includes(io.StringIO(
"#include ")).read())
assert 0 == len(remove_kyber90s(io.StringIO(
"#ifdef KYBER_90S\nx\n#endif")).read())
test_remove_kyber90s_expect =
"#ifdef OTHER\nx\n#else\nx\n#endif"
test_remove_ifdef_kyber90s = f
"""
#ifdef KYBER_90S
x
{test_remove_kyber90s_expect}
x
#else
{test_remove_kyber90s_expect}
#endif
"""
test_remove_ifdef_kyber90s_actual = (
remove_kyber90s(io.StringIO(test_remove_ifdef_kyber90s)).read().strip()
)
assert (
test_remove_kyber90s_expect == test_remove_ifdef_kyber90s_actual
),
"remove_kyber90s unit test"
test_remove_ifndef_kyber90s = f
"""
#ifndef KYBER_90S
{test_remove_kyber90s_expect}
#else
x
{test_remove_kyber90s_expect}
x
#endif
"""
test_remove_ifndef_kyber90s_actual = (
remove_kyber90s(io.StringIO(test_remove_ifndef_kyber90s)).read().strip()
)
assert (
test_remove_kyber90s_expect == test_remove_ifndef_kyber90s_actual
),
"remove_kyber90s unit test"
test_add_static_to_fns =
"""\
void fn() {
int x[3] = {1,2,3};
}
"""
assert (
f
"static {test_add_static_to_fns}"
== add_static_to_fns(io.StringIO(test_add_static_to_fns)).read()
)
test_remove_include_guard =
"""\
#ifndef TEST_H
#define TEST_H
#endif"""
assert 0 == len(
remove_include_guard(io.StringIO(test_remove_include_guard),
"TEST_H").read()
)
assert (
test_remove_include_guard
== remove_include_guard(
io.StringIO(test_remove_include_guard),
"OTHER_H"
).read()
)
def is_hex(s: str) -> bool:
try:
int(s, 16)
except ValueError:
return False
return True
if __name__ ==
"__main__":
test()
repo = f
"https://github.com/pq-crystals/kyber"
out =
"kyber-pqcrystals-ref.c"
out_api =
"kyber-pqcrystals-ref.h"
out_orig =
"kyber-pqcrystals-ref.c.orig"
if len(sys.argv) == 2
and len(sys.argv[1]) >= 6
and is_hex(sys.argv[1]):
commit = sys.argv[1]
print(f
"* using commit id {commit}")
else:
print(
f
"""\
Usage: python3 {sys.argv[0]} [commit]
where [commit]
is an 8+ hex digit commit id
from {repo}.
"""
)
sys.exit(1)
short_commit = commit[:8]
tarball_url = f
"{repo}/tarball/{commit}"
archive = f
"kyber-{short_commit}.tar.gz"
headers = [
"params.h",
"reduce.h",
"ntt.h",
"poly.h",
"cbd.h",
"polyvec.h",
"indcpa.h",
"fips202.h",
"symmetric.h",
"kem.h",
]
sources = [
"reduce.c",
"cbd.c",
"ntt.c",
"poly.c",
"polyvec.c",
"indcpa.c",
"fips202.c",
"symmetric-shake.c",
"kem.c",
]
if not os.path.isfile(archive):
print(f
"* fetching {tarball_url}")
req = requests.request(method=
"GET", url=tarball_url)
if not req.ok:
print(f
"* failed: {req.reason}")
sys.exit(1)
with open(archive,
"wb")
as f:
f.write(req.content)
print(f
"* extracting files from {archive}")
with open(archive,
"rb")
as f:
tarball = tarfile.open(mode=
"r:gz", fileobj=f)
topdir = tarball.members[0].path
assert (
topdir == f
"pq-crystals-kyber-{commit[:7]}"
),
"tarball directory structure changed"
# Write a single-file copy without modifications for easy diffing
print(f
"* writing unmodified files to {out_orig}")
with open(out_orig,
"w")
as f:
for filename
in headers:
x = tarball.extractfile(f
"{topdir}/ref/{filename}")
x = io.StringIO(x.read().decode(
"utf-8"))
x = file_block(x,
"ref/" + filename)
f.write(x.read())
for filename
in sources:
x = tarball.extractfile(f
"{topdir}/ref/{filename}")
x = io.StringIO(x.read().decode(
"utf-8"))
x = file_block(x,
"ref/" + filename)
f.write(x.read())
comment = io.StringIO()
comment.write(
f
"""/*
* SPDX-License-Identifier: Apache-2.0
*
* This file was generated
from
*
https://github.com/pq-crystals/kyber/commit/{short_commit}
*
* Files
from that repository are listed here surrounded by
*
"* begin: [file] *" and "* end: [file] *" comments.
*
* The following changes have been made:
* - include guards have been removed,
* - include directives have been removed,
* -
"#ifdef KYBER90S" blocks have been evaluated with "KYBER90S" undefined,
* - functions outside of kem.c have been made static.
*/
"""
)
for filename
in [
"LICENSE",
"AUTHORS"]:
comment.write(f
"""\n/** begin: ref/{filename} **\n""")
x = tarball.extractfile(f
"{topdir}/{filename}")
x = io.StringIO(x.read().decode(
"utf-8"))
for line
in x.readlines():
comment.write(line)
comment.write(f
"""** end: ref/{filename} **/\n""")
comment.seek(0)
print(f
"* writing modified files to {out}")
with open(out,
"w")
as f:
f.write(comment.read())
f.write(
"""
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#ifdef FREEBL_NO_DEPEND
#include "stubs.h"
#endif
#include "secport.h"
// We need to provide an implementation of randombytes to avoid an unused
// function warning. We don
't use the randomized API in freebl, so we'll make
// calling randombytes an error.
static void randombytes(uint8_t *out, size_t outlen) {
// this memset
is to avoid
"maybe-uninitialized" warnings that gcc-11 issues
//
for the (unused) crypto_kem_keypair
and crypto_kem_enc functions.
memset(out, 0, outlen);
assert(0);
}
/*************************************************
* Name: verify
*
* Description: Compare two arrays
for equality
in constant time.
*
* Arguments: const uint8_t *a: pointer to first byte array
* const uint8_t *b: pointer to second byte array
* size_t len: length of the byte arrays
*
* Returns 0
if the byte arrays are equal, 1 otherwise
**************************************************/
static int verify(const uint8_t *a, const uint8_t *b, size_t len) {
return NSS_SecureMemcmp(a, b, len);
}
/*************************************************
* Name: cmov
*
* Description: Copy len bytes
from x to r
if b
is 1;
* don
't modify x if b is 0. Requires b to be in {0,1};
* assumes two
's complement representation of negative integers.
* Runs
in constant time.
*
* Arguments: uint8_t *r: pointer to output byte array
* const uint8_t *x: pointer to input byte array
* size_t len: Amount of bytes to be copied
* uint8_t b: Condition bit; has to be
in {0,1}
**************************************************/
static void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b)
{
NSS_SecureSelect(r, r, x, len, b);
}
"""
)
for filename
in headers:
x = tarball.extractfile(f
"{topdir}/ref/{filename}")
x = io.StringIO(x.read().decode(
"utf-8"))
x = remove_include_guard(x, filename.upper().replace(
".",
"_"))
x = remove_includes(x)
x = remove_kyber90s(x)
if filename
not in [
"kem.h",
"fips202.h"]:
x = add_static_to_fns(x)
x = file_block(x,
"ref/" + filename)
f.write(x.read())
for filename
in sources:
x = tarball.extractfile(f
"{topdir}/ref/{filename}")
x = io.StringIO(x.read().decode(
"utf-8"))
x = remove_includes(x)
x = remove_kyber90s(x)
if filename
not in [
"kem.c",
"fips202.c"]:
x = add_static_to_fns(x)
x = file_block(x,
"ref/" + filename)
f.write(x.read())
print(f
"* writing private header to {out_api}")
with open(out_api,
"w")
as f:
filename =
"api.h"
comment.seek(0)
f.write(comment.read())
f.write(
"""
#ifndef KYBER_PQCRYSTALS_REF_H
#define KYBER_PQCRYSTALS_REF_H
"""
)
x = tarball.extractfile(f
"{topdir}/ref/{filename}")
x = io.StringIO(x.read().decode(
"utf-8"))
x = remove_include_guard(x, filename.upper().replace(
".",
"_"))
x = file_block(x,
"ref/" + filename)
f.write(x.read())
f.write(
f
"""
#endif // KYBER_PQCRYSTALS_REF_H
"""
)
print(
f
"""* done!
You should now:
1) Check the output by running `diff {out_orig} {out}`
2) Move {out} to lib/freebl/{out}
3) Move {out_api} to lib/freebl/{out_api}
4) Delete {out_orig}
and {archive}.
"""
)