55#include < vector>
66
77template <int md> struct ModInt {
8+ static_assert (md > 1 );
89 using lint = long long ;
910 constexpr static int mod () { return md; }
1011 static int get_primitive_root () {
@@ -102,39 +103,50 @@ template <int md> struct ModInt {
102103 return this ->pow (md - 2 );
103104 }
104105 }
105- constexpr ModInt fac () const {
106- while (this ->val_ >= int (facs.size ())) _precalculation (facs.size () * 2 );
107- return facs[this ->val_ ];
106+
107+ constexpr static ModInt fac (int n) {
108+ assert (n >= 0 );
109+ if (n >= md) return ModInt (0 );
110+ while (n >= int (facs.size ())) _precalculation (facs.size () * 2 );
111+ return facs[n];
108112 }
109- constexpr ModInt facinv () const {
110- while (this ->val_ >= int (facs.size ())) _precalculation (facs.size () * 2 );
111- return facinvs[this ->val_ ];
113+
114+ constexpr static ModInt facinv (int n) {
115+ assert (n >= 0 );
116+ if (n >= md) return ModInt (0 );
117+ while (n >= int (facs.size ())) _precalculation (facs.size () * 2 );
118+ return facinvs[n];
112119 }
113- constexpr ModInt doublefac () const {
114- lint k = (this ->val_ + 1 ) / 2 ;
115- return (this ->val_ & 1 ) ? ModInt (k * 2 ).fac () / (ModInt (2 ).pow (k) * ModInt (k).fac ())
116- : ModInt (k).fac () * ModInt (2 ).pow (k);
120+
121+ constexpr static ModInt doublefac (int n) {
122+ assert (n >= 0 );
123+ if (n >= md) return ModInt (0 );
124+ long long k = (n + 1 ) / 2 ;
125+ return (n & 1 ) ? ModInt::fac (k * 2 ) / (ModInt (2 ).pow (k) * ModInt::fac (k))
126+ : ModInt::fac (k) * ModInt (2 ).pow (k);
117127 }
118128
119- constexpr ModInt nCr (int r) const {
120- if (r < 0 or this ->val_ < r) return ModInt (0 );
121- return this ->fac () * (*this - r).facinv () * ModInt (r).facinv ();
129+ constexpr static ModInt nCr (int n, int r) {
130+ assert (n >= 0 );
131+ if (r < 0 or n < r) return ModInt (0 );
132+ return ModInt::fac (n) * ModInt::facinv (r) * ModInt::facinv (n - r);
122133 }
123134
124- constexpr ModInt nPr (int r) const {
125- if (r < 0 or this ->val_ < r) return ModInt (0 );
126- return this ->fac () * (*this - r).facinv ();
135+ constexpr static ModInt nPr (int n, int r) {
136+ assert (n >= 0 );
137+ if (r < 0 or n < r) return ModInt (0 );
138+ return ModInt::fac (n) * ModInt::facinv (n - r);
127139 }
128140
129141 static ModInt binom (int n, int r) {
130142 static long long bruteforce_times = 0 ;
131143
132144 if (r < 0 or n < r) return ModInt (0 );
133- if (n <= bruteforce_times or n < (int )facs.size ()) return ModInt (n). nCr (r);
145+ if (n <= bruteforce_times or n < (int )facs.size ()) return ModInt:: nCr (n, r);
134146
135147 r = std::min (r, n - r);
136148
137- ModInt ret = ModInt (r). facinv ();
149+ ModInt ret = ModInt:: facinv (r );
138150 for (int i = 0 ; i < r; ++i) ret *= n - i;
139151 bruteforce_times += r;
140152
@@ -148,18 +160,23 @@ template <int md> struct ModInt {
148160 int sum = 0 ;
149161 for (int k : ks) {
150162 assert (k >= 0 );
151- ret *= ModInt (k). facinv (), sum += k;
163+ ret *= ModInt:: facinv (k ), sum += k;
152164 }
153- return ret * ModInt (sum).fac ();
165+ return ret * ModInt::fac (sum);
166+ }
167+ template <class ... Args> static ModInt multinomial (Args... args) {
168+ int sum = (0 + ... + args);
169+ ModInt result = (1 * ... * ModInt::facinv (args));
170+ return ModInt::fac (sum) * result;
154171 }
155172
156- // Catalan number, C_n = binom(2n, n) / (n + 1)
173+ // Catalan number, C_n = binom(2n, n) / (n + 1) = # of Dyck words of length 2n
157174 // C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ...
158175 // https://oeis.org/A000108
159176 // Complexity: O(n)
160177 static ModInt catalan (int n) {
161178 if (n < 0 ) return ModInt (0 );
162- return ModInt (n * 2 ). fac () * ModInt (n + 1 ). facinv () * ModInt (n). facinv ();
179+ return ModInt::fac (n * 2 ) * ModInt::facinv (n + 1 ) * ModInt:: facinv (n );
163180 }
164181
165182 ModInt sqrt () const {
0 commit comments