Radix-2 DIT
Radix-2 DIT FFT 由 Cooley 和 Tukey 于 1965 年发明,广泛应用于数字信号处理。
它是一种分治算法,递归地将问题拆分为更小的子问题。复杂度为 O(N log N)
Radix-2 表示点数是 2 的幂。DIT 是时间抽取,表示输入在时间域中被抽取。
问题背景
计算如下两个多项式的乘积:
$$ f(x) = a_0 + a_1 x + a_2 x^2 + a_3 x^3\\ g(x) = b_0 + b_1 x + b_2 x^2 + b_3 x^3 $$
得到
$$ \begin{align} h(x) = f(x) * g(x) &= (a_0 + a_1 x + a_2 x^2 + a_3 x^3) * (b_0 + b_1 x + b_2 x^2 + b_3 x^3) \\ &=c_0 + c_1 x + c_2 x^2 + c_3 x^3 + c_4 x^4 + c_5 x^5 + c_6 x^6 \end{align} $$
时间复杂度为 O(N^2),是否可以降低?
以上多项式的表现形式为系数形式,如果可以表示为点值形式:
$$ f(x) = (x_0, f(x_0)), (x_1, f(x_1)), (x_2, f(x_2)), (x_3, f(x_3)), (x_4, f(x_4)), (x_5, f(x_5)), (x_6, f(x_6)), (x_7, f(x_7)) $$
$$ g(x) = (x_0, g(x_0)), (x_1, g(x_1)), (x_2, g(x_2)), (x_3, g(x_3)), (x_4, g(x_4)), (x_5, g(x_5)), (x_6, g(x_6)), (x_7, g(x_7)) $$
就可以在 O(N) 时间复杂度内算出乘积:
$$ h(x) = (x_0, f(x_0) * g(x_0)), (x_1, f(x_1) * g(x_1)), (x_2, f(x_2) * g(x_2)), (x_3, f(x_3) * g(x_3)), (x_4, f(x_4) * g(x_4)), (x_5, f(x_5) * g(x_5)), (x_6, f(x_6) * g(x_6)), (x_7, f(x_7) * g(x_7)) $$
因为 h(x) 的度为 6,所以 f(x) 和 g(x) 的点数都被扩展到了 7,为方便起见,扩展到了 8,后面会解释。
有两个问题:
- 如何将多项式从系数形式转换为点值形式?
- 如何将多项式从点值形式转换为系数形式?
DFT
DFT(离散傅里叶变换,Discrete Fourier Transform) 是一种将离散信号从时域转换到频域的数学工具,它是傅里叶变换的离散版本。在复数域中计算,使用复数单位根 $e^{-i2\pi kn/N}$,公式:
$$ X_k=\sum^{N-1}_{n=0}x_n \cdot e^{-i\frac{2\pi}{N}kn} $$
输入和输出都是复数。
一个复数的 n 次单位根(complex nth root of unity) 是一个满足以下方程的复数 w:
$$ w^n=1 $$
其通用公式为:
$$ ω_n = e^{2πik/n}=cos(\frac{2πk}{n})+isin(\frac{2πk}{n}) $$
其中:
- $k=0,1,2,…,n−1$
- $i$ 是虚数单位($i^2=−1$)
所有 n 次单位根都位于复平面的单位圆上,相邻根之间的角度为 $\frac{2π}{n}$ 弧度。
如果 n 是 2 的幂,有 $w_n^i=-w_n^{n/2+i}$
公式可改写为
$$ X_k=\sum^{N-1}_{n=0}x_n \cdot e^{-i\frac{2\pi}{N}kn} = \sum^{N-1}_{n=0}x_n \cdot w_{N}^{-kn} $$
假如有一杯草莓、香蕉和橙子混合而成的果汁,时域就是这杯果汁,是一个随时间变化的整体;频域就是成分分析报告,如草莓占30%,香蕉占50%,橙子占20%。
由于使用复数存在不精确的问题,NTT 使用有限域来代替。
NTT
NTT(数论变换,Number Theoretic Transform)是 DFT 在有限域上的模拟。将 DFT 中的复数单位根替换为有限域中的 n 次本原单位根,公式:
$$ X_k=\sum^{N-1}_{n=0}x_n \cdot w^{kn} \space \text{mod} \space p $$
其中 $x$ 是系数,$w$ 是 n 次本原单位根。
输入和输出都是有限域元素,计算中没有浮点数和舍入误差,结果精确。
n 次本原单位根(primitive n-th root of unity)指的是一个数 ω,满足:
- $ω^n=1$ (是 n 次单位根,n-th root of unity)
- $ω^k \ne 1$ 对于任意 $1 \le k < p$ (是本原的,primitive root)
n 次本原单位根并不是唯一的。
如果 n 是 2 的幂,有 $w=-w^{n/2+i}$
FFT
FFT(快速傅里叶变换)是 DFT NTT 的高效实现算法:
- 复杂度:DFT 直接计算需 O(N^2),FFT 仅需 O(NlogN)
- 常见算法:Cooley-Tukey, Stockham, Rader
FFT 方法采用分治策略,分别使用 $f(x)$ 的偶数索引和奇数索引系数来定义两个新的度数为 n/2 的多项式 $f_{even}(x)$ 和 $f_{odd}(x)$。注意,这里 n 是 2 的幂:
$$ f(x) =f_{even}(x^2) + x * f_{odd}(x^2)\\ f(-x) = f_{even}(x^2) - x * f_{odd}(x^2)\\ \forall x \in \left\{1, w, w^2, ..., w^n\right\} $$
例如 p=17,$w=2$ 作为 8 次本原单位根,可以得到
$$ f(x) = 1+2x+3^2+4x^3=f_{even}(x^2) + x * f_{odd}(x^2)=(1+3x^2)+x(2+4x^2) \space \forall x \in \left\{1, 2, 4, 8, 16, 15, 13, 9\right\} $$
先计算
$$ f_{even}(x^2)=1+3x^2, \space f_{odd}(x^2)=2+4x^2 \space \forall x \in \left\{1, 4, 16, 13 \right\} $$
进行递归分割,直到 domain 长度为 1,然后递归合并,就可以得到 f(x) 的值了。
步骤如下:
- 先得到 $f_{even}(x^2)$ 和 $f_{odd}(x^2)$ 在 1, 2, 4, 8 上的值
- 然后通过组合 $f_{even}(x^2) + xf_{odd}(x^2)$ 得到 $f(x)$ 在 1, 2, 4, 8 上的值
- 然后通过组合 $f_{even}(x^2) - xf_{odd}(x^2)$ 得到 $f(x)$ 在 -1, -2, -4, -8 (实际上是 16, 15, 13, 9)上的值
这样 FFT 便实现了将多项式从系数形式转换为点值形式,IFFT 可反转回来,即将多项式从点值形式转换为系数形式。
FFT 递归形式:
# fft: from coefficients to evaluations
def fft(vals, domain):
if (len(domain) & (len(domain) - 1)) != 0:
raise ValueError("Domain length must be a power of 2.")
if len(vals) < len(domain):
if len(vals) == 0:
zero = Field(0)
else:
zero = vals[0] - vals[0]
vals = vals + [zero] * (len(domain) - len(vals))
if len(vals) == 1:
return vals
half_domain = halve_domain(domain)
f0 = fft(vals[::2], half_domain)
f1 = fft(vals[1::2], half_domain)
left = [L+x*R for L,R,x in zip(f0, f1, domain)]
right = [L-x*R for L,R,x in zip(f0, f1, domain)]
return left+right
FFT 迭代形式:
# 按时间抽取(DIT)的快速傅里叶变换(FFT)
def dit(self, a, roots):
n = len(a)
# 位反转重排
# 输入索引:[0,1,2,3](二进制:00,01,10,11)
# 位反转后:[0,2,1,3](二进制:00,10,01,11)
a = self.bit_reverse(a)
logn = n.bit_length() - 1
for s in range(1, logn + 1): # 分治阶段,从最小子问题(m=2)逐步扩展到整个序列(m=n)
m = 1 << s # 当前子问题大小 (2^s)
wm = roots[n // m] # 单位根 ω_m = ω^(n/m)
for k in range(0, n, m): # 处理每个子块
w = 1 # 初始化旋转因子 ω_m^0
for j in range(m // 2): # 蝶形运算
u = a[k + j]
v = self.gf.mul(w, a[k + j + m // 2])
a[k + j] = self.gf.add(u, v) # 上半部分
a[k + j + m // 2] = self.gf.sub(u, v) # 下半部分
w = self.gf.mul(w, wm) # 更新旋转因子
return a # 返回频域结果
Radix-2 bowers FFT
Bowers FFT 的主要目标是通过减少 twiddle factors 访问次数和优化内存访问模式来提高计算效率。
Dit 和 Dif FFT
在 Dit FFT 中,输入数据按位反转顺序排列,而输出按自然顺序排列,从 2-DFT 开始,然后是 4-DFT……直到 N-DFT
在 Dif FFT 中,输入数据按自然顺序排列,而输出按位反转顺序排列,从 N-DFT 开始,然后是 N/2-DFT……直到 2-DFT
在 DIT 和 DIF FFT 中,蝶形运算均采用迭代方法而非递归方法执行。
Bowers FFT
Bowers FFT 与 DIF FFT 非常相似,其核心区别在于 twiddle factors 的访问方式。以 8 点 FFT 为例,在 DIF FFT 中,不同阶段的 twiddle factors 如下:
第一层:
$$ w_8^0,w_8^1,w_8^2,w_8^3 $$
第二层:
$$ w_8^0,w_8^2,w_8^0,w_8^2 $$
第三层:
$$ w_8^0,w_8^0,w_8^0,w_8^0 $$
而在 Bower FFT 中,不同阶段的 twiddle factors 如下:
第一层:
$$ w_8^0,w_8^0,w_8^0,w_8^0 $$
第二层:
$$ w_8^0,w_8^0,w_8^2,w_8^2 $$
第三层:
$$ w_8^0,w_8^1,w_8^2,w_8^3 $$
重点关注第二层:在 Bowers FFT 中,内存访问会更加连续。
DIF FFT 和 bower FFT 如下:
def dif(self, a, roots):
n = len(a)
logn = n.bit_length() - 1
for s in range(logn, 0, -1):
m = 1 << s
wm = roots[n//m]
for k in range(0, n, m):
w = 1
for j in range(m // 2):
u = a[k + j]
v = a[k + j + m // 2]
a[k + j] = self.gf.add(u, v)
a[k + j + m // 2] = self.gf.mul(w, self.gf.sub(u, v))
w = self.gf.mul(w, wm)
return self.bit_reverse(a)
def bower_g(self,a):
n = len(a)
a = self.bit_reverse(a)
roots=self.get_forward_roots(n)
roots=self.bit_reverse(roots[:len(roots) // 2])
logn = n.bit_length() - 1
for s in range(1, logn + 1):
m = 1 << s
for k in range(0, n, m):
w = roots[k//m]
for j in range(m // 2):
u = a[k + j]
v = a[k + j + m // 2]
a[k + j] = self.gf.add(u, v)
a[k + j + m // 2] = self.gf.mul(w, self.gf.sub(u, v))
Four-step FFT
四步 FFT(也称为 Bailey FFT)是一种用于计算快速傅里叶变换(FFT)的高性能算法。它是 Cooley-Tukey FFT 算法的变体。
原理
- 先将数据(按自然顺序)排列成矩阵
- 使用标准 FFT 算法独立处理矩阵的每一列
- 矩阵的每个元素乘以一个校正系数(称为 twiddle factors)
- 使用标准 FFT 算法独立处理矩阵的每一行
具有如下优势:
- 将大规模 FFT 转换为两个较小的 FFT(行 FFT 和列 FFT),每个 FFT 每次仅对矩阵的一部分进行运算。与传统的 FFT 相比,这减少了随机内存访问,因为传统的 FFT 中的蝶形运算需要将数据分布在整个输入上
- 行 FFT 和列 FFT 相互独立,可以并行计算。FPGA 和 GPU 等现代硬件可以利用这种并行性同时计算多个 FFT
- 四步 FFT 不存储整个 FFT 输入/输出,而是处理适合有限片上 RAM 的较小块(例如矩阵的行或列)
代码:
class NTT:
# generate forward and inverse roots, bit-reverse, dit and dif FFT, forward and inverse dit or dif.
def __init__(self, modulus, n):
self.gf = Field(modulus)
self.n = n
def get_forward_roots(self,n):
return self.gf.roots_of_unity(n)
def get_inverse_roots(self,n):
forward_roots=self.gf.roots_of_unity(n)
return [self.gf.inv(r) for r in forward_roots]
def bit_reversed_indices(self, n):
logn = n.bit_length() - 1
return [int(f"{i:0{logn}b}"[::-1], 2) for i in range(n)]
def bit_reverse(self, a):
n = len(a)
indices = self.bit_reversed_indices(n)
return [a[i] for i in indices]
def matrix(self, a, log_rows, log_cols):
# transfer array into matrix
rows = 1 << log_rows
cols = 1 << log_cols
return np.array(a).reshape((rows, cols))
def transpose_and_flatten(self, matrix):
# Transpose the matrix, and flatten it
return [element for row in matrix.T for element in row]
def apply_twiddles(self, wm, matrix):
# each matrix[i,j] mul wm^(i*j), wm is the root of n-domain
n, m = matrix.shape
for i in range(n):
for j in range(m):
factor = self.gf.pow(wm, i * j)
matrix[i, j] = self.gf.mul(matrix[i, j], factor)
def apply_column_fft(self, matrix):
# do fft for each column in matrix
n_rows, n_cols = matrix.shape
for j in range(n_cols):
column = matrix[:, j].tolist()
fft_result = self.forward_dit(column)
matrix[:, j] = fft_result
def apply_row_fft(self,matrix):
# do fft for each row in matrix
for i in range(matrix.shape[0]):
matrix[i] = self.forward_dit(matrix[i].tolist())
def forward_dit(self, a):
roots=self.get_forward_roots(len(a))
return self.dit(a,roots)
def dit(self, a, roots):
n = len(a)
a = self.bit_reverse(a)
logn = n.bit_length() - 1
for s in range(1, logn + 1):
m = 1 << s
wm = roots[n//m]
for k in range(0, n, m):
w = 1
for j in range(m // 2):
u = a[k + j]
v = self.gf.mul(w, a[k + j + m // 2])
a[k + j] = self.gf.add(u, v)
a[k + j + m // 2] = self.gf.sub(u, v)
w = self.gf.mul(w, wm)
return a
def four_step(array, log_rows,modulus):
n = len(array)
logn = n.bit_length() - 1
log_cols = logn - log_rows
assert log_rows > 0
assert log_cols > 0
assert modulus > n
gf = Field(modulus)
ntt = NTT(modulus, n)
# first step: transfer the array into matrix
matrix = ntt.matrix(array, log_cols, log_rows)
print("origin matrix is:",matrix)
# second step: do FFT for each column
ntt.apply_column_fft(matrix)
print("after column fft, matrix is:",matrix)
# third step: apply twiddles wm^(i*j)
wm = ntt.get_forward_roots(n)[1]
ntt.apply_twiddles(wm, matrix)
# fourth step: do FFT for each row
ntt.apply_row_fft(matrix)
print("after row fft, matrix is:",matrix)
# Transpose the matrix, and flatten it into array
out_array = ntt.transpose_and_flatten(matrix)
print("after transpose and flatten, array is:",out_array)
return out_array
没有评论