From ed74c76d78eeca271c4d4e4206f0304c2a901f0b Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Tue, 11 Feb 2025 21:19:18 -0500 Subject: [PATCH 01/20] compute a cyclic subgroup of mult. group compute a cyclic subgroup C of the multiplicative group of Z/NZ such that n divides the order of C. --- src/lib.rs | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 12 ++++++++- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index e1bc79f..676265c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,6 @@ use reikna::totient::totient; +use reikna::factor::quick_factorize; +use std::collections::HashMap; fn gcd(mut a: i64, mut b: i64) -> i64 { while b != 0 { @@ -9,6 +11,11 @@ fn gcd(mut a: i64, mut b: i64) -> i64 { a.abs() } +/// Compute LCM of two numbers. +fn lcm(a: i64, b: i64) -> i64 { + (a * b) / gcd(a, b) +} + // Modular arithmetic functions using i64 fn mod_add(a: i64, b: i64, p: i64) -> i64 { (a + b) % p @@ -132,3 +139,73 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec HashMap { + let mut factors = HashMap::new(); + for factor in quick_factorize(n as u64) { + *factors.entry(factor as i64).or_insert(0) += 1; + } + factors +} + +/// Fast computation of a primitive root mod p^e +pub fn primitive_root(p: i64, e: u32) -> i64 { + + // Find a primitive root mod p + let g = find_primitive_root_mod_p(p); + + // Lift it to p^e + let mut g_lifted = g; + for _ in 1..e { + if g_lifted.pow((p - 1) as u32) % p.pow(e) == 1 { + g_lifted += p.pow(e - 1); + } + } + g_lifted +} + +/// Finds a primitive root modulo a prime p +fn find_primitive_root_mod_p(p: i64) -> i64 { + let phi = p - 1; + let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities + + for g in 2..p { + // Check if g is a primitive root by checking mod_exp conditions with all prime factors of phi + if factors.iter().all(|(&q, _)| mod_exp(g, phi / q, p) != 1) { + return g; + } + } + 0 // Should never happen +} + +/// Finds an element in (Z/NZ)* whose order is divisible by `n`. +pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { + if n == 0 || (n & (n - 1)) != 0 { + panic!("n must be a power of 2"); + } + + let factors = factorize(modulus); + let mut generators = Vec::new(); + let mut orders = Vec::new(); + + for (&p, &e) in &factors { + let phi = (p - 1) * p.pow(e - 1); + let g = primitive_root(p, e); + generators.push(g); + orders.push(phi); + } + + let mut chosen_element = 1; + let mut chosen_order = 1; + + for (&g, &k) in generators.iter().zip(orders.iter()) { + let required_order = lcm(k, n); // Ensure the subgroup order is divisible by n + let exponent = required_order / gcd(k, required_order); // Pick the exponent carefully + chosen_element = (chosen_element * mod_exp(g, exponent, modulus)) % modulus; + chosen_order = lcm(chosen_order, required_order); + } + + (chosen_element, chosen_order) +} + diff --git a/src/main.rs b/src/main.rs index 8b7286d..ff98b19 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod test; -use ntt::{omega, ntt, intt , polymul, polymul_ntt}; +use reikna::totient::totient; +use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, primitive_root}; fn main() { let p: i64 = 17; // Prime modulus @@ -44,4 +45,13 @@ fn main() { println!("Resultant Polynomial (c): {:?}", c); println!("Standard polynomial mult. result: {:?}", c_std); println!("Polynomial multiplication method using NTT: {:?}", c_fast); + + let modulus = 45; // Example modulus + let n = 4; // Must be a power of 2 + let (g, g_order) = find_cyclic_subgroup(modulus, n); + let root = primitive_root(23, 2); + println!("Primitive root: {}", root); + println!("(g, order) = {}, {}", g, g_order); + println!("g^g_order: {}", g.pow(g_order as u32) % modulus); + println!("Totient of {}: {}", modulus, totient(modulus as u64)); } From b1013ab2cb564a4a74fb2e18517af8079662be38 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Tue, 11 Feb 2025 21:50:01 -0500 Subject: [PATCH 02/20] compute omega --- src/lib.rs | 18 +++++++++++++----- src/main.rs | 10 ++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 676265c..0a5484c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -181,16 +181,19 @@ fn find_primitive_root_mod_p(p: i64) -> i64 { /// Finds an element in (Z/NZ)* whose order is divisible by `n`. pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { + // Check if n is a power of 2 if n == 0 || (n & (n - 1)) != 0 { panic!("n must be a power of 2"); } + // Factorize modulus (assuming a function exists) let factors = factorize(modulus); let mut generators = Vec::new(); let mut orders = Vec::new(); + // Loop through factors to find generators and orders for (&p, &e) in &factors { - let phi = (p - 1) * p.pow(e - 1); + let phi = (p - 1) * p.pow(e - 1); // Euler's totient function let g = primitive_root(p, e); generators.push(g); orders.push(phi); @@ -199,13 +202,18 @@ pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { let mut chosen_element = 1; let mut chosen_order = 1; + // Loop through generators and orders to find element with required order for (&g, &k) in generators.iter().zip(orders.iter()) { - let required_order = lcm(k, n); // Ensure the subgroup order is divisible by n - let exponent = required_order / gcd(k, required_order); // Pick the exponent carefully - chosen_element = (chosen_element * mod_exp(g, exponent, modulus)) % modulus; - chosen_order = lcm(chosen_order, required_order); + // Calculate required order + let required_order = lcm(k, n); // Least common multiple + let exponent = required_order / gcd(k, required_order); // Adjust exponent + chosen_element = (chosen_element * mod_exp(g, exponent, modulus)) % modulus; // mod_exp computes power mod modulus + chosen_order = lcm(chosen_order, k / gcd(k, n)); // Adjust chosen order } + // Assert the order is divisible by n + assert_eq!(chosen_order % n, 0); + (chosen_element, chosen_order) } diff --git a/src/main.rs b/src/main.rs index ff98b19..6fcb88c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod test; use reikna::totient::totient; -use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, primitive_root}; +use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, primitive_root, mod_exp}; fn main() { let p: i64 = 17; // Prime modulus @@ -47,11 +47,13 @@ fn main() { println!("Polynomial multiplication method using NTT: {:?}", c_fast); let modulus = 45; // Example modulus - let n = 4; // Must be a power of 2 + let n = 2; // Must be a power of 2 let (g, g_order) = find_cyclic_subgroup(modulus, n); + let omega = mod_exp(g, g_order / n, modulus); let root = primitive_root(23, 2); + println!("Totient of {}: {}", modulus, totient(modulus as u64)); println!("Primitive root: {}", root); println!("(g, order) = {}, {}", g, g_order); - println!("g^g_order: {}", g.pow(g_order as u32) % modulus); - println!("Totient of {}: {}", modulus, totient(modulus as u64)); + println!("omega: {}", omega); + println!("omega^n: {}", mod_exp(omega, n, modulus)); } From 66e71b586608931e80aaa57dfe09606bd036c894 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Tue, 11 Feb 2025 23:30:04 -0500 Subject: [PATCH 03/20] use CRT --- src/lib.rs | 68 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0a5484c..47b0cb6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,16 @@ fn gcd(mut a: i64, mut b: i64) -> i64 { a.abs() } +fn extended_gcd(a: i64, b: i64) -> (i64, i64) { + if b == 0 { + return (1, 0); + } + let (x1, y1) = extended_gcd(b, a % b); + let x = y1; + let y = x1 - y1 * (a / b); + (x, y) +} + /// Compute LCM of two numbers. fn lcm(a: i64, b: i64) -> i64 { (a * b) / gcd(a, b) @@ -179,41 +189,53 @@ fn find_primitive_root_mod_p(p: i64) -> i64 { 0 // Should never happen } -/// Finds an element in (Z/NZ)* whose order is divisible by `n`. pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { - // Check if n is a power of 2 if n == 0 || (n & (n - 1)) != 0 { panic!("n must be a power of 2"); } - // Factorize modulus (assuming a function exists) let factors = factorize(modulus); - let mut generators = Vec::new(); - let mut orders = Vec::new(); + let mut result_element = 1; // Initialize to 1 for CRT + let mut result_order = 1; - // Loop through factors to find generators and orders for (&p, &e) in &factors { - let phi = (p - 1) * p.pow(e - 1); // Euler's totient function - let g = primitive_root(p, e); - generators.push(g); - orders.push(phi); + let phi = (p - 1) * p.pow(e - 1); + if phi % n == 0 { + let g = primitive_root(p, e); + let order = phi / n; // Find an element of order n (or a multiple of n) + let element_in_factor = mod_exp(g, order, p.pow(e)); + + //Lift using CRT + result_element = crt(result_element, p.pow(e), element_in_factor, modulus/p.pow(e)); + return (result_element, n) + } } - let mut chosen_element = 1; - let mut chosen_order = 1; - - // Loop through generators and orders to find element with required order - for (&g, &k) in generators.iter().zip(orders.iter()) { - // Calculate required order - let required_order = lcm(k, n); // Least common multiple - let exponent = required_order / gcd(k, required_order); // Adjust exponent - chosen_element = (chosen_element * mod_exp(g, exponent, modulus)) % modulus; // mod_exp computes power mod modulus - chosen_order = lcm(chosen_order, k / gcd(k, n)); // Adjust chosen order + if result_order == 1{ + panic!("could not find element of order n"); } + (result_element, result_order) +} - // Assert the order is divisible by n - assert_eq!(chosen_order % n, 0); +fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { + // Solve x = a1 mod n1 and x = a2 mod n2 + let n = n1 * n2; - (chosen_element, chosen_order) + // Find the modular inverses + let (inv1, _) = extended_gcd(n1, n2); // inv1 is the inverse of n2 mod n1 + let (inv2, _) = extended_gcd(n2, n1); // inv2 is the inverse of n1 mod n2 + + // CRT formula: x = (a1 * n2 * inv(n2, n1) + a2 * n1 * inv(n1, n2)) mod n + let x = (a1 * n2 * inv1 + a2 * n1 * inv2) % n; + + // Ensure non-negative result + if x < 0 { + x + n + } else { + x + } } + + + From 3b89f06c2c59483863018b058cba97075ce9a352 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 00:56:23 -0500 Subject: [PATCH 04/20] remove return --- src/lib.rs | 25 +++++++++---------------- src/main.rs | 2 +- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 47b0cb6..9468bd0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,7 +207,7 @@ pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { //Lift using CRT result_element = crt(result_element, p.pow(e), element_in_factor, modulus/p.pow(e)); - return (result_element, n) + result_order = n } } @@ -217,25 +217,18 @@ pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { (result_element, result_order) } -fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { - // Solve x = a1 mod n1 and x = a2 mod n2 - let n = n1 * n2; - - // Find the modular inverses - let (inv1, _) = extended_gcd(n1, n2); // inv1 is the inverse of n2 mod n1 - let (inv2, _) = extended_gcd(n2, n1); // inv2 is the inverse of n1 mod n2 - - // CRT formula: x = (a1 * n2 * inv(n2, n1) + a2 * n1 * inv(n1, n2)) mod n - let x = (a1 * n2 * inv1 + a2 * n1 * inv2) % n; - - // Ensure non-negative result - if x < 0 { +fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64{ + //Solve x = a1 mod n1 and x = a2 mod n2 + let n = n1*n2; + let (inv1, _) = extended_gcd(n1, n2); + let (inv2, _) = extended_gcd(n2, n1); + let x = (a1*inv2%n*n2%n + a2*inv1%n*n1%n)%n; + if x < 0{ x + n - } else { + }else{ x } } - diff --git a/src/main.rs b/src/main.rs index 6fcb88c..0e575f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,7 @@ fn main() { println!("Polynomial multiplication method using NTT: {:?}", c_fast); let modulus = 45; // Example modulus - let n = 2; // Must be a power of 2 + let n = 4; // Must be a power of 2 let (g, g_order) = find_cyclic_subgroup(modulus, n); let omega = mod_exp(g, g_order / n, modulus); let root = primitive_root(23, 2); From 9d373aa6f43a20abe4bdf8265d25149a262217db Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 01:53:59 -0500 Subject: [PATCH 05/20] simplify crt --- src/lib.rs | 28 +++++++--------------------- src/main.rs | 7 ++++++- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9468bd0..1233ff5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,16 +11,6 @@ fn gcd(mut a: i64, mut b: i64) -> i64 { a.abs() } -fn extended_gcd(a: i64, b: i64) -> (i64, i64) { - if b == 0 { - return (1, 0); - } - let (x1, y1) = extended_gcd(b, a % b); - let x = y1; - let y = x1 - y1 * (a / b); - (x, y) -} - /// Compute LCM of two numbers. fn lcm(a: i64, b: i64) -> i64 { (a * b) / gcd(a, b) @@ -201,6 +191,7 @@ pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { for (&p, &e) in &factors { let phi = (p - 1) * p.pow(e - 1); if phi % n == 0 { + println!("phi: {}", phi); let g = primitive_root(p, e); let order = phi / n; // Find an element of order n (or a multiple of n) let element_in_factor = mod_exp(g, order, p.pow(e)); @@ -217,17 +208,12 @@ pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { (result_element, result_order) } -fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64{ - //Solve x = a1 mod n1 and x = a2 mod n2 - let n = n1*n2; - let (inv1, _) = extended_gcd(n1, n2); - let (inv2, _) = extended_gcd(n2, n1); - let x = (a1*inv2%n*n2%n + a2*inv1%n*n1%n)%n; - if x < 0{ - x + n - }else{ - x - } +pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { + let n = n1 * n2; + let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 + let m2 = mod_inv(n2, n1); // Inverse of n2 mod n1 + let x = (a1 * m2 * n2 + a2 * m1 * n1) % n; + if x < 0 { x + n } else { x } } diff --git a/src/main.rs b/src/main.rs index 0e575f0..34aad99 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod test; use reikna::totient::totient; -use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, primitive_root, mod_exp}; +use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, primitive_root, mod_exp, crt}; fn main() { let p: i64 = 17; // Prime modulus @@ -56,4 +56,9 @@ fn main() { println!("(g, order) = {}, {}", g, g_order); println!("omega: {}", omega); println!("omega^n: {}", mod_exp(omega, n, modulus)); + + let x1 = crt(2, 3, 3, 5); + println!("Expected: 8, Computed: {}", x1); + let x3 = crt(4, 7, 5, 11); + println!("Expected: 60, Computed: {}", x3); } From 8193fc947816b4bc144c864fbb4675e1aa5d7122 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 02:15:51 -0500 Subject: [PATCH 06/20] add comments, clean up --- src/lib.rs | 36 ++++++++---------------------------- src/main.rs | 13 ++----------- 2 files changed, 10 insertions(+), 39 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1233ff5..a1e5dbe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,11 +11,6 @@ fn gcd(mut a: i64, mut b: i64) -> i64 { a.abs() } -/// Compute LCM of two numbers. -fn lcm(a: i64, b: i64) -> i64 { - (a * b) / gcd(a, b) -} - // Modular arithmetic functions using i64 fn mod_add(a: i64, b: i64, p: i64) -> i64 { (a + b) % p @@ -179,33 +174,18 @@ fn find_primitive_root_mod_p(p: i64) -> i64 { 0 // Should never happen } -pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> (i64, i64) { - if n == 0 || (n & (n - 1)) != 0 { - panic!("n must be a power of 2"); - } - - let factors = factorize(modulus); - let mut result_element = 1; // Initialize to 1 for CRT - let mut result_order = 1; - +pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> i64 { + let factors = factorize(modulus); // factor the modulus for (&p, &e) in &factors { - let phi = (p - 1) * p.pow(e - 1); + let phi = (p - 1) * p.pow(e - 1); // Euler's totient function if phi % n == 0 { - println!("phi: {}", phi); - let g = primitive_root(p, e); - let order = phi / n; // Find an element of order n (or a multiple of n) - let element_in_factor = mod_exp(g, order, p.pow(e)); - - //Lift using CRT - result_element = crt(result_element, p.pow(e), element_in_factor, modulus/p.pow(e)); - result_order = n + let g = primitive_root(p, e); // find a primitive root mod p^e + let exp = phi / n; // exponent of the primitive root + let order_n_elem = mod_exp(g, exp, p.pow(e)); // element of mult. order n mod p^e + return crt(1, modulus/p.pow(e), order_n_elem, p.pow(e)); // lift using CRT } } - - if result_order == 1{ - panic!("could not find element of order n"); - } - (result_element, result_order) + panic!("could not find element of order n"); } pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { diff --git a/src/main.rs b/src/main.rs index 34aad99..369b08d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod test; use reikna::totient::totient; -use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, primitive_root, mod_exp, crt}; +use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, mod_exp}; fn main() { let p: i64 = 17; // Prime modulus @@ -48,17 +48,8 @@ fn main() { let modulus = 45; // Example modulus let n = 4; // Must be a power of 2 - let (g, g_order) = find_cyclic_subgroup(modulus, n); - let omega = mod_exp(g, g_order / n, modulus); - let root = primitive_root(23, 2); + let omega = find_cyclic_subgroup(modulus, n); println!("Totient of {}: {}", modulus, totient(modulus as u64)); - println!("Primitive root: {}", root); - println!("(g, order) = {}, {}", g, g_order); println!("omega: {}", omega); println!("omega^n: {}", mod_exp(omega, n, modulus)); - - let x1 = crt(2, 3, 3, 5); - println!("Expected: 8, Computed: {}", x1); - let x3 = crt(4, 7, 5, 11); - println!("Expected: 60, Computed: {}", x3); } From e8811b0d9dc7f0f66867d2efee9175f728258cdf Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 03:14:04 -0500 Subject: [PATCH 07/20] add test for non prime power case --- src/lib.rs | 25 ++++++++++++++++++++----- src/main.rs | 7 +++---- src/test.rs | 31 +++++++++++++++++++++++++------ 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a1e5dbe..d2647be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,13 @@ fn gcd(mut a: i64, mut b: i64) -> i64 { a.abs() } +pub fn is_prime_power(n: i64) -> bool { + if n <= 1 { + return false; // 1 and numbers <= 1 are not prime powers + } + factorize(n).len() == 1 +} + // Modular arithmetic functions using i64 fn mod_add(a: i64, b: i64, p: i64) -> i64 { (a + b) % p @@ -40,10 +47,18 @@ fn mod_inv(a: i64, p: i64) -> i64 { } // Compute n-th root of unity (omega) for p not necessarily prime -pub fn omega(root: i64, p: i64, n: usize) -> i64 { - let grp_size = totient(p as u64) as i64; - assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size); - mod_exp(root, grp_size / n as i64, p) // order of mult. group is Euler's totient function +pub fn omega(modulus: i64, n: usize) -> i64 { + let factors = factorize(modulus as i64); + if factors.len() == 1 { + let (p, e) = factors.into_iter().next().unwrap(); + let root = primitive_root(p, e); // primitive root mod p + let grp_size = totient(modulus as u64) as i64; + assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size); + return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function + } + else { + return cyclic_subgroup_gen(modulus, n as i64) + } } // Forward transform using NTT, output bit-reversed @@ -174,7 +189,7 @@ fn find_primitive_root_mod_p(p: i64) -> i64 { 0 // Should never happen } -pub fn find_cyclic_subgroup(modulus: i64, n: i64) -> i64 { +pub fn cyclic_subgroup_gen(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); // factor the modulus for (&p, &e) in &factors { let phi = (p - 1) * p.pow(e - 1); // Euler's totient function diff --git a/src/main.rs b/src/main.rs index 369b08d..2b56ef2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,12 @@ mod test; use reikna::totient::totient; -use ntt::{omega, ntt, intt , polymul, polymul_ntt, find_cyclic_subgroup, mod_exp}; +use ntt::{omega, ntt, intt , polymul, polymul_ntt, cyclic_subgroup_gen, mod_exp}; fn main() { let p: i64 = 17; // Prime modulus - let root: i64 = 3; // Primitive root of unity for the modulus let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, p, n); // n-th root of unity: root^((p - 1) / n) % p + let omega = omega(p, n); // n-th root of unity: root^((p - 1) / n) % p // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -48,7 +47,7 @@ fn main() { let modulus = 45; // Example modulus let n = 4; // Must be a power of 2 - let omega = find_cyclic_subgroup(modulus, n); + let omega = cyclic_subgroup_gen(modulus, n); println!("Totient of {}: {}", modulus, totient(modulus as u64)); println!("omega: {}", omega); println!("omega^n: {}", mod_exp(omega, n, modulus)); diff --git a/src/test.rs b/src/test.rs index 3553a79..ee7e450 100644 --- a/src/test.rs +++ b/src/test.rs @@ -5,9 +5,8 @@ mod tests { #[test] fn test_polymul_ntt() { let p: i64 = 17; // Prime modulus - let root: i64 = 3; // Primitive root of unity let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, p, n); // n-th root of unity + let omega = omega(p, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -28,9 +27,8 @@ mod tests { #[test] fn test_polymul_ntt_square_modulus() { let modulus: i64 = 17*17; // Prime modulus - let root: i64 = 3; // Primitive root of unity let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, modulus, n); // n-th root of unity + let omega = omega(modulus, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -51,9 +49,30 @@ mod tests { #[test] fn test_polymul_ntt_prime_power_modulus() { let modulus: i64 = (17 as i64).pow(4); // modulus p^k - let root: i64 = 3; // Primitive root of unity let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(root, modulus, n); // n-th root of unity + let omega = omega(modulus, n); // n-th root of unity + + // Input polynomials (padded to length `n`) + let mut a = vec![1, 2, 3, 4]; + let mut b = vec![4, 5, 6, 7]; + a.resize(n, 0); + b.resize(n, 0); + + // Perform the standard polynomial multiplication + let c_std = polymul(&a, &b, n as i64, modulus); + + // Perform the NTT-based polynomial multiplication + let c_fast = polymul_ntt(&a, &b, n, modulus, omega); + + // Ensure both methods produce the same result + assert_eq!(c_std, c_fast, "The results of polymul and polymul_ntt do not match"); + } + + #[test] + fn test_polymul_ntt_non_prime_power_modulus() { + let modulus: i64 = 45; // modulus p^k + let n: usize = 4; // Length of the NTT (must be a power of 2) + let omega = omega(modulus, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; From 8cc15ad9ecf1dcc2f4cbf1d477ade533ef347a6a Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 03:17:21 -0500 Subject: [PATCH 08/20] update test --- src/test.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/test.rs b/src/test.rs index ee7e450..c5302ef 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use ntt::{omega, polymul, polymul_ntt}; + use ntt::{omega, polymul, polymul_ntt,mod_exp}; #[test] fn test_polymul_ntt() { @@ -70,15 +70,14 @@ mod tests { #[test] fn test_polymul_ntt_non_prime_power_modulus() { - let modulus: i64 = 45; // modulus p^k + let modulus: i64 = 45; // modulus not of the form p^k let n: usize = 4; // Length of the NTT (must be a power of 2) let omega = omega(modulus, n); // n-th root of unity + println!("omega^n: {}", mod_exp(omega,n as i64,modulus)); // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; let mut b = vec![4, 5, 6, 7]; - a.resize(n, 0); - b.resize(n, 0); // Perform the standard polynomial multiplication let c_std = polymul(&a, &b, n as i64, modulus); From 811fbfa06b989078d8214256685e1b3d00b2cfbb Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 03:50:37 -0500 Subject: [PATCH 09/20] compute mod_inv with extended_gcd --- src/lib.rs | 35 +++++++++++++++-------------------- src/main.rs | 7 ++++--- src/test.rs | 6 ++++-- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d2647be..be6d4ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,22 +2,6 @@ use reikna::totient::totient; use reikna::factor::quick_factorize; use std::collections::HashMap; -fn gcd(mut a: i64, mut b: i64) -> i64 { - while b != 0 { - let temp = b; - b = a % b; - a = temp; - } - a.abs() -} - -pub fn is_prime_power(n: i64) -> bool { - if n <= 1 { - return false; // 1 and numbers <= 1 are not prime powers - } - factorize(n).len() == 1 -} - // Modular arithmetic functions using i64 fn mod_add(a: i64, b: i64, p: i64) -> i64 { (a + b) % p @@ -40,10 +24,21 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { result } -//compute the modular inverse of a modulo p using Fermat's little theorem, p not necessarily prime -fn mod_inv(a: i64, p: i64) -> i64 { - assert!(gcd(a, p) == 1, "{} and {} are not coprime", a, p); - mod_exp(a, totient(p as u64) as i64 - 1, p) // order of mult. group is Euler's totient function +fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { + if b == 0 { + (a, 1, 0) // gcd, x, y + } else { + let (gcd, x1, y1) = extended_gcd(b, a % b); + (gcd, y1, x1 - (a / b) * y1) + } +} + +pub fn mod_inv(a: i64, modulus: i64) -> i64 { + let (gcd, x, _) = extended_gcd(a, modulus); + if gcd != 1 { + panic!("{} and {} are not coprime, no inverse exists", a, modulus); + } + (x % modulus + modulus) % modulus // Ensure a positive result } // Compute n-th root of unity (omega) for p not necessarily prime diff --git a/src/main.rs b/src/main.rs index 2b56ef2..f4c4924 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod test; use reikna::totient::totient; -use ntt::{omega, ntt, intt , polymul, polymul_ntt, cyclic_subgroup_gen, mod_exp}; +use ntt::{omega, ntt, intt , polymul, polymul_ntt, cyclic_subgroup_gen, mod_exp, mod_inv}; fn main() { let p: i64 = 17; // Prime modulus @@ -45,10 +45,11 @@ fn main() { println!("Standard polynomial mult. result: {:?}", c_std); println!("Polynomial multiplication method using NTT: {:?}", c_fast); - let modulus = 45; // Example modulus - let n = 4; // Must be a power of 2 + let modulus = 51; // Example modulus + let n = 8; // Must be a power of 2 let omega = cyclic_subgroup_gen(modulus, n); println!("Totient of {}: {}", modulus, totient(modulus as u64)); println!("omega: {}", omega); println!("omega^n: {}", mod_exp(omega, n, modulus)); + println!("n^-1 = {}", mod_inv(n as i64, modulus)); } diff --git a/src/test.rs b/src/test.rs index c5302ef..ada8d6a 100644 --- a/src/test.rs +++ b/src/test.rs @@ -70,14 +70,16 @@ mod tests { #[test] fn test_polymul_ntt_non_prime_power_modulus() { - let modulus: i64 = 45; // modulus not of the form p^k - let n: usize = 4; // Length of the NTT (must be a power of 2) + let modulus: i64 = 51; // modulus not of the form p^k + let n: usize = 8; // Length of the NTT (must be a power of 2) let omega = omega(modulus, n); // n-th root of unity println!("omega^n: {}", mod_exp(omega,n as i64,modulus)); // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; let mut b = vec![4, 5, 6, 7]; + a.resize(n, 0); + b.resize(n, 0); // Perform the standard polynomial multiplication let c_std = polymul(&a, &b, n as i64, modulus); From fddba744b4d8610102ef34026a7d92e60429e9e1 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 03:59:01 -0500 Subject: [PATCH 10/20] clean up main --- src/main.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main.rs b/src/main.rs index f4c4924..651ccbe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,12 @@ mod test; use reikna::totient::totient; -use ntt::{omega, ntt, intt , polymul, polymul_ntt, cyclic_subgroup_gen, mod_exp, mod_inv}; +use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv}; fn main() { let p: i64 = 17; // Prime modulus let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(p, n); // n-th root of unity: root^((p - 1) / n) % p + let omega = ntt::omega(p, n); // n-th root of unity: root^((p - 1) / n) % p // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -47,9 +47,9 @@ fn main() { let modulus = 51; // Example modulus let n = 8; // Must be a power of 2 - let omega = cyclic_subgroup_gen(modulus, n); + let omega = ntt::omega(modulus, n); // n-th root of unity println!("Totient of {}: {}", modulus, totient(modulus as u64)); println!("omega: {}", omega); - println!("omega^n: {}", mod_exp(omega, n, modulus)); + println!("omega^n: {}", mod_exp(omega, n as i64, modulus)); println!("n^-1 = {}", mod_inv(n as i64, modulus)); } From 673a6fa79432b1064198b951ac3f53bfad925915 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Wed, 12 Feb 2025 11:22:25 -0500 Subject: [PATCH 11/20] add debugging --- src/lib.rs | 5 +++-- src/main.rs | 25 ++++++++++++++++++++++++- src/test.rs | 1 - 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index be6d4ae..612df45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,7 +52,7 @@ pub fn omega(modulus: i64, n: usize) -> i64 { return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function } else { - return cyclic_subgroup_gen(modulus, n as i64) + return root(modulus, n as i64) } } @@ -184,7 +184,8 @@ fn find_primitive_root_mod_p(p: i64) -> i64 { 0 // Should never happen } -pub fn cyclic_subgroup_gen(modulus: i64, n: i64) -> i64 { +// Compute the n-th root of unity modulo a composite modulus +pub fn root(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); // factor the modulus for (&p, &e) in &factors { let phi = (p - 1) * p.pow(e - 1); // Euler's totient function diff --git a/src/main.rs b/src/main.rs index 651ccbe..14fb4f5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -45,11 +45,34 @@ fn main() { println!("Standard polynomial mult. result: {:?}", c_std); println!("Polynomial multiplication method using NTT: {:?}", c_fast); + //test the composite modulus case let modulus = 51; // Example modulus let n = 8; // Must be a power of 2 let omega = ntt::omega(modulus, n); // n-th root of unity println!("Totient of {}: {}", modulus, totient(modulus as u64)); println!("omega: {}", omega); - println!("omega^n: {}", mod_exp(omega, n as i64, modulus)); + (0..=n).for_each(|i| println!("omega^{}: {}", i, mod_exp(omega, i as i64, modulus))); println!("n^-1 = {}", mod_inv(n as i64, modulus)); + let a_ntt = ntt(&a, omega, n, modulus); + let b_ntt = ntt(&b, omega, n, modulus); + let a_ntt_intt = intt(&a_ntt, omega, n, modulus); + let b_ntt_intt = intt(&b_ntt, omega, n, modulus); + let c_ntt: Vec = a_ntt + .iter() + .zip(b_ntt.iter()) + .map(|(x, y)| (x * y) % modulus) + .collect(); + let c = intt(&c_ntt, omega, n, modulus); + let c_std = polymul(&a, &b, n as i64, modulus); + let c_fast = polymul_ntt(&a, &b, n, modulus, omega); + println!("A: {:?}", a); + println!("Transformed A: {:?}", a_ntt); + println!("Transformed B: {:?}", b_ntt); + println!("Recovered A: {:?}", a_ntt_intt); + println!("Recovered B: {:?}", b_ntt_intt); + println!("Pointwise Product in NTT Domain: {:?}", c_ntt); + println!("Standard polynomial mult. result: {:?}", c_std); + println!("Resultant Polynomial (c): {:?}", c); + println!("Polynomial multiplication method using NTT: {:?}", c_fast); + } diff --git a/src/test.rs b/src/test.rs index ada8d6a..9bbf7c4 100644 --- a/src/test.rs +++ b/src/test.rs @@ -73,7 +73,6 @@ mod tests { let modulus: i64 = 51; // modulus not of the form p^k let n: usize = 8; // Length of the NTT (must be a power of 2) let omega = omega(modulus, n); // n-th root of unity - println!("omega^n: {}", mod_exp(omega,n as i64,modulus)); // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; From 883b5162bb4c8116a198490772cbcc47d3d4aed7 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Thu, 13 Feb 2025 14:08:05 -0500 Subject: [PATCH 12/20] write function to verify root of unity the root of unity needs to satisfy a summation condition. since there are zero divisors in this ring, it is now possible that omega-1 is a zero divisor. --- src/lib.rs | 26 +++++++++++++++++++------- src/main.rs | 3 ++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 612df45..e33b438 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,7 +52,7 @@ pub fn omega(modulus: i64, n: usize) -> i64 { return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function } else { - return root(modulus, n as i64) + return root_of_unity(modulus, n as i64) } } @@ -156,9 +156,8 @@ fn factorize(n: i64) -> HashMap { /// Fast computation of a primitive root mod p^e pub fn primitive_root(p: i64, e: u32) -> i64 { - - // Find a primitive root mod p - let g = find_primitive_root_mod_p(p); + + let g = primitive_root_mod_p(p); // Lift it to p^e let mut g_lifted = g; @@ -171,10 +170,9 @@ pub fn primitive_root(p: i64, e: u32) -> i64 { } /// Finds a primitive root modulo a prime p -fn find_primitive_root_mod_p(p: i64) -> i64 { +fn primitive_root_mod_p(p: i64) -> i64 { let phi = p - 1; let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities - for g in 2..p { // Check if g is a primitive root by checking mod_exp conditions with all prime factors of phi if factors.iter().all(|(&q, _)| mod_exp(g, phi / q, p) != 1) { @@ -185,7 +183,7 @@ fn find_primitive_root_mod_p(p: i64) -> i64 { } // Compute the n-th root of unity modulo a composite modulus -pub fn root(modulus: i64, n: i64) -> i64 { +pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); // factor the modulus for (&p, &e) in &factors { let phi = (p - 1) * p.pow(e - 1); // Euler's totient function @@ -199,6 +197,20 @@ pub fn root(modulus: i64, n: i64) -> i64 { panic!("could not find element of order n"); } +pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { + assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); + for k in 1..n { + let mut sum = 0i64; + for j in 0..n { + sum = (sum + mod_exp(omega, j * k, modulus)) % modulus; + } + if sum != 0 { + return false; + } + } + true +} + pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { let n = n1 * n2; let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 diff --git a/src/main.rs b/src/main.rs index 14fb4f5..c1dc735 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod test; use reikna::totient::totient; -use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv}; +use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv, verify_root_of_unity}; fn main() { let p: i64 = 17; // Prime modulus @@ -51,6 +51,7 @@ fn main() { let omega = ntt::omega(modulus, n); // n-th root of unity println!("Totient of {}: {}", modulus, totient(modulus as u64)); println!("omega: {}", omega); + println!("verify omega = {}", verify_root_of_unity(omega, n as i64, modulus)); (0..=n).for_each(|i| println!("omega^{}: {}", i, mod_exp(omega, i as i64, modulus))); println!("n^-1 = {}", mod_inv(n as i64, modulus)); let a_ntt = ntt(&a, omega, n, modulus); From 3a27cb8a19a73d76fe2727bc34557fc8c989ec68 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Thu, 13 Feb 2025 15:35:50 -0500 Subject: [PATCH 13/20] generalize n^th root of unity function now for each factor in the mult. group, compute a divisor d of the order phi by using the gcd. ensure that the product of these divisors is n. then compute elements of order d for each factor, pull them back along the CRT isomorphism, and take their product mod modulus. --- src/lib.rs | 38 ++++++++++++++++++++++++++++++++++++++ src/main.rs | 1 + 2 files changed, 39 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index e33b438..9cf10f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,14 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { result } +fn gcd(a: i64, b: i64) -> i64 { + if b == 0 { + a.abs() + } else { + gcd(b, a % b) + } +} + fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { if b == 0 { (a, 1, 0) // gcd, x, y @@ -183,6 +191,7 @@ fn primitive_root_mod_p(p: i64) -> i64 { } // Compute the n-th root of unity modulo a composite modulus +/* pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); // factor the modulus for (&p, &e) in &factors { @@ -196,6 +205,35 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 { } panic!("could not find element of order n"); } +*/ + +pub fn root_of_unity(modulus: i64, n: i64) -> i64 { + let factors = factorize(modulus); + let mut remaining_n = n; + let mut result = 1; + let mut current_modulus = modulus; // Start with the full modulus + + for (&p, &e) in &factors { + let phi = (p - 1) * p.pow(e - 1); + let d = gcd(remaining_n, phi); // GCD with the current factor + remaining_n /= d; + + if d > 1 { + let g = primitive_root(p, e); + let exp = phi / d; + let order_d_elem = mod_exp(g, exp, p.pow(e)); + + current_modulus /= p.pow(e); // Remove this factor before CRT + result = crt(result, current_modulus, order_d_elem, p.pow(e)); // Combine using CRT + } + } + + if remaining_n != 1 { + panic!("Could not find all factors of n in the group"); + } + + result +} pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); diff --git a/src/main.rs b/src/main.rs index c1dc735..b4ec721 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,6 +35,7 @@ fn main() { let c_fast = polymul_ntt(&a, &b, n, p, omega); // Output the results + println!("verify omega = {}", verify_root_of_unity(omega, n as i64, p)); println!("Polynomial A: {:?}", a); println!("Polynomial B: {:?}", b); println!("Transformed A: {:?}", a_ntt); From 641e674a78f0750a7ed94730e533f7cb47135d7d Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Thu, 13 Feb 2025 16:47:48 -0500 Subject: [PATCH 14/20] compute divisors whose lcm is n given a list of integers [phi_i], we want to compute a list of divisors d_i such that lcm({d_i}) = n. --- src/lib.rs | 67 +++++++++++++++++++++++++++++++++-------------------- src/main.rs | 8 ++++++- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9cf10f9..eaf2ad7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,31 +207,58 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 { } */ +pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { + let n = n1 * n2; + let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 + let m2 = mod_inv(n2, n1); // Inverse of n2 mod n1 + let x = (a1 * m2 * n2 + a2 * m1 * n1) % n; + if x < 0 { x + n } else { x } +} + +pub fn divisors_with_given_lcm(phi: &[i64], n: i64) -> Option> { + let n_factors = factorize(n); + let phi_factors: Vec> = phi.iter().map(|&x| factorize(x)).collect(); + let num_phi = phi.len(); + let mut d: Vec = vec![1; num_phi]; + + for (&prime, &n_power) in n_factors.iter() { + let mut found = false; + for i in 0..num_phi { + if phi_factors[i].contains_key(&prime) && phi_factors[i][&prime] >= n_power { + d[i] *= prime.pow(n_power); + found = true; + break; + } + } + if !found { + return None; + } + } + + Some(d) +} + pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); - let mut remaining_n = n; let mut result = 1; - let mut current_modulus = modulus; // Start with the full modulus - for (&p, &e) in &factors { - let phi = (p - 1) * p.pow(e - 1); - let d = gcd(remaining_n, phi); // GCD with the current factor - remaining_n /= d; + // Compute the divisors d_i such that lcm(d_i for all i) = n + let phi: Vec = factors.iter().map(|(&p, &e)| (p - 1) * p.pow(e - 1)).collect(); + let divisors = divisors_with_given_lcm(&phi, n).expect("Could not find divisors with LCM equal to n"); + for (i, (&p, &e)) in factors.iter().enumerate() { + let d = divisors[i]; // Use the divisor for the current factor if d > 1 { - let g = primitive_root(p, e); - let exp = phi / d; - let order_d_elem = mod_exp(g, exp, p.pow(e)); + let g = primitive_root(p, e); // Find primitive root mod p^e + let phi = (p - 1) * p.pow(e - 1); // Euler's totient function + let exp = phi / d; // Compute exponent for order d + let order_d_elem = mod_exp(g, exp, p.pow(e)); // Element of order d - current_modulus /= p.pow(e); // Remove this factor before CRT - result = crt(result, current_modulus, order_d_elem, p.pow(e)); // Combine using CRT + // Combine with the running result using CRT + result = crt(result, modulus / p.pow(e), order_d_elem, p.pow(e)); } } - if remaining_n != 1 { - panic!("Could not find all factors of n in the group"); - } - result } @@ -249,13 +276,3 @@ pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { true } -pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { - let n = n1 * n2; - let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 - let m2 = mod_inv(n2, n1); // Inverse of n2 mod n1 - let x = (a1 * m2 * n2 + a2 * m1 * n1) % n; - if x < 0 { x + n } else { x } -} - - - diff --git a/src/main.rs b/src/main.rs index b4ec721..e73e0df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod test; use reikna::totient::totient; -use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv, verify_root_of_unity}; +use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv, verify_root_of_unity, divisors_with_given_lcm}; fn main() { let p: i64 = 17; // Prime modulus @@ -77,4 +77,10 @@ fn main() { println!("Resultant Polynomial (c): {:?}", c); println!("Polynomial multiplication method using NTT: {:?}", c_fast); + //check that we can take a list of numbers and compute divisors such that their lcm = a number + let n = 12; + let phis = vec![4, 6, 6]; // Example phi values + let divisors = divisors_with_given_lcm(&phis, n); + println!("{:?}", divisors); // Output a set of divisors whose LCM is n + } From 4cdd7f0977571ccdc8aa64b5529a128fa7ba846f Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Thu, 13 Feb 2025 17:03:15 -0500 Subject: [PATCH 15/20] try N=85 --- src/lib.rs | 8 -------- src/main.rs | 6 +++--- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index eaf2ad7..4ba5986 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,14 +24,6 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 { result } -fn gcd(a: i64, b: i64) -> i64 { - if b == 0 { - a.abs() - } else { - gcd(b, a % b) - } -} - fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) { if b == 0 { (a, 1, 0) // gcd, x, y diff --git a/src/main.rs b/src/main.rs index e73e0df..9ace6df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,7 @@ fn main() { println!("Polynomial multiplication method using NTT: {:?}", c_fast); //test the composite modulus case - let modulus = 51; // Example modulus + let modulus = 85; // Example modulus let n = 8; // Must be a power of 2 let omega = ntt::omega(modulus, n); // n-th root of unity println!("Totient of {}: {}", modulus, totient(modulus as u64)); @@ -78,8 +78,8 @@ fn main() { println!("Polynomial multiplication method using NTT: {:?}", c_fast); //check that we can take a list of numbers and compute divisors such that their lcm = a number - let n = 12; - let phis = vec![4, 6, 6]; // Example phi values + let n = 8; + let phis = vec![8, 2]; // Example phi values let divisors = divisors_with_given_lcm(&phis, n); println!("{:?}", divisors); // Output a set of divisors whose LCM is n From 6491944faa690af0d895f2d46432b4ecf980917b Mon Sep 17 00:00:00 2001 From: tjaysilver Date: Thu, 13 Feb 2025 18:42:02 -0500 Subject: [PATCH 16/20] fix root_of_unity omega(modulus,n) will now compute an element that is a primitive nth root of unity mod all prime factors of n --- src/lib.rs | 25 +++++++------------------ src/main.rs | 2 +- src/test.rs | 4 ++-- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4ba5986..789a7fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -234,24 +234,13 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); let mut result = 1; - // Compute the divisors d_i such that lcm(d_i for all i) = n - let phi: Vec = factors.iter().map(|(&p, &e)| (p - 1) * p.pow(e - 1)).collect(); - let divisors = divisors_with_given_lcm(&phi, n).expect("Could not find divisors with LCM equal to n"); - - for (i, (&p, &e)) in factors.iter().enumerate() { - let d = divisors[i]; // Use the divisor for the current factor - if d > 1 { - let g = primitive_root(p, e); // Find primitive root mod p^e - let phi = (p - 1) * p.pow(e - 1); // Euler's totient function - let exp = phi / d; // Compute exponent for order d - let order_d_elem = mod_exp(g, exp, p.pow(e)); // Element of order d - - // Combine with the running result using CRT - result = crt(result, modulus / p.pow(e), order_d_elem, p.pow(e)); - } - } - - result + for (&p, &e) in factors.iter() { + // Find primitive nth root of unity mod p^e + let omega = omega(p.pow(e), n.try_into().unwrap()); + // Combine with the running result using CRT + result = crt(result, modulus / p.pow(e), omega, p.pow(e)); + } + result } pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { diff --git a/src/main.rs b/src/main.rs index 9ace6df..f50da26 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,7 @@ fn main() { println!("Polynomial multiplication method using NTT: {:?}", c_fast); //test the composite modulus case - let modulus = 85; // Example modulus + let modulus = 697; // Example modulus let n = 8; // Must be a power of 2 let omega = ntt::omega(modulus, n); // n-th root of unity println!("Totient of {}: {}", modulus, totient(modulus as u64)); diff --git a/src/test.rs b/src/test.rs index 9bbf7c4..f038058 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use ntt::{omega, polymul, polymul_ntt,mod_exp}; + use ntt::{omega, polymul, polymul_ntt}; #[test] fn test_polymul_ntt() { @@ -70,7 +70,7 @@ mod tests { #[test] fn test_polymul_ntt_non_prime_power_modulus() { - let modulus: i64 = 51; // modulus not of the form p^k + let modulus: i64 = 697; // modulus not of the form p^k let n: usize = 8; // Length of the NTT (must be a power of 2) let omega = omega(modulus, n); // n-th root of unity From 97dd0d3bb8490f6ba8c4eb326a89c072a93fd7e5 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Fri, 14 Feb 2025 14:08:48 -0500 Subject: [PATCH 17/20] assert omega is square remove unnecessary functions and assert omega is square --- src/lib.rs | 58 +++++++++++------------------------------------------ src/main.rs | 8 +------- 2 files changed, 13 insertions(+), 53 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 789a7fb..43e5879 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,7 +48,7 @@ pub fn omega(modulus: i64, n: usize) -> i64 { let (p, e) = factors.into_iter().next().unwrap(); let root = primitive_root(p, e); // primitive root mod p let grp_size = totient(modulus as u64) as i64; - assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size); + assert!(grp_size % 2*n as i64 == 0, "{} does not divide {}", 2*n, grp_size); return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function } else { @@ -182,23 +182,7 @@ fn primitive_root_mod_p(p: i64) -> i64 { 0 // Should never happen } -// Compute the n-th root of unity modulo a composite modulus -/* -pub fn root_of_unity(modulus: i64, n: i64) -> i64 { - let factors = factorize(modulus); // factor the modulus - for (&p, &e) in &factors { - let phi = (p - 1) * p.pow(e - 1); // Euler's totient function - if phi % n == 0 { - let g = primitive_root(p, e); // find a primitive root mod p^e - let exp = phi / n; // exponent of the primitive root - let order_n_elem = mod_exp(g, exp, p.pow(e)); // element of mult. order n mod p^e - return crt(1, modulus/p.pow(e), order_n_elem, p.pow(e)); // lift using CRT - } - } - panic!("could not find element of order n"); -} -*/ - +// the Chinese remainder theorem for two moduli pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { let n = n1 * n2; let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2 @@ -207,42 +191,24 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { if x < 0 { x + n } else { x } } -pub fn divisors_with_given_lcm(phi: &[i64], n: i64) -> Option> { - let n_factors = factorize(n); - let phi_factors: Vec> = phi.iter().map(|&x| factorize(x)).collect(); - let num_phi = phi.len(); - let mut d: Vec = vec![1; num_phi]; - - for (&prime, &n_power) in n_factors.iter() { - let mut found = false; - for i in 0..num_phi { - if phi_factors[i].contains_key(&prime) && phi_factors[i][&prime] >= n_power { - d[i] *= prime.pow(n_power); - found = true; - break; - } - } - if !found { - return None; - } - } - - Some(d) -} - +// computes an n^th root of unity modulo a composite modulus +// note we require that an n^th root of unity exists for each multiplicative group modulo p^e +// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus +// for the NTT, we require than a 2n^th root of unity exists pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); let mut result = 1; - + let mut omega_is_square = false; for (&p, &e) in factors.iter() { - // Find primitive nth root of unity mod p^e - let omega = omega(p.pow(e), n.try_into().unwrap()); - // Combine with the running result using CRT - result = crt(result, modulus / p.pow(e), omega, p.pow(e)); + let omega = omega(p.pow(e), n.try_into().unwrap()); // Find primitive nth root of unity mod p^e + result = crt(result, modulus / p.pow(e), omega, p.pow(e)); // Combine with the running result using CRT + omega_is_square = omega_is_square || totient(p.pow(e) as u64) % (2*n as u64) == 0; } + assert!(omega_is_square, "no 2n-th root of unity exists modulo {}", modulus); result } +//ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); for k in 1..n { diff --git a/src/main.rs b/src/main.rs index f50da26..912349f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod test; use reikna::totient::totient; -use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv, verify_root_of_unity, divisors_with_given_lcm}; +use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv, verify_root_of_unity}; fn main() { let p: i64 = 17; // Prime modulus @@ -77,10 +77,4 @@ fn main() { println!("Resultant Polynomial (c): {:?}", c); println!("Polynomial multiplication method using NTT: {:?}", c_fast); - //check that we can take a list of numbers and compute divisors such that their lcm = a number - let n = 8; - let phis = vec![8, 2]; // Example phi values - let divisors = divisors_with_given_lcm(&phis, n); - println!("{:?}", divisors); // Output a set of divisors whose LCM is n - } From 0f2ceec85ef525032dccb968fa0905e4b2c3a58f Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Fri, 14 Feb 2025 14:11:47 -0500 Subject: [PATCH 18/20] omega doesn't need to be square --- src/lib.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 43e5879..4cc2afe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,7 +48,7 @@ pub fn omega(modulus: i64, n: usize) -> i64 { let (p, e) = factors.into_iter().next().unwrap(); let root = primitive_root(p, e); // primitive root mod p let grp_size = totient(modulus as u64) as i64; - assert!(grp_size % 2*n as i64 == 0, "{} does not divide {}", 2*n, grp_size); + assert!(grp_size % n as i64 == 0, "{} does not divide {}", n, grp_size); return mod_exp(root, grp_size / n as i64, modulus) // order of mult. group is Euler's totient function } else { @@ -198,13 +198,10 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 { pub fn root_of_unity(modulus: i64, n: i64) -> i64 { let factors = factorize(modulus); let mut result = 1; - let mut omega_is_square = false; for (&p, &e) in factors.iter() { let omega = omega(p.pow(e), n.try_into().unwrap()); // Find primitive nth root of unity mod p^e result = crt(result, modulus / p.pow(e), omega, p.pow(e)); // Combine with the running result using CRT - omega_is_square = omega_is_square || totient(p.pow(e) as u64) % (2*n as u64) == 0; } - assert!(omega_is_square, "no 2n-th root of unity exists modulo {}", modulus); result } From 5dcb428818e7300fd5e1d2fe0dba7a75e781a49f Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Fri, 14 Feb 2025 14:22:37 -0500 Subject: [PATCH 19/20] simplify verification step rather than checking that the sums sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n, we can simplify this to just omega^{n/2} == -1 (mod N). --- src/lib.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4cc2afe..4423aea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -208,15 +208,7 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 { //ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool { assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity"); - for k in 1..n { - let mut sum = 0i64; - for j in 0..n { - sum = (sum + mod_exp(omega, j * k, modulus)) % modulus; - } - if sum != 0 { - return false; - } - } + assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)"); true } From a7ec5cbd9081e8ba86af7ed88fbb191863c30537 Mon Sep 17 00:00:00 2001 From: Jackson Walters Date: Fri, 14 Feb 2025 14:45:04 -0500 Subject: [PATCH 20/20] more test cases add some more test cases to the non prime power modulus case --- src/lib.rs | 5 +---- src/main.rs | 54 +++++++++++------------------------------------------ src/test.rs | 32 +++++++++++++++---------------- 3 files changed, 27 insertions(+), 64 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4423aea..91f4696 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -156,11 +156,8 @@ fn factorize(n: i64) -> HashMap { /// Fast computation of a primitive root mod p^e pub fn primitive_root(p: i64, e: u32) -> i64 { - let g = primitive_root_mod_p(p); - - // Lift it to p^e - let mut g_lifted = g; + let mut g_lifted = g; // Lift it to p^e for _ in 1..e { if g_lifted.pow((p - 1) as u32) % p.pow(e) == 1 { g_lifted += p.pow(e - 1); diff --git a/src/main.rs b/src/main.rs index 912349f..e7bb1f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,11 @@ mod test; -use reikna::totient::totient; -use ntt::{ntt, intt , polymul, polymul_ntt, mod_exp, mod_inv, verify_root_of_unity}; +use ntt::{ntt, intt , polymul, polymul_ntt, verify_root_of_unity}; fn main() { - let p: i64 = 17; // Prime modulus + let modulus: i64 = 17; // modulus, n must divide phi(p^k) for each prime factor p let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = ntt::omega(p, n); // n-th root of unity: root^((p - 1) / n) % p + let omega = ntt::omega(modulus, n); // n-th root of unity // Input polynomials (padded to length `n`) let mut a = vec![1, 2, 3, 4]; @@ -15,27 +14,27 @@ fn main() { b.resize(n, 0); // Perform the forward NTT - let a_ntt = ntt(&a, omega, n, p); - let b_ntt = ntt(&b, omega, n, p); + let a_ntt = ntt(&a, omega, n, modulus); + let b_ntt = ntt(&b, omega, n, modulus); // Perform the inverse NTT on the transformed A for verification - let a_ntt_intt = intt(&a_ntt, omega, n, p); + let a_ntt_intt = intt(&a_ntt, omega, n, modulus); // Pointwise multiplication in the NTT domain let c_ntt: Vec = a_ntt .iter() .zip(b_ntt.iter()) - .map(|(x, y)| (x * y) % p) + .map(|(x, y)| (x * y) % modulus) .collect(); // Inverse NTT to get the polynomial product - let c = intt(&c_ntt, omega, n, p); + let c = intt(&c_ntt, omega, n, modulus); - let c_std = polymul(&a, &b, n as i64, p); - let c_fast = polymul_ntt(&a, &b, n, p, omega); + let c_std = polymul(&a, &b, n as i64, modulus); + let c_fast = polymul_ntt(&a, &b, n, modulus, omega); // Output the results - println!("verify omega = {}", verify_root_of_unity(omega, n as i64, p)); + println!("verify omega = {}", verify_root_of_unity(omega, n as i64, modulus)); println!("Polynomial A: {:?}", a); println!("Polynomial B: {:?}", b); println!("Transformed A: {:?}", a_ntt); @@ -46,35 +45,4 @@ fn main() { println!("Standard polynomial mult. result: {:?}", c_std); println!("Polynomial multiplication method using NTT: {:?}", c_fast); - //test the composite modulus case - let modulus = 697; // Example modulus - let n = 8; // Must be a power of 2 - let omega = ntt::omega(modulus, n); // n-th root of unity - println!("Totient of {}: {}", modulus, totient(modulus as u64)); - println!("omega: {}", omega); - println!("verify omega = {}", verify_root_of_unity(omega, n as i64, modulus)); - (0..=n).for_each(|i| println!("omega^{}: {}", i, mod_exp(omega, i as i64, modulus))); - println!("n^-1 = {}", mod_inv(n as i64, modulus)); - let a_ntt = ntt(&a, omega, n, modulus); - let b_ntt = ntt(&b, omega, n, modulus); - let a_ntt_intt = intt(&a_ntt, omega, n, modulus); - let b_ntt_intt = intt(&b_ntt, omega, n, modulus); - let c_ntt: Vec = a_ntt - .iter() - .zip(b_ntt.iter()) - .map(|(x, y)| (x * y) % modulus) - .collect(); - let c = intt(&c_ntt, omega, n, modulus); - let c_std = polymul(&a, &b, n as i64, modulus); - let c_fast = polymul_ntt(&a, &b, n, modulus, omega); - println!("A: {:?}", a); - println!("Transformed A: {:?}", a_ntt); - println!("Transformed B: {:?}", b_ntt); - println!("Recovered A: {:?}", a_ntt_intt); - println!("Recovered B: {:?}", b_ntt_intt); - println!("Pointwise Product in NTT Domain: {:?}", c_ntt); - println!("Standard polynomial mult. result: {:?}", c_std); - println!("Resultant Polynomial (c): {:?}", c); - println!("Polynomial multiplication method using NTT: {:?}", c_fast); - } diff --git a/src/test.rs b/src/test.rs index f038058..25ba447 100644 --- a/src/test.rs +++ b/src/test.rs @@ -70,24 +70,22 @@ mod tests { #[test] fn test_polymul_ntt_non_prime_power_modulus() { - let modulus: i64 = 697; // modulus not of the form p^k + let moduli = [17*41, 17*73, 17*41*73]; // Different moduli to test let n: usize = 8; // Length of the NTT (must be a power of 2) - let omega = omega(modulus, n); // n-th root of unity - - // Input polynomials (padded to length `n`) - let mut a = vec![1, 2, 3, 4]; - let mut b = vec![4, 5, 6, 7]; - a.resize(n, 0); - b.resize(n, 0); - - // Perform the standard polynomial multiplication - let c_std = polymul(&a, &b, n as i64, modulus); - - // Perform the NTT-based polynomial multiplication - let c_fast = polymul_ntt(&a, &b, n, modulus, omega); - - // Ensure both methods produce the same result - assert_eq!(c_std, c_fast, "The results of polymul and polymul_ntt do not match"); + + for &modulus in &moduli { + let omega = omega(modulus, n); + + let mut a = vec![1, 2, 3, 4]; + let mut b = vec![4, 5, 6, 7]; + a.resize(n, 0); + b.resize(n, 0); + + let c_std = polymul(&a, &b, n as i64, modulus); + let c_fast = polymul_ntt(&a, &b, n, modulus, omega); + + assert_eq!(c_std, c_fast, "Failed for modulus {}", modulus); + } } }