(一)多项式乘法(无模数)

核心算法:FFT算法(本质上是利用复数的循环性,并通过分治来优化)

单位根的性质

  1. ωnk=cos(2kπn)+sin(2kπn)i{\omega_n}^k = cos(\frac{2k\pi}n)+sin(\frac{2k\pi}n)i

  2. ωn0=ωnn{\omega_n}^0={\omega_n}^n

  3. ω2n2k=ωnk\omega_{2n}^{2k}={\omega_n^k}

  4. ωnk+n2=ωnk\omega_n^{k+\frac{n}2}=-\omega_n^{k}

正变换(即由系数表达式转变为点值表达式)

A(x)=a0+a1x1+a2x2+...an2xn1A(x)=a_0+a_1x^1+a_2x^2+...a_{n-2}x^{n-1}

A1(x)=a0+a2x2+...+an2xn21A_1(x)=a_0+a_2x^2+...+a_{n-2}x^{\frac{n}2-1},A2(x)=a1+a3x1+a5x2+...an1xn21A_2(x)=a_1+a_3x^1+a_5x^2+...a_{n-1}x^{\frac{n}2-1}

A(x)=A1(x2)+xA2(x2)A(x)=A_1(x^2)+xA_2(x^2),把x=ωnk(k[0,n21])x=\omega_n^k(k\in[0,\frac{n}2-1])带入得:
A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)=A1(ωn2k)+ωnkA2(ωn2k)A(\omega_n^k)=A_1(\omega_n^{2k})+\omega_n^kA_2(\omega_n^{2k})=A_1(\omega_{\frac{n}2}^k)+\omega_n^kA_2(\omega_{\frac{n}2}^k)
x=ωnk(k+n2[n2,n1])x=\omega_n^k(k+\frac{n}2\in[\frac{n}2,n-1])带入得:

A(ωnk+n2)=A1(ωn2k)+ωnk+n2A2(ωn2k)=A1(ωn2k)ωnkA2(ωn2k)A(\omega_n^{k+\frac{n}2})=A_1(\omega_n^{2k})+\omega_n^{k+\frac{n}2}A_2(\omega_n^{2k})=A_1(\omega_{\frac{n}2}^k)-\omega_n^kA_2(\omega_{\frac{n}2}^k)

A(ωnk)A(\omega_n^k)可由A1(ωn2k)A_1(\omega_{\frac{n}2}^k)A2(ωn2k)A_2(\omega_{\frac{n}2}^k)来表示,则可以递归求A(ωnk)A(\omega_n^k),递归层数为logn\log{n},故时间复杂度为nlognn\log{n}

逆变换(即由点值表达式转变为由系数表达式)

逆变换问题转换为:已知点值表达式(ωnk,A(ωnk)\omega_n^k,A(\omega_n^k)),k[0,n1]k\in[0,n-1],求A(x)=a0+a1x+a2x2+...+an1xn1A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}aa的值。

引理:nak=ck=i=0n1yi(ωnk)ina_k=c_k=\sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i((其中yi=A(ωni)y_i=A(\omega_n^i)))

证明:nak=ckna_k=c_k

ck=i=0n1yi(ωnk)i=i=0n1A(ωni)(ωnk)i=i=0n1j=0n1aj(ωni)j(ωnk)i=i=0n1j=0n1aj(ωnj)i(ωnk)i=i=0n1j=0n1aj(ωnjk)i=j=0n1aji=0n1(ωnjk)i\begin{aligned} c_k &= \sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i\\ &=\sum_{i=0}^{n-1}A(\omega_n^i)(\omega_n^{-k})^i\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(\omega_n^i)^j(\omega_n^{-k})^i\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(\omega_n^j)^i(\omega_n^{-k})^i\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(\omega_n^{j-k})^i\\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\\ \end{aligned}

下求i=0n1(ωnjk)i\sum_{i=0}^{n-1}(\omega_n^{j-k})^i
S(x)=1+x+x2+...+xn1S(x)=1+x+x^2+...+x^{n-1},即求S(ωnk)S(\omega_n^{k})
1.当k0k\ne0
S(ωnk)=1+ωnk+ωn2k+...+ωn(n1)kS(\omega_n^k)=1+\omega_n^k+\omega_n^{2k}+...+\omega_n^{(n-1)k},两边同时乘以ωnk\omega_n^k得:

