Radix-2 DIT

Radix-2 DIT FFT 由 Cooley 和 Tukey 于 1965 年发明,广泛应用于数字信号处理。

它是一种分治算法,递归地将问题拆分为更小的子问题。复杂度为 O(N log N)

Radix-2 表示点数是 2 的幂。DIT 是时间抽取,表示输入在时间域中被抽取。

问题背景

计算如下两个多项式的乘积:

$$ f(x) = a_0 + a_1 x + a_2 x^2 + a_3 x^3\\ g(x) = b_0 + b_1 x + b_2 x^2 + b_3 x^3 $$

得到

$$ \begin{align} h(x) = f(x) * g(x) &= (a_0 + a_1 x + a_2 x^2 + a_3 x^3) * (b_0 + b_1 x + b_2 x^2 + b_3 x^3) \\ &=c_0 + c_1 x + c_2 x^2 + c_3 x^3 + c_4 x^4 + c_5 x^5 + c_6 x^6 \end{align} $$

时间复杂度为 O(N^2),是否可以降低?

以上多项式的表现形式为系数形式,如果可以表示为点值形式:

$$ f(x) = (x_0, f(x_0)), (x_1, f(x_1)), (x_2, f(x_2)), (x_3, f(x_3)), (x_4, f(x_4)), (x_5, f(x_5)), (x_6, f(x_6)), (x_7, f(x_7)) $$

$$ g(x) = (x_0, g(x_0)), (x_1, g(x_1)), (x_2, g(x_2)), (x_3, g(x_3)), (x_4, g(x_4)), (x_5, g(x_5)), (x_6, g(x_6)), (x_7, g(x_7)) $$

就可以在 O(N) 时间复杂度内算出乘积:

$$ h(x) = (x_0, f(x_0) * g(x_0)), (x_1, f(x_1) * g(x_1)), (x_2, f(x_2) * g(x_2)), (x_3, f(x_3) * g(x_3)), (x_4, f(x_4) * g(x_4)), (x_5, f(x_5) * g(x_5)), (x_6, f(x_6) * g(x_6)), (x_7, f(x_7) * g(x_7)) $$

因为 h(x) 的度为 6,所以 f(x) 和 g(x) 的点数都被扩展到了 7,为方便起见,扩展到了 8,后面会解释。

有两个问题:

  • 如何将多项式从系数形式转换为点值形式?
  • 如何将多项式从点值形式转换为系数形式?

DFT

DFT(离散傅里叶变换,Discrete Fourier Transform) 是一种将离散信号从时域转换到频域的数学工具,它是傅里叶变换的离散版本。在复数域中计算,使用复数单位根 $e^{-i2\pi kn/N}$,公式:

$$ X_k=\sum^{N-1}_{n=0}x_n \cdot e^{-i\frac{2\pi}{N}kn} $$

输入和输出都是复数。

一个复数的 n 次单位根(complex nth root of unity) 是一个满足以下方程的复数 w:

$$ w^n=1 $$

其通用公式为:

$$ ω_n = e^{2πik/n}=cos(\frac{2πk}{n})+isin(\frac{2πk}{n}) $$

其中:

  • $k=0,1,2,…,n−1$
  • $i$ 是虚数单位($i^2=−1$)

所有 n 次单位根都位于复平面的单位圆上,相邻根之间的角度为 $\frac{2π}{n}$ 弧度。

如果 n 是 2 的幂,有 $w_n^i=-w_n^{n/2+i}$

公式可改写为

$$ X_k=\sum^{N-1}_{n=0}x_n \cdot e^{-i\frac{2\pi}{N}kn} = \sum^{N-1}_{n=0}x_n \cdot w_{N}^{-kn} $$

假如有一杯草莓、香蕉和橙子混合而成的果汁,时域就是这杯果汁,是一个随时间变化的整体;频域就是成分分析报告,如草莓占30%,香蕉占50%,橙子占20%。

由于使用复数存在不精确的问题,NTT 使用有限域来代替。

NTT

NTT(数论变换,Number Theoretic Transform)是 DFT 在有限域上的模拟。将 DFT 中的复数单位根替换为有限域中的 n 次本原单位根,公式:

