写在前面的话:
快速傅里叶应用于多项式,来加速多项式乘法使其复杂度从o(
)变成o(n(
)),好了先一步步来,这个算法确实很难,但也不是完全理解不了,只是方式不对。
首先,正常的多项式相乘是:
A[n]=
+
x+
+
+~~+
,B[n]=
~~
;
C=A[n]*B[n]=
~~
正常肯定因为这个应该不能再被简化了,但是没有绝对的最简,人们觉得这个复杂度太高,没有艺术美,于是FFT诞生了,多项式再是局限于o(
)。
先讲清楚它的思想支柱-点值法。多个点可以确定一个函数,而一个函数也可以分解为多个点。
比如你解一个二元一次方程需要{a+bx=y, 2a+bx=y}如果想解出a,b的值是不是至少需要两组x,y的值。代入,然后解方程。换言之,表示一个n次多项式是不是可以用n组解来表示。记不记得高斯消元法,就是这个,n组解确定,代入方程,n组方程,利用高斯消元法,就可以得到一组解,就可以确定原方程。
下面问题来了,点值法可以用,但是如何找到n组合适的解呢?不然就算点值法做,复杂度也是不会变的。
好了主题来了,轮到虚数出场了,问题的关键就是找到合适的解。
先来看一下虚数的定义:a+bi,a是实根,b是虚根。如果a+bi的长度为一的话,a,b可以怎么表示呢,a=cos
,b=sin
。原式等于cos
+sin
i。这里扯一下著名的欧拉公式
,当
时,
+1=0。美!公式本身就是一个艺术。证明用泰勒展开。
好了回归正题,
,
,也就是角度乘2。
关键点来了:
设
为1,
表示
,
就是
对应上面的解矩阵,开始是1+x+x^2+x^3+…..+x^(n-1)表示一组解x=
。下一组解
=1,好了关键中的关键来了,第二组解就是在第一组解的基础上乘本身
,然后继续乘x=
,第三组同理。好了,开始神奇的变换了。
A(x)=a0+a1∗x+a2∗x^2+a3∗x^3+a4∗x^4+a5∗x^5+⋯+an−2∗x^n−2+an−1∗x^n−1。
A(x)=(a0+a2∗x^2+a4∗x^4+⋯+an−2∗x^n−2)+(a1∗x+a3∗x^3+a5∗x^5+⋯+an−1∗x^n−1)
A1(x)=a0+a2∗x+a4∗x^2+⋯+an−2∗x^n/2−1
A2(x)=a1+a3∗x+a5∗x^2+⋯+an−1∗x^n/2−1
A(x)=A1(x^2)+x*A2(x^2)
而上面已经证明
,A(
)=A1(
)+
A2(
)=A1(
)+
A2(
)
A(
)=A1(
)+
A(
)=A1(
)−
A2(
)=A1(
)-
A2(
)
现在我们就可以得出A(
) (k
)是前一半,而后一半直接就可以得出A(
),复杂度减一半,递归下去。
下面是递归代码实现:
double PI = acos(-1);
class complex {
private:
int r, i;
public:
complex(double _r=0,double _i=0):r(_r),i(_i){}
complex operator+(const complex &u) { return complex(r+u.r,i+u.i); }
complex operator-(const complex &u) { return complex(r - u.r, i - u.i); }
complex operator*(const complex &u) { return complex(r*u.r - i * u.i, r*u.i + i * u.r); }
};
void FFT(complex *a,int n,int type) {
if (n == 1)return;
complex *a1 = new complex[n >> 1], *a2 = new complex[n >> 1];
for (int i = 0; i <n; i+=2) {
a1[i>>1] = a[i];
a2[i >>1] = a[i + 1];
}
FFT(a1, n >> 1, type);//递归偶数部分
FFT(a2, n >> 1, type);//递归奇数部分
complex wn(cos(2 * PI / n), type*sin(2 * PI / n)), w(1, 0);
/*这段代码是核心,a[0]~a[n]相当于y[0]~y[n]
a[i + n >> 1] = a1[i] - w * a2[i];相当于计算下一半
*/
for (int i = 0; i < n >> 1; i++,w=w*wn) {
a[i] = a1[i] + w * a2[i];//前一半
a[i + n >> 1] = a1[i] - w * a2[i];//后一半
}
delete &a1; delete &a2;
}
现在已经将系数转换为点值了,点值相乘,然后再转换为系数。
XA=Y,
,A就是解方程的系数矩阵,
那么现在只需要求X的逆矩阵就可以了,回到上面的欧拉公式,
,
所以只需要将他的符号改一下就变成了逆矩阵
这又是一个FFT只不过符号要变为负的,最后值要除n。
c++complex模板实现:
#include <iostream>
using namespace std;
#include <complex>
#include <cmath>
double PI = acos(-1);
complex<double> a[400010], b[400010], c[400010];
void fft(complex<double> *a, int n, int op)
{
if (n == 1) return;
complex<double> w(1, 0), wn(cos(2 * PI*op / n), sin(2 * PI*op / n));
complex<double>*a1 = new complex<double>[n >> 1], *a2 = new complex<double>[n >> 1];
for (int i = 0; i < (n >> 1); i++)
a1[i] = a[i << 1], a2[i] = a[(i << 1) + 1];
fft(a1, n >> 1, op), fft(a2, n >> 1, op);
for (int i = 0; i < (n >> 1); i++, w *= wn)
a[i] = a1[i] + w * a2[i], a[i + (n >> 1)] = a1[i] - w * a2[i];
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; i++) scanf("%lf", &a[i]);
for (int i = 0; i <= m; i++) scanf("%lf", &b[i]);
m += n, n = 1;
while (n <= m) n <<= 1;
fft(a, n, 1), fft(b, n, 1);
for (int i = 0; i < n; i++) c[i] = a[i] * b[i];
fft(c, n, -1);
for (int i = 0; i <= m; i++) printf("%d ", int(c[i].real() / n + 0.5));
//system("pause");
return 0;
}
实际算法中使用的是迭代法:
奇偶数可以用二进制反转来分,这样就不需要递归了,网上找的迭代法
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <complex>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 1000005;
const double PI = acos(-1);
typedef complex <double> cp;
char sa[N], sb[N];
int n = 1, lena, lenb, res[N];
cp a[N], b[N], omg[N], inv[N];
void init(){
for(int i = 0; i < n; i++){
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(cp *a, cp *omg){
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++){
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));//每次移位,比较第一位看是否为1,如果是1,那么对应的另一端要变为1
if(i < t) swap(a[i], a[t]); // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
for(int l = 2; l <= n; l *= 2){
int m = l / 2;
for(cp *p = a; p != a + n; p += l)
for(int i = 0; i < m; i++){
cp t = omg[n / l * i] * p[i + m];//这步看不懂的可以推一下总和等于奇数加偶数的公式,每次提取出的数是从小到大的
p[i + m] = p[i] - t;
p[i] += t;
}
}
}
int main(){
scanf("%s%s", sa, sb);
lena = strlen(sa), lenb = strlen(sb);
while(n < lena + lenb) n *= 2;
for(int i = 0; i < lena; i++)
a[i].real(sa[lena - 1 - i] - '0');
for(int i = 0; i < lenb; i++)
b[i].real(sb[lenb - 1 - i] - '0');
init();
fft(a, omg);
fft(b, omg);
for(int i = 0; i < n; i++)
a[i] *= b[i];
fft(a, inv);
for(int i = 0; i < n; i++){
res[i] += floor(a[i].real() / n + 0.5);
res[i + 1] += res[i] / 10;
res[i] %= 10;
}
for(int i = res[lena + lenb - 1] ? lena + lenb - 1: lena + lenb - 2; i >= 0; i--)
putchar('0' + res[i]);
enter;
return 0;
}