ωnkS(ωnk)=ωnk+ωn2k+ωn3k+...+ωnnk=ωnk+ωn2k+ωn3k+...+1=S(ωnk)\omega_n^kS(\omega_n^k)=\omega_n^k+\omega_n^{2k}+\omega_n^{3k}+...+\omega_n^{nk}=\omega_n^k+\omega_n^{2k}+\omega_n^{3k}+...+1=S(\omega_n^k)

(1ωnk)S(ωnk)=0(1-\omega_n^k)S(\omega_n^k)=0,由k[0,n1]k\in[0,n-1]S(ωnk)=0S(\omega_n^k)=0

2.当k=0k=0时,代入表达式可得S(ωnk)=nS(\omega_n^k)=n;

故:

ck=i=0n1yi(ωnk)i=j=0n1aji=0n1(ωnjk)i=nak\begin{aligned} c_k&=\sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i\\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\\ &=na_k \end{aligned}

则原多项式系数为ak=ckna_k=\frac{c_k}n
因此只需把A(ωnk)A(\omega_n^k)当成多项式的系数,然后做一次正变换即可。
但是要注意FFTFFT的精度问题,当数据过大时会超出FFTFFT的精度范围,此时不能使用FFTFFT,应当转变思路。

FFTFFT在实现时的两个优化:

  1. 由于递归的方式求会使算法的常数过大,因此要改写为递归版本,利用蝴蝶变换来实现。

  2. 三步变两步优化:设aabb为实多项式,F=a+biF=a+bi,则F2=a2b2+2abiF^2=a^2-b^2+2abi,我们要求的正是abab,为F2F^2虚部的二分之一,这样只需两次FFTFFT即可。所以可以把b(x)b(x)放到a(x)a(x)的虚部上去,从而求出a(x)2a(x)^2,然后把a(x)2a(x)^2的虚部拿出来除以22即可。(但是要注意三步变两步优化会使FFTFFT算法的精度变低)

例题(洛谷P3803)

给定一个 nn 次多项式 F(x)F(x),和一个 mm 次多项式 G(x)G(x)
请求出 F(x){F(x)}G(x)G(x) 的卷积。
保证输入中的系数大于等于 0{0} 且小于等于 9{9}
对于 100%{100\%} 的数据:1n,m106{1 \le n, m \leq {10}^6}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>

using namespace std;

const int N = 3e5 + 5;
const double PI = acos(-1);

struct Complex{
double a , b;
Complex operator+ (const Complex &w)const{
return {a + w.a , b + w.b};
}
Complex operator- (const Complex &w)const{
return {a - w.a , b - w.b};
}
Complex operator* (const Complex &w)const{
return {a * w.a - b * w.b , a * w.b + b * w.a};
}
}a[N] , b[N];
int n , m;
int rev[N] , bit;

void FFT(Complex a[] , int t){
for(int i = 0 ; i < (1 << bit) ; i ++)
if(i < rev[i])
swap(a[i] , a[rev[i]]);
for(int mid = 1 ; mid < (1 << bit) ; mid <<= 1){
auto w1 = Complex({cos(PI / mid) ,t * sin(PI / mid)});
for(int i = 0 ; i < (1 << bit) ; i += mid * 2){
auto wk = Complex({1 , 0});
for(int j = 0 ; j < mid ; j ++ , wk = wk * w1){
auto x = a[i + j] , y = wk * a[i + j + mid];
a[i + j] = x + y , a[i + j + mid] = x - y;
}
}
}
}