$$ X_k=\sum^{N-1}_{n=0}x_n \cdot w^{kn} \space \text{mod} \space p $$

其中 $x$ 是系数,$w$ 是 n 次本原单位根。

输入和输出都是有限域元素,计算中没有浮点数和舍入误差,结果精确。

n 次本原单位根(primitive n-th root of unity)指的是一个数 ω,满足:

  • $ω^n=1$ (是 n 次单位根,n-th root of unity)
  • $ω^k \ne 1$ 对于任意 $1 \le k < p$ (是本原的,primitive root)

n 次本原单位根并不是唯一的。

如果 n 是 2 的幂,有 $w=-w^{n/2+i}$

FFT

FFT(快速傅里叶变换)是 DFT NTT 的高效实现算法

  • 复杂度:DFT 直接计算需 O(N^2),FFT 仅需 O(NlogN)
  • 常见算法:Cooley-Tukey, Stockham, Rader

FFT 方法采用分治策略,分别使用 $f(x)$ 的偶数索引和奇数索引系数来定义两个新的度数为 n/2 的多项式 $f_{even}(x)$ 和 $f_{odd}(x)$。注意,这里 n 是 2 的幂:

$$ f(x) =f_{even}(x^2) + x * f_{odd}(x^2)\\ f(-x) = f_{even}(x^2) - x * f_{odd}(x^2)\\ \forall x \in \left\{1, w, w^2, ..., w^n\right\} $$

例如 p=17,$w=2$ 作为 8 次本原单位根,可以得到

$$ f(x) = 1+2x+3^2+4x^3=f_{even}(x^2) + x * f_{odd}(x^2)=(1+3x^2)+x(2+4x^2) \space \forall x \in \left\{1, 2, 4, 8, 16, 15, 13, 9\right\} $$

先计算

$$ f_{even}(x^2)=1+3x^2, \space f_{odd}(x^2)=2+4x^2 \space \forall x \in \left\{1, 4, 16, 13 \right\} $$

进行递归分割,直到 domain 长度为 1,然后递归合并,就可以得到 f(x) 的值了。

步骤如下:

  1. 先得到 $f_{even}(x^2)$ 和 $f_{odd}(x^2)$ 在 1, 2, 4, 8 上的值
  2. 然后通过组合 $f_{even}(x^2) + xf_{odd}(x^2)$ 得到 $f(x)$ 在 1, 2, 4, 8 上的值
  3. 然后通过组合 $f_{even}(x^2) - xf_{odd}(x^2)$ 得到 $f(x)$ 在 -1, -2, -4, -8 (实际上是 16, 15, 13, 9)上的值

这样 FFT 便实现了将多项式从系数形式转换为点值形式,IFFT 可反转回来,即将多项式从点值形式转换为系数形式。

FFT 递归形式:

# fft: from coefficients to evaluations
def fft(vals, domain):
    if (len(domain) & (len(domain) - 1)) != 0:
            raise ValueError("Domain length must be a power of 2.")  

    if len(vals) < len(domain):
            if len(vals) == 0:  
                    zero = Field(0)
            else:  
                    zero = vals[0] - vals[0]
            vals = vals + [zero] * (len(domain) - len(vals))

    if len(vals) == 1:
        return vals
    half_domain = halve_domain(domain)
    f0 = fft(vals[::2], half_domain)
    f1 = fft(vals[1::2], half_domain)
    left = [L+x*R for L,R,x in zip(f0, f1, domain)]
    right = [L-x*R for L,R,x in zip(f0, f1, domain)]
    return left+right

FFT 迭代形式:

