diff --git a/src/lib.rs b/src/lib.rs index e1bc79f..91f4696 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,6 @@ use reikna::totient::totient; - -fn gcd(mut a: i64, mut b: i64) -> i64 { - while b != 0 { - let temp = b; - b = a % b; - a = temp; - } - a.abs() -} +use reikna::factor::quick_factorize; +use std::collections::HashMap; // Modular arithmetic functions using i64 fn mod_add(a: i64, b: i64, p: i64) -> i64 { @@ -31,17 +24,36 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { result } -//compute the modular inverse of a modulo p using Fermat's little theorem, p not necessarily prime -fn mod_inv(a: i64, p: i64) -> i64 { - assert!(gcd(a, p) == 1, "{} and {} are not coprime", a, p); - mod_exp(a, totient(p as u64) as i64 - 1, p) // order of mult. group is Euler's totient function +fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { + if b == 0 { + (a, 1, 0) // gcd, x, y + } else { + let (gcd, x1, y1) = extended_gcd(b, a % b); + (gcd, y1, x1 - (a / b) * y1) + } +} + +pub fn mod_inv(a: i64, modulus: i64) -> i64 { + let (gcd, x, _) = extended_gcd(a, modulus); + if gcd != 1 { + panic!("{} and {} are not coprime, no inverse exists", a, modulus); + } + (x % modulus + modulus) % modulus // Ensure a positive result } // Compute n-th root of unity (omega) for p not necessarily prime -pub fn omega(root: i64, p: i64, n: usize) -> i64 { - let grp_size = totient(p as u64) as i64; - assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size); - mod_exp(root, grp_size / n as i64, p) // order of mult. group is Euler's totient function +pub fn omega(modulus: i64, n: usize) -> i64 { + let factors = factorize(modulus as i64); + if factors.len() == 1 { + let (p, e) = factors.into_iter().next().unwrap(); + let root = primitive_root(p, e); // primitive root mod p + let grp_size = totient(modulus as u64) as i64; + assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size); + return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function + } + else { + return root_of_unity(modulus, n as i64) + } } // Forward transform using NTT, output bit-reversed @@ -132,3 +144,68 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec HashMap { + let mut factors = HashMap::new(); + for factor in quick_factorize(n as u64) { + *factors.entry(factor as i64).or_insert(0) += 1; + } + factors +} + +/// Fast computation of a primitive root mod p^e +pub fn primitive_root(p: i64, e: u32) -> i64 { + let g = primitive_root_mod_p(p); + let mut g_lifted = g; // Lift it to p^e + for _ in 1..e { + if g_lifted.pow((p - 1) as u32) % p.pow(e) == 1 { + g_lifted += p.pow(e - 1); + } + } + g_lifted +} + +/// Finds a primitive root modulo a prime p +fn primitive_root_mod_p(p: i64) -> i64 { + let phi = p - 1; + let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities + for g in 2..p { + // Check if g is a primitive root by checking mod_exp conditions with all prime factors of phi + if factors.iter().all(|(&q, _)| mod_exp(g, phi / q, p) != 1) { + return g; + } + } + 0 // Should never happen +} + +// the Chinese remainder theorem for two moduli +pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { + let n = n1 * n2; + let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 + let m2 = mod_inv(n2, n1); // Inverse of n2 mod n1 + let x = (a1 * m2 * n2 + a2 * m1 * n1) % n; + if x < 0 { x + n } else { x } +} + +// computes an n^th root of unity modulo a composite modulus +// note we require that an n^th root of unity exists for each multiplicative group modulo p^e +// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus +// for the NTT, we require than a 2n^th root of unity exists +pub fn root_of_unity(modulus: i64, n: i64) -> i64 { + let factors = factorize(modulus); + let mut result = 1; + for (&p, &e) in factors.iter() { + let omega = omega(p.pow(e), n.try_into().unwrap()); // Find primitive nth root of unity mod p^e + result = crt(result, modulus / p.pow(e), omega, p.pow(e)); // Combine with the running result using CRT + } + result +} + +//ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n +pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { + assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); + assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)"); + true +} + diff --git a/src/main.rs b/src/main.rs index 8b7286d..e7bb1f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,11 @@ mod test; -use ntt::{omega, ntt, intt , polymul, polymul_ntt}; +use ntt::{ntt, intt , polymul, polymul_ntt, verify_root_of_unity}; fn main() { - let p: i64 = 17; // Prime modulus - let root: i64 = 3; // Primitive root of unity for the modulus + let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, p, n); // n-th root of unity: root^((p - 1) / n) % p + let omega = ntt::omega(modulus, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -15,26 +14,27 @@ fn main() { b.resize(n, 0); // Perform the forward NTT - let a_ntt = ntt(&a, omega, n, p); - let b_ntt = ntt(&b, omega, n, p); + let a_ntt = ntt(&a, omega, n, modulus); + let b_ntt = ntt(&b, omega, n, modulus); // Perform the inverse NTT on the transformed A for verification - let a_ntt_intt = intt(&a_ntt, omega, n, p); + let a_ntt_intt = intt(&a_ntt, omega, n, modulus); // Pointwise multiplication in the NTT domain let c_ntt: Vec = a_ntt .iter() .zip(b_ntt.iter()) - .map(|(x, y)| (x * y) % p) + .map(|(x, y)| (x * y) % modulus) .collect(); // Inverse NTT to get the polynomial product - let c = intt(&c_ntt, omega, n, p); + let c = intt(&c_ntt, omega, n, modulus); - let c_std = polymul(&a, &b, n as i64, p); - let c_fast = polymul_ntt(&a, &b, n, p, omega); + let c_std = polymul(&a, &b, n as i64, modulus); + let c_fast = polymul_ntt(&a, &b, n, modulus, omega); // Output the results + println!("verify omega = {}", verify_root_of_unity(omega, n as i64, modulus)); println!("Polynomial A: {:?}", a); println!("Polynomial B: {:?}", b); println!("Transformed A: {:?}", a_ntt); @@ -44,4 +44,5 @@ fn main() { println!("Resultant Polynomial (c): {:?}", c); println!("Standard polynomial mult. result: {:?}", c_std); println!("Polynomial multiplication method using NTT: {:?}", c_fast); + } diff --git a/src/test.rs b/src/test.rs index 3553a79..25ba447 100644 --- a/src/test.rs +++ b/src/test.rs @@ -5,9 +5,8 @@ mod tests { #[test] fn test_polymul_ntt() { let p: i64 = 17; // Prime modulus - let root: i64 = 3; // Primitive root of unity let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, p, n); // n-th root of unity + let omega = omega(p, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -28,9 +27,8 @@ mod tests { #[test] fn test_polymul_ntt_square_modulus() { let modulus: i64 = 17*17; // Prime modulus - let root: i64 = 3; // Primitive root of unity let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, modulus, n); // n-th root of unity + let omega = omega(modulus, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -51,9 +49,8 @@ mod tests { #[test] fn test_polymul_ntt_prime_power_modulus() { let modulus: i64 = (17 as i64).pow(4); // modulus p^k - let root: i64 = 3; // Primitive root of unity let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, modulus, n); // n-th root of unity + let omega = omega(modulus, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -70,5 +67,25 @@ mod tests { // Ensure both methods produce the same result assert_eq!(c_std, c_fast, "The results of polymul and polymul_ntt do not match"); } + + #[test] + fn test_polymul_ntt_non_prime_power_modulus() { + let moduli = [17*41, 17*73, 17*41*73]; // Different moduli to test + let n: usize = 8; // Length of the NTT (must be a power of 2) + + for &modulus in &moduli { + let omega = omega(modulus, n); + + let mut a = vec![1, 2, 3, 4]; + let mut b = vec![4, 5, 6, 7]; + a.resize(n, 0); + b.resize(n, 0); + + let c_std = polymul(&a, &b, n as i64, modulus); + let c_fast = polymul_ntt(&a, &b, n, modulus, omega); + + assert_eq!(c_std, c_fast, "Failed for modulus {}", modulus); + } + } }