diff --git a/Cargo.toml b/Cargo.toml index ecb851a..fc0d0b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntt" -version = "0.1.2" +version = "0.1.3" edition = "2021" description = "Implements the fast NTT (number theoretic transform) for polynomial multiplcation." license = "MIT" diff --git a/src/lib.rs b/src/lib.rs index 34b583c..8aa0600 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,11 @@ fn mod_inv(a: i64, p: i64) -> i64 { mod_exp(a, p - 2, p) // Using Fermat's Little Theorem } +// Compute n-th root of unity (omega = root^((p - 1) / n) % p) +pub fn omega(root: i64, p: i64, n: usize) -> i64{ + mod_exp(root, (p - 1) / n as i64, p) +} + // Forward transform using NTT, output bit-reversed pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec { let mut result = a.to_vec(); @@ -94,9 +99,7 @@ pub fn polymul(a: &Vec, b: &Vec, n: i64, p: i64) -> Vec { /// /// # Returns /// A vector representing the polynomial product modulo `p`. -pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, root: i64) -> Vec { - // Compute n-th root of unity (omega = root^((p - 1) / n) % p) - let omega = mod_exp(root, (p - 1) / n as i64, p); +pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec { // Step 1: Perform the NTT (forward transform) on both polynomials let a_ntt = ntt(a, omega, n, p); diff --git a/src/main.rs b/src/main.rs index 7502ff0..8b7286d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,12 @@ mod test; -use ntt::{ntt, intt, mod_exp, polymul, polymul_ntt}; +use ntt::{omega, ntt, intt , polymul, polymul_ntt}; fn main() { let p: i64 = 17; // Prime modulus let root: i64 = 3; // Primitive root of unity for the modulus let n: usize = 8; // Length of the NTT (must be a power of 2) - - // Compute n-th root of unity: ω = g^((p - 1) / n) % p - let omega = mod_exp(root, (p - 1) / n as i64, p); + let omega = omega(root, p, n); // n-th root of unity: root^((p - 1) / n) % p // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -34,7 +32,7 @@ fn main() { let c = intt(&c_ntt, omega, n, p); let c_std = polymul(&a, &b, n as i64, p); - let c_fast = polymul_ntt(&a, &b, n, p, root); + let c_fast = polymul_ntt(&a, &b, n, p, omega); // Output the results println!("Polynomial A: {:?}", a); diff --git a/src/test.rs b/src/test.rs index a815b7b..697c8cb 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,12 +1,13 @@ #[cfg(test)] mod tests { - use ntt::{polymul, polymul_ntt}; + use ntt::{omega, polymul, polymul_ntt}; #[test] fn test_polymul_ntt() { let p: i64 = 17; // Prime modulus let root: i64 = 3; // Primitive root of unity let n: usize = 8; // Length of the NTT (must be a power of 2) + let omega = omega(root, p, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -18,7 +19,7 @@ mod tests { let c_std = polymul(&a, &b, n as i64, p); // Perform the NTT-based polynomial multiplication - let c_fast = polymul_ntt(&a, &b, n, p, root); + let c_fast = polymul_ntt(&a, &b, n, p, omega); // Ensure both methods produce the same result assert_eq!(c_std, c_fast, "The results of polymul and polymul_ntt do not match");