Skip to content

Commit 309d256

Browse files
add documentation
add arguments and return value docstrings
1 parent b254a5a commit 309d256

File tree

1 file changed

+96
-14
lines changed

1 file changed

+96
-14
lines changed

src/lib.rs

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,26 @@ use reikna::totient::totient;
22
use reikna::factor::quick_factorize;
33
use std::collections::HashMap;
44

5-
// Modular arithmetic functions using i64
5+
/// Modular arithmetic functions using i64
66
fn mod_add(a: i64, b: i64, p: i64) -> i64 {
77
(a + b) % p
88
}
99

10+
/// Modular multiplication
1011
fn mod_mul(a: i64, b: i64, p: i64) -> i64 {
1112
(a * b) % p
1213
}
1314

14-
pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
15+
/// Modular exponentiation
16+
/// # Arguments
17+
///
18+
/// * `base` - Base of the exponentiation.
19+
/// * `exp` - Exponent.
20+
/// * `p` - Prime modulus for the operations.
21+
///
22+
/// # Returns
23+
/// The result of the exponentiation modulo `p`.
24+
fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
1525
let mut result = 1;
1626
base %= p;
1727
while exp > 0 {
@@ -24,6 +34,7 @@ pub fn mod_exp(mut base: i64, mut exp: i64, p: i64) -> i64 {
2434
result
2535
}
2636

37+
/// Extended Euclidean algorithm
2738
fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
2839
if b == 0 {
2940
(a, 1, 0) // gcd, x, y
@@ -33,15 +44,23 @@ fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
3344
}
3445
}
3546

