@@ -2,16 +2,26 @@ use reikna::totient::totient;
22use reikna:: factor:: quick_factorize;
33use std:: collections:: HashMap ;
44
5- // Modular arithmetic functions using i64
5+ /// Modular arithmetic functions using i64
66fn mod_add ( a : i64 , b : i64 , p : i64 ) -> i64 {
77 ( a + b) % p
88}
99
10+ /// Modular multiplication
1011fn mod_mul ( a : i64 , b : i64 , p : i64 ) -> i64 {
1112 ( a * b) % p
1213}
1314
14- pub fn mod_exp ( mut base : i64 , mut exp : i64 , p : i64 ) -> i64 {
15+ /// Modular exponentiation
16+ /// # Arguments
17+ ///
18+ /// * `base` - Base of the exponentiation.
19+ /// * `exp` - Exponent.
20+ /// * `p` - Prime modulus for the operations.
21+ ///
22+ /// # Returns
23+ /// The result of the exponentiation modulo `p`.
24+ fn mod_exp ( mut base : i64 , mut exp : i64 , p : i64 ) -> i64 {
1525 let mut result = 1 ;
1626 base %= p;
1727 while exp > 0 {
@@ -24,6 +34,7 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
2434 result
2535}
2636
37+ /// Extended Euclidean algorithm
2738fn extended_gcd ( a : i64 , b : i64 ) -> ( i64 , i64 , i64 ) {
2839 if b == 0 {
2940 ( a, 1 , 0 ) // gcd, x, y
@@ -33,15 +44,23 @@ fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
3344 }
3445}
3546
36- pub fn mod_inv ( a : i64 , modulus : i64 ) -> i64 {
47+ /// Compute the modular inverse of a modulo modulus
48+ fn mod_inv ( a : i64 , modulus : i64 ) -> i64 {
3749 let ( gcd, x, _) = extended_gcd ( a, modulus) ;
3850 if gcd != 1 {
3951 panic ! ( "{} and {} are not coprime, no inverse exists" , a, modulus) ;
4052 }
4153 ( x % modulus + modulus) % modulus // Ensure a positive result
4254}
4355
44- // Compute n-th root of unity (omega) for p not necessarily prime
56+ /// Compute n-th root of unity (omega) for p not necessarily prime
57+ /// # Arguments
58+ ///
59+ /// * `modulus` - Modulus. n must divide each prime power factor.
60+ /// * `n` - Order of the root of unity.
61+ ///
62+ /// # Returns
63+ /// The n-th root of unity modulo `modulus`.
4564pub fn omega ( modulus : i64 , n : usize ) -> i64 {
4665 let factors = factorize ( modulus as i64 ) ;
4766 if factors. len ( ) == 1 {
@@ -56,7 +75,15 @@ pub fn omega(modulus: i64, n: usize) -> i64 {
5675 }
5776}
5877
59- // Forward transform using NTT, output bit-reversed
78+ /// Forward transform using NTT, output bit-reversed
79+ /// # Arguments
80+ ///
81+ /// * `a` - Input vector.
82+ /// * `omega` - Primitive root of unity modulo `p`.
83+ /// * `n` - Length of the input vector and the result.
84+ /// * `p` - Prime modulus for the operations.
85+ ///
86+ /// # Returns
6087pub fn ntt ( a : & [ i64 ] , omega : i64 , n : usize , p : i64 ) -> Vec < i64 > {
6188 let mut result = a. to_vec ( ) ;
6289 let mut step = n/2 ;
@@ -77,7 +104,16 @@ pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
77104 result
78105}
79106
80- // Inverse transform using INTT, input bit-reversed
107+ /// Inverse transform using INTT, input bit-reversed
108+ /// # Arguments
109+ ///
110+ /// * `a` - Input vector (bit-reversed).
111+ /// * `omega` - Primitive root of unity modulo `p`.
112+ /// * `n` - Length of the input vector and the result.
113+ /// * `p` - Prime modulus for the operations.
114+ ///
115+ /// # Returns
116+ /// A vector representing the inverse NTT of the input vector.
81117pub fn intt ( a : & [ i64 ] , omega : i64 , n : usize , p : i64 ) -> Vec < i64 > {
82118 let omega_inv = mod_inv ( omega, p) ;
83119 let n_inv = mod_inv ( n as i64 , p) ;
@@ -103,7 +139,16 @@ pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
103139 . collect ( )
104140}
105141
106- // Naive polynomial multiplication
142+ /// Naive polynomial multiplication
143+ /// # Arguments
144+ ///
145+ /// * `a` - First polynomial (as a vector of coefficients).
146+ /// * `b` - Second polynomial (as a vector of coefficients).
147+ /// * `n` - Length of the polynomials and the result.
148+ /// * `p` - Prime modulus for the operations.
149+ ///
150+ /// # Returns
151+ /// A vector representing the polynomial product modulo `p`.
107152pub fn polymul ( a : & Vec < i64 > , b : & Vec < i64 > , n : i64 , p : i64 ) -> Vec < i64 > {
108153 let mut result = vec ! [ 0 ; n as usize ] ;
109154 for i in 0 ..a. len ( ) {
@@ -145,7 +190,14 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i6
145190 c
146191}
147192
148- /// Compute the prime factorization of `n` (with multiplicities).
193+ /// Compute the prime factorization of `n` (with multiplicities)
194+ /// Uses reikna::quick_factorize internally
195+ /// # Arguments
196+ ///
197+ /// * `n` - Number to factorize.
198+ ///
199+ /// # Returns
200+ /// A HashMap with the prime factors of `n` as keys and their multiplicities as values.
149201fn factorize ( n : i64 ) -> HashMap < i64 , u32 > {
150202 let mut factors = HashMap :: new ( ) ;
151203 for factor in quick_factorize ( n as u64 ) {
@@ -167,6 +219,12 @@ pub fn primitive_root(p: i64, e: u32) -> i64 {
167219}
168220
169221/// Finds a primitive root modulo a prime p
222+ /// # Arguments
223+ ///
224+ /// * `p` - Prime modulus.
225+ ///
226+ /// # Returns
227+ /// A primitive root modulo `p`.
170228fn primitive_root_mod_p ( p : i64 ) -> i64 {
171229 let phi = p - 1 ;
172230 let factors = factorize ( phi) ; // Reusing factorize to get both prime factors and multiplicities
@@ -179,7 +237,16 @@ fn primitive_root_mod_p(p: i64) -> i64 {
179237 0 // Should never happen
180238}
181239
182- // the Chinese remainder theorem for two moduli
240+ /// the Chinese remainder theorem for two moduli
241+ /// # Arguments
242+ ///
243+ /// * `a1` - First residue.
244+ /// * `n1` - First modulus.
245+ /// * `a2` - Second residue.
246+ /// * `n2` - Second modulus.
247+ ///
248+ /// # Returns
249+ /// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2).
183250pub fn crt ( a1 : i64 , n1 : i64 , a2 : i64 , n2 : i64 ) -> i64 {
184251 let n = n1 * n2;
185252 let m1 = mod_inv ( n1, n2) ; // Inverse of n1 mod n2
@@ -188,10 +255,17 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
188255 if x < 0 { x + n } else { x }
189256}
190257
191- // computes an n^th root of unity modulo a composite modulus
192- // note we require that an n^th root of unity exists for each multiplicative group modulo p^e
193- // use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
194- // for the NTT, we require than a 2n^th root of unity exists
258+ /// computes an n^th root of unity modulo a composite modulus
259+ /// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
260+ /// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
261+ /// for the NTT, we require than a 2n^th root of unity exists
262+ /// # Arguments
263+ ///
264+ /// * `modulus` - Modulus. n must divide each prime power factor.
265+ /// * `n` - Order of the root of unity.
266+ ///
267+ /// # Returns
268+ /// The n-th root of unity modulo `modulus`.
195269pub fn root_of_unity ( modulus : i64 , n : i64 ) -> i64 {
196270 let factors = factorize ( modulus) ;
197271 let mut result = 1 ;
@@ -202,7 +276,15 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
202276 result
203277}
204278
205- //ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
279+ /// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
280+ /// # Arguments
281+ ///
282+ /// * `omega` - n-th root of unity.
283+ /// * `n` - Order of the root of unity.
284+ /// * `modulus` - Modulus.
285+ ///
286+ /// # Returns
287+ /// True if the root of unity satisfies the condition.
206288pub fn verify_root_of_unity ( omega : i64 , n : i64 , modulus : i64 ) -> bool {
207289 assert ! ( mod_exp( omega, n, modulus as i64 ) == 1 , "omega is not an n-th root of unity" ) ;
208290 assert ! ( mod_exp( omega, n/2 , modulus as i64 ) == modulus-1 , "omgea^(n/2) != -1 (mod modulus)" ) ;
0 commit comments