Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 94 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
use reikna::totient::totient;

fn gcd(mut a: i64, mut b: i64) -> i64 {
while b != 0 {
let temp = b;
b = a % b;
a = temp;
}
a.abs()
}
use reikna::factor::quick_factorize;
use std::collections::HashMap;

// Modular arithmetic functions using i64
fn mod_add(a: i64, b: i64, p: i64) -> i64 {
Expand All @@ -31,17 +24,36 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
result
}

//compute the modular inverse of a modulo p using Fermat's little theorem, p not necessarily prime
fn mod_inv(a: i64, p: i64) -> i64 {
assert!(gcd(a, p) == 1, "{} and {} are not coprime", a, p);
mod_exp(a, totient(p as u64) as i64 - 1, p) // order of mult. group is Euler's totient function
fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
if b == 0 {
(a, 1, 0) // gcd, x, y
} else {
let (gcd, x1, y1) = extended_gcd(b, a % b);
(gcd, y1, x1 - (a / b) * y1)
}
}

pub fn mod_inv(a: i64, modulus: i64) -> i64 {
let (gcd, x, _) = extended_gcd(a, modulus);
if gcd != 1 {
panic!("{} and {} are not coprime, no inverse exists", a, modulus);
}
(x % modulus + modulus) % modulus // Ensure a positive result
}

// Compute n-th root of unity (omega) for p not necessarily prime
pub fn omega(root: i64, p: i64, n: usize) -> i64 {
let grp_size = totient(p as u64) as i64;
assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size);
mod_exp(root, grp_size / n as i64, p) // order of mult. group is Euler's totient function
pub fn omega(modulus: i64, n: usize) -> i64 {
let factors = factorize(modulus as i64);
if factors.len() == 1 {
let (p, e) = factors.into_iter().next().unwrap();
let root = primitive_root(p, e); // primitive root mod p
let grp_size = totient(modulus as u64) as i64;
assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size);
return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function
}
else {
return root_of_unity(modulus, n as i64)
}
}

// Forward transform using NTT, output bit-reversed
Expand Down Expand Up @@ -132,3 +144,68 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i6

c
}

/// Compute the prime factorization of `n` (with multiplicities).
fn factorize(n: i64) -> HashMap<i64, u32> {
let mut factors = HashMap::new();
for factor in quick_factorize(n as u64) {
*factors.entry(factor as i64).or_insert(0) += 1;
}
factors
}

/// Fast computation of a primitive root mod p^e
pub fn primitive_root(p: i64, e: u32) -> i64 {
let g = primitive_root_mod_p(p);
let mut g_lifted = g; // Lift it to p^e
for _ in 1..e {
if g_lifted.pow((p - 1) as u32) % p.pow(e) == 1 {
g_lifted += p.pow(e - 1);
}
}
g_lifted
}

/// Finds a primitive root modulo a prime p
fn primitive_root_mod_p(p: i64) -> i64 {
let phi = p - 1;
let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities
for g in 2..p {
// Check if g is a primitive root by checking mod_exp conditions with all prime factors of phi
if factors.iter().all(|(&q, _)| mod_exp(g, phi / q, p) != 1) {
return g;
}
}
0 // Should never happen
}

// the Chinese remainder theorem for two moduli
pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
let n = n1 * n2;
let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2
let m2 = mod_inv(n2, n1); // Inverse of n2 mod n1
let x = (a1 * m2 * n2 + a2 * m1 * n1) % n;
if x < 0 { x + n } else { x }
}

// computes an n^th root of unity modulo a composite modulus
// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
// for the NTT, we require than a 2n^th root of unity exists
pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
let factors = factorize(modulus);
let mut result = 1;
for (&p, &e) in factors.iter() {
let omega = omega(p.pow(e), n.try_into().unwrap()); // Find primitive nth root of unity mod p^e
result = crt(result, modulus / p.pow(e), omega, p.pow(e)); // Combine with the running result using CRT
}
result
}

//ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool {
assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity");
assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)");
true
}

23 changes: 12 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
mod test;

use ntt::{omega, ntt, intt , polymul, polymul_ntt};
use ntt::{ntt, intt , polymul, polymul_ntt, verify_root_of_unity};