36-
pub fn mod_inv(a: i64, modulus: i64) -> i64 {
47+
/// Compute the modular inverse of a modulo modulus
48+
fn mod_inv(a: i64, modulus: i64) -> i64 {
3749
let (gcd, x, _) = extended_gcd(a, modulus);
3850
if gcd != 1 {
3951
panic!("{} and {} are not coprime, no inverse exists", a, modulus);
4052
}
4153
(x % modulus + modulus) % modulus // Ensure a positive result
4254
}
4355

44-
// Compute n-th root of unity (omega) for p not necessarily prime
56+
/// Compute n-th root of unity (omega) for p not necessarily prime
57+
/// # Arguments
58+
///
59+
/// * `modulus` - Modulus. n must divide each prime power factor.
60+
/// * `n` - Order of the root of unity.
61+
///
62+
/// # Returns
63+
/// The n-th root of unity modulo `modulus`.
4564
pub fn omega(modulus: i64, n: usize) -> i64 {
4665
let factors = factorize(modulus as i64);
4766
if factors.len() == 1 {
@@ -56,7 +75,15 @@ pub fn omega(modulus: i64, n: usize) -> i64 {
5675
}
5776
}
5877

59-
// Forward transform using NTT, output bit-reversed
78+
/// Forward transform using NTT, output bit-reversed
79+
/// # Arguments
80+
///
81+
/// * `a` - Input vector.
82+
/// * `omega` - Primitive root of unity modulo `p`.
83+
/// * `n` - Length of the input vector and the result.
84+
/// * `p` - Prime modulus for the operations.
85+
///
86+
/// # Returns
6087
pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
6188
let mut result = a.to_vec();
6289
let mut step = n/2;
@@ -77,7 +104,16 @@ pub fn ntt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
77104
result
78105
}
79106

80-
// Inverse transform using INTT, input bit-reversed
107+
/// Inverse transform using INTT, input bit-reversed
108+
/// # Arguments
109+
///
110+
/// * `a` - Input vector (bit-reversed).
111+
/// * `omega` - Primitive root of unity modulo `p`.
112+
/// * `n` - Length of the input vector and the result.
113+
/// * `p` - Prime modulus for the operations.
114+
///
115+
/// # Returns
116+
/// A vector representing the inverse NTT of the input vector.
81117
pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
82118
let omega_inv = mod_inv(omega, p);
83119
let n_inv = mod_inv(n as i64, p);
@@ -103,7 +139,16 @@ pub fn intt(a: &[i64], omega: i64, n: usize, p: i64) -> Vec<i64> {
103139
.collect()
104140
}
105141

106-
// Naive polynomial multiplication
142+
/// Naive polynomial multiplication
143+
/// # Arguments
144+
///
145+
/// * `a` - First polynomial (as a vector of coefficients).
146+
/// * `b` - Second polynomial (as a vector of coefficients).
147+
/// * `n` - Length of the polynomials and the result.
148+
/// * `p` - Prime modulus for the operations.
149+
///
150+
/// # Returns
151+
/// A vector representing the polynomial product modulo `p`.
107152
pub fn polymul(a: &Vec<i64>, b: &Vec<i64>, n: i64, p: i64) -> Vec<i64> {
108153
let mut result = vec![0; n as usize];
109154
for i in 0..a.len() {
@@ -145,7 +190,14 @@ pub fn polymul_ntt(a: &[i64], b: &[i64], n: usize, p: i64, omega: i64) -> Vec<i6
145190
c
146191
}
147192

148-
/// Compute the prime factorization of `n` (with multiplicities).
193+
/// Compute the prime factorization of `n` (with multiplicities)
194+
/// Uses reikna::quick_factorize internally
195+
/// # Arguments
196+
///
197+
/// * `n` - Number to factorize.
198+
///
199+
/// # Returns
200+
/// A HashMap with the prime factors of `n` as keys and their multiplicities as values.
149201
fn factorize(n: i64) -> HashMap<i64, u32> {
150202
let mut factors = HashMap::new();
151203
for factor in quick_factorize(n as u64) {
@@ -167,6 +219,12 @@ pub fn primitive_root(p: i64, e: u32) -> i64 {
167219
}
168220

169221
/// Finds a primitive root modulo a prime p
222+
/// # Arguments
223+
///
224+
/// * `p` - Prime modulus.
225+
///
226+
/// # Returns
227+
/// A primitive root modulo `p`.
170228
fn primitive_root_mod_p(p: i64) -> i64 {
171229
let phi = p - 1;
172230
let factors = factorize(phi); // Reusing factorize to get both prime factors and multiplicities
@@ -179,7 +237,16 @@ fn primitive_root_mod_p(p: i64) -> i64 {
179237
0 // Should never happen
180238
}
181239

182-
// the Chinese remainder theorem for two moduli
240+
/// the Chinese remainder theorem for two moduli
241+
/// # Arguments
242+
///
243+
/// * `a1` - First residue.
244+
/// * `n1` - First modulus.
245+
/// * `a2` - Second residue.
246+
/// * `n2` - Second modulus.
247+
///
248+
/// # Returns
249+
/// The solution to the system of congruences x = a1 (mod n1) and x = a2 (mod n2).
183250
pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
184251
let n = n1 * n2;
185252
let m1 = mod_inv(n1, n2); // Inverse of n1 mod n2
@@ -188,10 +255,17 @@ pub fn crt(a1: i64, n1: i64, a2: i64, n2: i64) -> i64 {
188255
if x < 0 { x + n } else { x }
189256
}
190257

191-
// computes an n^th root of unity modulo a composite modulus
192-
// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
193-
// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
194-
// for the NTT, we require than a 2n^th root of unity exists
258+
/// computes an n^th root of unity modulo a composite modulus
259+
/// note we require that an n^th root of unity exists for each multiplicative group modulo p^e
260+
/// use the CRT isomorphism to pull back each n^th root of unity to the composite modulus
261+
/// for the NTT, we require than a 2n^th root of unity exists
262+
/// # Arguments
263+
///
264+
/// * `modulus` - Modulus. n must divide each prime power factor.
265+
/// * `n` - Order of the root of unity.
266+
///
267+
/// # Returns
268+
/// The n-th root of unity modulo `modulus`.
195269
pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
196270
let factors = factorize(modulus);
197271
let mut result = 1;
@@ -202,7 +276,15 @@ pub fn root_of_unity(modulus: i64, n: i64) -> i64 {
202276
result
203277
}
204278

205-
//ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
279+
/// ensure the root of unity satisfies sum_{j=0}^{n-1} omega^{jk} = 0 for 1 \le k < n
280+
/// # Arguments
281+
///
282+
/// * `omega` - n-th root of unity.
283+
/// * `n` - Order of the root of unity.
284+
/// * `modulus` - Modulus.
285+
///
286+
/// # Returns
287+
/// True if the root of unity satisfies the condition.
206288
pub fn verify_root_of_unity(omega: i64, n: i64, modulus: i64) -> bool {
207289
assert!(mod_exp(omega, n, modulus as i64) == 1, "omega is not an n-th root of unity");
208290
assert!(mod_exp(omega, n/2, modulus as i64) == modulus-1, "omgea^(n/2) != -1 (mod modulus)");

0 commit comments

Comments
 (0)