int main(){
cin >> n >> m;
for(int i = 0 ; i <= n ; i ++)cin >> a[i].a;
for(int j = 0 ; j <= m ; j ++)cin >> b[j].a;
while((1 << bit) < n + m + 1)bit ++;
for(int i = 0 ; i < (1 << bit) ; i ++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
FFT(a ,1) , FFT(b , 1);
for(int i = 0 ; i < (1 << bit) ; i ++)a[i] = a[i] * b[i];
FFT(a , -1);
for(int i = 0 ; i <= n + m ; i ++)printf("%d " , (int)(a[i].a / (1 << bit) + 0.5));
return 0;
}

(二)多项式乘法(有模数且模数有原根)

前置知识:

  1. 原根:设mNm\in N^*gZg\in Z.若(g,m)=1(g,m)=1,且δm(g)=φ(m)\delta_m(g)=\varphi(m),则称gg为模mm的原根。

  2. 原根的判定定理:设m3m\ge3,(g,m)=1(g,m)=1,则gg是模mm的原根的充要条件为,对于φ(m)\varphi(m)的每个素因子pp,都有gφ(m)p≢1(mod m)g^{\frac{\varphi(m)}p}\not\equiv1(mod\ m).

  3. 原根的性质:
    gn=gp1ng_n=g^{\frac{p-1}n}
     gnk=gnk1gn\bullet\ g_n^k=g_n^{k-1}\cdot g_n
     gn0=gnn=1\bullet\ g_n^0=g_n^n=1
     g2n2k=gnk\bullet\ g_{2n}^{2k}=g_n^k
     gnn2=gp1=1\bullet\ g_n^{\frac{n}2}=g^{p-1}=-1

核心算法:NTT算法(本质上是利用整数在模的意义下的循环性)

NTTNTT算法的推导与实现和FFTFFT相差无几,只需把FFTFFT中的ωnk\omega_n^k全部都换成gnkg_n^k即可,类比去做就行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>

using namespace std;

typedef long long LL;

const int p = 998244353, G = 3, Gi = 332748118;//这里的Gi是G的除法逆元
const int N = 5000007;

const double PI = acos(-1);

int n, m;
int res, ans[N];
int limit = 1;//
int L;//二进制的位数
int RR[N];
LL a[N], b[N];

inline int read()
{
register int x = 0, f = 1;
register char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-')f = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}

LL qpow(LL a, LL b)
{
LL res = 1;
while(b) {
if(b & 1) res = res * a % p;
a = a * a % p;
b >>= 1;
}
return res % p;
}

LL inv(LL x) {return qpow(x, p - 2);}

void NTT(LL *A, int type)
{
for(int i = 0; i < limit; ++ i)
if(i < RR[i])
swap(A[i], A[RR[i]]);
for(int mid = 1; mid < limit; mid <<= 1) {//原根代替单位根
//LL wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1));
LL wn = qpow(G, (p - 1) / (mid * 2));
if(type == -1) wn = qpow(wn, p - 2);
//如果超时了上面if这句话删掉,在下面的if(type == -1)里加上下面这个循环
/*for (int i = 1; i < limit / 2; i ++)
swap(A[i], A[limit - i]); */
//逆变换则乘上逆元,因为我们算出来的公式中逆变换是(a^-ij),也就是(a^ij)的逆元
for(int len = mid << 1, pos = 0; pos < limit; pos += len) {
LL w = 1;
for(int k = 0; k < mid; ++ k, w = (w * wn) % p) {
int x = A[pos + k], y = w * A[pos + mid + k] % p;
A[pos + k] = (x + y) % p;
A[pos + k + mid] = (x - y + p) % p;

}
}
}

if(type == -1) {
LL limit_inv = inv(limit);//N的逆元(N是limit, 指的是2的整数幂)
for(int i = 0; i < limit; ++ i)
A[i] = (A[i] * limit_inv) % p;//NTT还是要除以n的,但是这里把除换成逆元了,inv就是n在模p意义下的逆元
}
}//代码实现上和FFT相差无几
//多项式乘法
void poly_mul(LL *a, LL *b, int deg)
{
for(limit = 1, L = 0; limit <= deg; limit <<= 1) L ++ ;
for(int i = 0; i < limit; ++ i) {
RR[i] = (RR[i >> 1] >> 1) | ((i & 1) << (L - 1));
}
NTT(a, 1);
NTT(b, 1);
for(int i = 0; i < limit; ++ i) a[i] = a[i] * b[i] % p;
NTT(a, -1);
}

int main(){
cin >> n >> m;
for(int i = 0 ; i <= n ; i ++)a[i] = (read() + p) % p;
for(int i = 0 ; i <= m ; i ++)b[i] = (read() + p) % p;
poly_mul(a , b , n + m);
for(int i = 0 ; i <= n + m ; i ++)
cout << a[i] << ' ';
return 0;
}

例题(Problem E of The 2021 ICPC Asia Macau Regional Contest)

\quadThere are nn children playing with nn balls. Both children and balls are numbered from 11 to nn.
\quadBefore the game, nn integers p1,p2,,pnp_1, p_2, \cdots, p_n are given. In each round of the game, child ii will pass the ball he possesses to child pip_i. It is guaranteed that no child will pass his ball to himself, which means piip_i \neq i. Moreover, we also know that after each round, each child will hold exactly one ball.
\quadLet bib_i be the ball possessed by child ii. At the beginning of the game, child ii (1in1 \le i \le n) will be carrying ball ii, which means bi=ib_i=i initially. You’re asked to process qq queries. For each query you’re given an integer kk and you need to compute the value of i=1ni×bi\sum\limits_{i=1}^{n} i \times b_i after kk rounds.
Input
\quadThere is only one test case for each test file.
\quadThe first line of the input contains two integers nn (2n1052 \le n \le 10^5) and qq (1q1051 \le q \le 10^5), indicating the number of children and the number of queries.
\quadThe second line contains nn integers p1,p2,,pnp_1, p_2, \cdots, p_n (1pin1 \le p_i \le n) indicating how the children pass the balls around.
\quadFor the following qq lines, the ii-th line contains one integer kik_i (1ki1091 \le k_i \le 10^9) indicating a query asking for the result after kik_i rounds.

思路:我们首先可以将iipip_i连一条有向边,那么最终可以形成有若干个环组成的有向图。关于计算价值,可以把环拆开,把每个环上的数再复制一遍接到后面去,实现拆环,然后发现传递的这个过程和卷积相同,做NTTNTT即可,时间复杂度 为nlogn+qnnlogn+q\sqrt{n},但是要注意此题不能使用FFTFFT,因为FFTFFT算法的精度有限,因为会涉及到小数发生精度丢失,故只能使用NTTNTT来求解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
#include <set>
#include <map>
#include <cmath>
#include <stack>
#include <queue>
#include <ctime>
#include <vector>
#include <cstdio>
#include <random>
#include <chrono>
#include <bitset>
#include <cstring>
#include <sstream>
#include <iomanip>
#include <cassert>
#include <iostream>
#include <algorithm>
#include <unordered_map>
#define ios ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#define TIME cerr << 1e3 * clock() / CLOCKS_PER_SEC << " ms\n";
#define debug(x) cerr << #x << " : " << x << '\n'
#define all(v) v.begin() , v.end()
// std::mt19937 rnd(std::chrono::system_clock::now().time_since_epoch().count());

#define x first
#define y second
#define int long long
// #define double long double

using namespace std;

typedef long long LL;
typedef pair<int , int> PII;

const int N = 4e5 + 5;
const int INF = 0x3f3f3f3f;
const int mod = 998244353;

int read(){
int x = 0 , f = 1;
char c = getchar();
while(c < '0' || c > '9'){
if(c == '-')f = -1;
c = getchar();
}
while(c >= '0' && c <= '9'){
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}

template <typename T> inline void write(T x) {
static int stk[100], top = 0;
if (x == 0) return (void)putchar('0');
if (x < 0) x = -x, putchar('-');
while (x) stk[++top] = x % 10, x /= 10;
while (top) putchar(stk[top--] + '0');
}

int p[N];
vector<LL> q[N] , s[N];
vector<vector<LL>> f[N];

bool st[N];
vector<int> e[N];
int fa[N];
int len[N];

int find(int x){
if(x != fa[x])fa[x] = find(fa[x]);
return fa[x];
}

void dfs(int u, int fa){
if(st[u])return ;
st[u] = 1;
q[find(u)].push_back(u);
for(auto c : e[u]){
if(c == fa)continue;
dfs(c , u);
}
}

//C++版本17以上
using i64 = long long;

using db = long double;

struct comp {
db x, y;
comp(db real = 0, db imag = 0) : x(real), y(imag) {};
comp operator+(comp b) const { return { x + b.x, y + b.y }; }
comp operator-(comp b) const { return { x - b.x, y - b.y }; }
comp operator*(comp b) const { return { x * b.x - y * b.y, x * b.y + y * b.x }; }
};

const db pi = acosl(-1.L);
std::vector<comp> fft_init(int L) {
std::vector<comp> w(L, 1);
for (int i = 2; i < L; i *= 2) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
comp wn(cosl(pi / i), sinl(pi / i));
for (int j = 0; j < i; j += 2) {
w1[j] = w0[j / 2];
w1[j + 1] = w1[j] * wn;
}
}
return w;
}

auto W = fft_init(1 << 20); // Notice !

void dft(std::vector<comp>& a) {
comp x, y;
const int n = a.size();
for (int k = n / 2; k; k /= 2) {
for (int i = 0; i < n; i += k * 2) {
for (int j = 0; j < k; j++) {
x = a[i + j], y = a[i + j + k];
a[i + j + k] = (a[i + j] - y) * W[k + j], a[i + j] = x + y;
}
}
}
}

void idft(std::vector<comp>& a) {
const int n = a.size();
comp x, y;
for (int k = 1; k < n; k *= 2) {
for (int i = 0; i < n; i += k * 2) {
for (int j = 0; j < k; j++) {
x = a[i + j], y = a[i + j + k] * W[k + j];
a[i + j + k] = x - y, a[i + j] = x + y;
}
}
}
for (int i = 0; i < n; i++) {
a[i].x /= n;
a[i].y /= n;
}
std::reverse(a.begin() + 1, a.end());
}

template<class T>
struct Poly : public std::vector<T> {
Poly() : std::vector<T>() {}
explicit constexpr Poly(int n) : std::vector<T>(n) {}
explicit constexpr Poly(const std::vector<T>& a) : std::vector<T>(a) {}
constexpr Poly(const std::initializer_list<T>& a) : std::vector<T>(a) {}

constexpr static int norm(int n) {
return 1 << (std::__lg(n - 1) + 1);
}

friend constexpr Poly operator+(const Poly& a, const Poly& b) {
Poly res(std::max(a.size(), b.size()));
for (int i = 0; i < a.size(); i++) {
res[i] += a[i];
}
for (int i = 0; i < b.size(); i++) {
res[i] += b[i];
}
return res;
}
friend constexpr Poly operator*(i64 k, const Poly& a) {
Poly ans{};
for (auto i : a) {
ans.push_back(k * i);
}
return ans;
}

friend constexpr Poly operator*(const Poly& a, const Poly& b) {
int n = a.size() + b.size() - 1;
std::vector<comp> c(norm(n));
for (int i = 0; i < a.size(); i++) {
c[i].x = a[i];
}
for (int i = 0; i < b.size(); i++) {
c[i].y = b[i];
}
dft(c);
for (auto& x : c) {
x = x * x;
}
idft(c);
Poly ans(n);
for (int i = 0; i < n; i++) {
ans[i] = T(c[i].y * .5L + .5L);
}
return ans;
}

constexpr Poly& operator*=(const Poly& b) {
return (*this) = (*this) * b;
}
};

void solve(){
int n , t;
cin >> n >> t;
for(int i = 1 ; i <= n ; i ++)fa[i] = i;
for(int i = 1 ; i <= n ; i ++){
cin >> p[i];
e[i].push_back(p[i]);
e[p[i]].push_back(i);
int pa = find(i) , pb = find(p[i]);
if(pa != pb)fa[pa] = pb;
}
for(int i = 1 ; i <= n ; i ++){
if(st[find(i)] == 0){
dfs(i , -1);
}
}
for(int i = 1 ; i <= n ; i ++){
if(q[i].size() == 0)continue;
int k = q[i].size() - 1;

Poly<LL> c(k + 1) , d(k * 2 + 2);
for(int j = 0 ; j <= k ; j ++)c[j] = q[i][j];
for(int j = 0 ; j <= k ; j ++)d[k - j] = q[i][j];
for(int j = 0 ; j <= k ; j ++)d[k + j + 1] = d[j];
c = c * d;
vector<LL> s;
for(int j = k ; j <= k * 2 ; j ++)s.push_back(c[j]);
f[k + 1].push_back(s);

}
int cnt = 0;
for(int i = 1 ; i <= n ; i ++){
if(f[i].size()){
cnt ++;
len[cnt] = i;
for(int j = 0 ; j <= i ; j ++)s[cnt].push_back(0);
for(auto c : f[i]){
for(int k = 0 ; k < c.size() ; k ++){
// cout << k << ' ' << c[k] << endl;
s[cnt][k] += c[k];
}
}
}
}
while(t --){
int k;
cin >> k;
LL res = 0;
for(int i = 1 ; i <= cnt ; i ++){
res = (res + s[i][k % len[i]]);
}
cout << res << endl;
}
}

signed main(){
ios;
// freopen("test.in","r",stdin);
// freopen("test.out","w",stdout);
int T = 1;
// cin >> T;
while(T --){
solve();
}
return 0;
}
/* stuff you should look for
* int overflow, array bounds
* special cases (n=1?)
* do smth instead of nothing and stay organized
* WRITE STUFF DOWN
* DON'T GET STUCK ON ONE APPROACH
*/