# 按时间抽取(DIT)的快速傅里叶变换(FFT)
def dit(self, a, roots):
    n = len(a)
    
    # 位反转重排
    # 输入索引:[0,1,2,3](二进制:00,01,10,11)
    # 位反转后:[0,2,1,3](二进制:00,10,01,11)
    a = self.bit_reverse(a)  
 
    logn = n.bit_length() - 1
    
    for s in range(1, logn + 1): # 分治阶段,从最小子问题(m=2)逐步扩展到整个序列(m=n)
        m = 1 << s           # 当前子问题大小 (2^s)
        wm = roots[n // m]   # 单位根 ω_m = ω^(n/m)
        for k in range(0, n, m): # 处理每个子块
            w = 1  # 初始化旋转因子 ω_m^0
            for j in range(m // 2): # 蝶形运算
                u = a[k + j]
                v = self.gf.mul(w, a[k + j + m // 2])
                a[k + j] = self.gf.add(u, v)          # 上半部分
                a[k + j + m // 2] = self.gf.sub(u, v) # 下半部分
                w = self.gf.mul(w, wm)                # 更新旋转因子
    
    return a  # 返回频域结果

Radix-2 bowers FFT

Bowers FFT 的主要目标是通过减少 twiddle factors 访问次数和优化内存访问模式来提高计算效率。

Dit 和 Dif FFT

在 Dit FFT 中,输入数据按位反转顺序排列,而输出按自然顺序排列,从 2-DFT 开始,然后是 4-DFT……直到 N-DFT

在 Dif FFT 中,输入数据按自然顺序排列,而输出按位反转顺序排列,从 N-DFT 开始,然后是 N/2-DFT……直到 2-DFT

在 DIT 和 DIF FFT 中,蝶形运算均采用迭代方法而非递归方法执行。

Bowers FFT

Bowers FFT 与 DIF FFT 非常相似,其核心区别在于 twiddle factors 的访问方式。以 8 点 FFT 为例,在 DIF FFT 中,不同阶段的 twiddle factors 如下:

第一层:

$$ w_8^0,w_8^1,w_8^2,w_8^3 $$

第二层:

$$ w_8^0,w_8^2,w_8^0,w_8^2 $$

第三层:

$$ w_8^0,w_8^0,w_8^0,w_8^0 $$

而在 Bower FFT 中,不同阶段的 twiddle factors 如下:

第一层:

$$ w_8^0,w_8^0,w_8^0,w_8^0 $$

第二层:

$$ w_8^0,w_8^0,w_8^2,w_8^2 $$

第三层:

$$ w_8^0,w_8^1,w_8^2,w_8^3 $$

重点关注第二层:在 Bowers FFT 中,内存访问会更加连续。

DIF FFT 和 bower FFT 如下:

def dif(self, a, roots):
    n = len(a)
    logn = n.bit_length() - 1
    for s in range(logn, 0, -1):
        m = 1 << s
        wm = roots[n//m]
        for k in range(0, n, m):
            w = 1
            for j in range(m // 2):
                u = a[k + j]
                v = a[k + j + m // 2]
                a[k + j] = self.gf.add(u, v)
                a[k + j + m // 2] = self.gf.mul(w, self.gf.sub(u, v))
                w = self.gf.mul(w, wm)
    return self.bit_reverse(a)

def bower_g(self,a):
    n = len(a)
    a = self.bit_reverse(a)
    roots=self.get_forward_roots(n)
    roots=self.bit_reverse(roots[:len(roots) // 2])
    logn = n.bit_length() - 1
    for s in range(1, logn + 1):
        m = 1 << s
        for k in range(0, n, m):
            w = roots[k//m]
            for j in range(m // 2):
                u = a[k + j]
                v = a[k + j + m // 2]
                a[k + j] = self.gf.add(u, v)
                a[k + j + m // 2] = self.gf.mul(w, self.gf.sub(u, v))

Four-step FFT

四步 FFT(也称为 Bailey FFT)是一种用于计算快速傅里叶变换(FFT)的高性能算法。它是 Cooley-Tukey FFT 算法的变体。

原理

  1. 先将数据(按自然顺序)排列成矩阵
  2. 使用标准 FFT 算法独立处理矩阵的每一列
  3. 矩阵的每个元素乘以一个校正系数(称为 twiddle factors)
  4. 使用标准 FFT 算法独立处理矩阵的每一行

具有如下优势:

  • 将大规模 FFT 转换为两个较小的 FFT(行 FFT 和列 FFT),每个 FFT 每次仅对矩阵的一部分进行运算。与传统的 FFT 相比,这减少了随机内存访问,因为传统的 FFT 中的蝶形运算需要将数据分布在整个输入上
  • 行 FFT 和列 FFT 相互独立,可以并行计算。FPGA 和 GPU 等现代硬件可以利用这种并行性同时计算多个 FFT
  • 四步 FFT 不存储整个 FFT 输入/输出,而是处理适合有限片上 RAM 的较小块(例如矩阵的行或列)

代码:

class NTT:
    # generate forward and inverse roots, bit-reverse, dit and dif FFT, forward and inverse dit or dif.
    def __init__(self, modulus, n):
        self.gf = Field(modulus)
        self.n = n
    
    def get_forward_roots(self,n):
        return self.gf.roots_of_unity(n)
    
    def get_inverse_roots(self,n):
        forward_roots=self.gf.roots_of_unity(n)        
        return [self.gf.inv(r) for r in forward_roots]

    def bit_reversed_indices(self, n):
        logn = n.bit_length() - 1
        return [int(f"{i:0{logn}b}"[::-1], 2) for i in range(n)]

    def bit_reverse(self, a):
        n = len(a)
        indices = self.bit_reversed_indices(n)
        return [a[i] for i in indices]
        
    def matrix(self, a, log_rows, log_cols): 
        # transfer array into matrix
        rows = 1 << log_rows
        cols = 1 << log_cols
        return np.array(a).reshape((rows, cols))

    def transpose_and_flatten(self, matrix):
        # Transpose the matrix, and flatten it
        return [element for row in matrix.T for element in row]

    def apply_twiddles(self, wm, matrix):
        # each matrix[i,j] mul wm^(i*j), wm is the root of n-domain
        n, m = matrix.shape
        for i in range(n):
            for j in range(m):
                factor = self.gf.pow(wm, i * j)
                matrix[i, j] = self.gf.mul(matrix[i, j], factor)

    def apply_column_fft(self, matrix):
        # do fft for each column in matrix
        n_rows, n_cols = matrix.shape  
        for j in range(n_cols):  
            column = matrix[:, j].tolist()  
            fft_result = self.forward_dit(column)  
            matrix[:, j] = fft_result  

    def apply_row_fft(self,matrix):
        # do fft for each row in matrix
        for i in range(matrix.shape[0]):
            matrix[i] = self.forward_dit(matrix[i].tolist())
            
    def forward_dit(self, a):
        roots=self.get_forward_roots(len(a))
        return self.dit(a,roots)
        
    def dit(self, a, roots):
        n = len(a)
        a = self.bit_reverse(a)
        logn = n.bit_length() - 1
        for s in range(1, logn + 1):
            m = 1 << s
            wm = roots[n//m]
            for k in range(0, n, m):
                w = 1
                for j in range(m // 2):
                    u = a[k + j]
                    v = self.gf.mul(w, a[k + j + m // 2])
                    a[k + j] = self.gf.add(u, v)
                    a[k + j + m // 2] = self.gf.sub(u, v)
                    w = self.gf.mul(w, wm)
        return a
def four_step(array, log_rows,modulus):
    n = len(array)
    logn = n.bit_length() - 1
    log_cols = logn - log_rows
    assert log_rows > 0
    assert log_cols > 0
    assert modulus > n

    gf = Field(modulus)
    ntt = NTT(modulus, n)

    # first step: transfer the array into matrix
    matrix = ntt.matrix(array, log_cols, log_rows)
    print("origin matrix is:",matrix)

    # second step: do FFT for each column
    ntt.apply_column_fft(matrix)
    print("after column fft, matrix is:",matrix)

    # third step: apply twiddles wm^(i*j)
    wm = ntt.get_forward_roots(n)[1]
    ntt.apply_twiddles(wm, matrix)

    # fourth step: do FFT for each row
    ntt.apply_row_fft(matrix)
    print("after row fft, matrix is:",matrix)    

    # Transpose the matrix, and flatten it into array
    out_array = ntt.transpose_and_flatten(matrix)
    print("after transpose and flatten, array is:",out_array)   
    return out_array

参考

plonky3-python-notebook/fft