fn main() {
let p: i64 = 17; // Prime modulus
let root: i64 = 3; // Primitive root of unity for the modulus
let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p
let n: usize = 8; // Length of the NTT (must be a power of 2)
let omega = omega(root, p, n); // n-th root of unity: root^((p - 1) / n) % p
let omega = ntt::omega(modulus, n); // n-th root of unity

// Input polynomials (padded to length `n`)
let mut a = vec![1, 2, 3, 4];
Expand All @@ -15,26 +14,27 @@ fn main() {
b.resize(n, 0);

// Perform the forward NTT
let a_ntt = ntt(&a, omega, n, p);
let b_ntt = ntt(&b, omega, n, p);
let a_ntt = ntt(&a, omega, n, modulus);
let b_ntt = ntt(&b, omega, n, modulus);

// Perform the inverse NTT on the transformed A for verification
let a_ntt_intt = intt(&a_ntt, omega, n, p);
let a_ntt_intt = intt(&a_ntt, omega, n, modulus);

// Pointwise multiplication in the NTT domain
let c_ntt: Vec<i64> = a_ntt
.iter()
.zip(b_ntt.iter())
.map(|(x, y)| (x * y) % p)
.map(|(x, y)| (x * y) % modulus)
.collect();

// Inverse NTT to get the polynomial product
let c = intt(&c_ntt, omega, n, p);
let c = intt(&c_ntt, omega, n, modulus);

let c_std = polymul(&a, &b, n as i64, p);
let c_fast = polymul_ntt(&a, &b, n, p, omega);
let c_std = polymul(&a, &b, n as i64, modulus);
let c_fast = polymul_ntt(&a, &b, n, modulus, omega);

// Output the results
println!("verify omega = {}", verify_root_of_unity(omega, n as i64, modulus));
println!("Polynomial A: {:?}", a);
println!("Polynomial B: {:?}", b);
println!("Transformed A: {:?}", a_ntt);
Expand All @@ -44,4 +44,5 @@ fn main() {
println!("Resultant Polynomial (c): {:?}", c);
println!("Standard polynomial mult. result: {:?}", c_std);
println!("Polynomial multiplication method using NTT: {:?}", c_fast);

}
29 changes: 23 additions & 6 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ mod tests {
#[test]
fn test_polymul_ntt() {
let p: i64 = 17; // Prime modulus
let root: i64 = 3; // Primitive root of unity
let n: usize = 8; // Length of the NTT (must be a power of 2)
let omega = omega(root, p, n); // n-th root of unity
let omega = omega(p, n); // n-th root of unity

// Input polynomials (padded to length `n`)
let mut a = vec![1, 2, 3, 4];
Expand All @@ -28,9 +27,8 @@ mod tests {
#[test]
fn test_polymul_ntt_square_modulus() {
let modulus: i64 = 17*17; // Prime modulus
let root: i64 = 3; // Primitive root of unity
let n: usize = 8; // Length of the NTT (must be a power of 2)
let omega = omega(root, modulus, n); // n-th root of unity
let omega = omega(modulus, n); // n-th root of unity

// Input polynomials (padded to length `n`)
let mut a = vec![1, 2, 3, 4];
Expand All @@ -51,9 +49,8 @@ mod tests {
#[test]
fn test_polymul_ntt_prime_power_modulus() {
let modulus: i64 = (17 as i64).pow(4); // modulus p^k
let root: i64 = 3; // Primitive root of unity
let n: usize = 8; // Length of the NTT (must be a power of 2)
let omega = omega(root, modulus, n); // n-th root of unity
let omega = omega(modulus, n); // n-th root of unity

// Input polynomials (padded to length `n`)
let mut a = vec![1, 2, 3, 4];
Expand All @@ -70,5 +67,25 @@ mod tests {
// Ensure both methods produce the same result
assert_eq!(c_std, c_fast, "The results of polymul and polymul_ntt do not match");
}

#[test]
fn test_polymul_ntt_non_prime_power_modulus() {
let moduli = [17*41, 17*73, 17*41*73]; // Different moduli to test
let n: usize = 8; // Length of the NTT (must be a power of 2)

for &modulus in &moduli {
let omega = omega(modulus, n);

let mut a = vec![1, 2, 3, 4];
let mut b = vec![4, 5, 6, 7];
a.resize(n, 0);
b.resize(n, 0);

let c_std = polymul(&a, &b, n as i64, modulus);
let c_fast = polymul_ntt(&a, &b, n, modulus, omega);

assert_eq!(c_std, c_fast, "Failed for modulus {}", modulus);
}
}

}