在有限域上实现FFT [英] Implementing FFT over finite fields

查看:127
本文介绍了在有限域上实现FFT的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用NTT实现多项式的乘法.我遵循了数论转换(整数DFT),工作.

现在,我想在有限域Z_p[x]上实现多项式的乘法,其中p是任意质数.

与以前的无界情况相比,它现在改变了系数现在由p界定的任何东西吗?

特别是,原始NTT需要找到质数N作为大于(magnitude of largest element of input vector)^2 * (length of input vector) + 1的工作模数,以便结果永不溢出.如果结果无论如何都会受到该p质数的限制,则模数可以是多少?请注意,p - 1不必为(some positive integer) * (length of input vector)形式.

我从上面的链接复制粘贴了源代码以说明问题:

# 
# Number-theoretic transform library (Python 2, 3)
# 
# Copyright (c) 2017 Project Nayuki
# All rights reserved. Contact Nayuki for licensing.
# https://www.nayuki.io/page/number-theoretic-transform-integer-dft
#

import itertools, numbers

def find_params_and_transform(invec, minmod):
    check_int(minmod)
    mod = find_modulus(len(invec), minmod)
    root = find_primitive_root(len(invec), mod - 1, mod)
    return (transform(invec, root, mod), root, mod)

def check_int(n):
    if not isinstance(n, numbers.Integral):
        raise TypeError()

def find_modulus(veclen, minimum):
    check_int(veclen)
    check_int(minimum)
    if veclen < 1 or minimum < 1:
        raise ValueError()
    start = (minimum - 1 + veclen - 1) // veclen
    for i in itertools.count(max(start, 1)):
        n = i * veclen + 1
        assert n >= minimum
        if is_prime(n):
            return n

def is_prime(n):
    check_int(n)
    if n <= 1:
        raise ValueError()
    return all((n % i != 0) for i in range(2, sqrt(n) + 1))

def sqrt(n):
    check_int(n)
    if n < 0:
        raise ValueError()
    i = 1
    while i * i <= n:
        i *= 2
    result = 0
    while i > 0:
        if (result + i)**2 <= n:
            result += i
        i //= 2
    return result

def find_primitive_root(degree, totient, mod):
    check_int(degree)
    check_int(totient)
    check_int(mod)
    if not (1 <= degree <= totient < mod):
        raise ValueError()
    if totient % degree != 0:
        raise ValueError()
    gen = find_generator(totient, mod)
    root = pow(gen, totient // degree, mod)
    assert 0 <= root < mod
    return root

def find_generator(totient, mod):
    check_int(totient)
    check_int(mod)
    if not (1 <= totient < mod):
        raise ValueError()
    for i in range(1, mod):
        if is_generator(i, totient, mod):
            return i
    raise ValueError("No generator exists")

def is_generator(val, totient, mod):
    check_int(val)
    check_int(totient)
    check_int(mod)
    if not (0 <= val < mod):
        raise ValueError()
    if not (1 <= totient < mod):
        raise ValueError()
    pf = unique_prime_factors(totient)
    return pow(val, totient, mod) == 1 and all((pow(val, totient // p, mod) != 1) for p in pf)

def unique_prime_factors(n):
    check_int(n)
    if n < 1:
        raise ValueError()
    result = []
    i = 2
    end = sqrt(n)
    while i <= end:
        if n % i == 0:
            n //= i
            result.append(i)
            while n % i == 0:
                n //= i
            end = sqrt(n)
        i += 1
    if n > 1:
        result.append(n)
    return result

def transform(invec, root, mod):
    check_int(root)
    check_int(mod)
    if len(invec) >= mod:
        raise ValueError()
    if not all((0 <= val < mod) for val in invec):
        raise ValueError()
    if not (1 <= root < mod):
        raise ValueError()

    outvec = []
    for i in range(len(invec)):
        temp = 0
        for (j, val) in enumerate(invec):
            temp += val * pow(root, i * j, mod)
            temp %= mod
        outvec.append(temp)
    return outvec

def inverse_transform(invec, root, mod):
    outvec = transform(invec, reciprocal(root, mod), mod)
    scaler = reciprocal(len(invec), mod)
    return [(val * scaler % mod) for val in outvec]

def reciprocal(n, mod):
    check_int(n)
    check_int(mod)
    if not (0 <= n < mod):
        raise ValueError()
    x, y = mod, n
    a, b = 0, 1
    while y != 0:
        a, b = b, a - x // y * b
        x, y = y, x % y
    if x == 1:
        return a % mod
    else:
        raise ValueError("Reciprocal does not exist")

def circular_convolve(vec0, vec1):
    if not (0 < len(vec0) == len(vec1)):
        raise ValueError()
    if any((val < 0) for val in itertools.chain(vec0, vec1)):
        raise ValueError()
    maxval = max(val for val in itertools.chain(vec0, vec1))
    minmod = maxval**2 * len(vec0) + 1
    temp0, root, mod = find_params_and_transform(vec0, minmod)
    temp1 = transform(vec1, root, mod)
    temp2 = [(x * y % mod) for (x, y) in zip(temp0, temp1)]
    return inverse_transform(temp2, root, mod)

vec0 = [24, 12, 28, 8, 0, 0, 0, 0]
vec1 = [4, 26, 29, 23, 0, 0, 0, 0]

print(circular_convolve(vec0, vec1))

def modulo(vec, prime):
    return [x % prime for x in vec]

print(modulo(circular_convolve(vec0, vec1), 31))

打印:

[96, 672, 1120, 1660, 1296, 876, 184, 0]
[3, 21, 4, 17, 25, 8, 29, 0]

但是,当我将minmod = maxval**2 * len(vec0) + 1更改为minmod = maxval + 1时,它将停止工作:

[14, 16, 13, 20, 25, 15, 20, 0]
[14, 16, 13, 20, 25, 15, 20, 0]

为了能够按预期工作,最小的minmod(在上面的链接中的N)是多少?

解决方案

如果您输入的n整数绑定到某些素数q(任何mod q不仅是素数也将是相同的),则可以使用它作为max value +1的内容,但请注意,不能将其用作 NTT 的质数p,因为 NTT 质数p具有特殊的属性.他们都在这里:

因此,每个输入的最大值为q-1,但是在您的任务计算过程中(对2个 NTT 结果进行卷积),第一层结果的大小可以提高到n.(q-1),但是对它们进行卷积,最终 iNTT 的输入幅度将上升为:

m = n.((q-1)^2)

如果您在 NTT 上执行的操作不同于m等式,则可能会发生变化.

现在让我们回到p,因此,简而言之,您可以使用支持这些的任何素数p:

p mod n == 1
p > m

并且存在1 <= r,L < p这样:

 p mod (L-1) = 0
r^(L*i) mod p == 1 // i = { 0,n }
r^(L*i) mod p != 1 // i = { 1,2,3, ... n-1 }
 

如果这一切都得到满足,则p第n个统一根,并且可以用于 NTT .要找到这样的素数,还要r,L看看上面的链接(有C ++代码找到了这样的素数).

例如,在字符串乘法期间,我们对它们进行2个字符串 NTT ,然后对结果进行卷积,并 iNTT 返回结果(这是两个输入大小的总和).例如:

                                99999999999999999999999999999999
                               *99999999999999999999999999999999
----------------------------------------------------------------
9999999999999999999999999999999800000000000000000000000000000001

q = 10和两个操作数均为9 ^ 32,因此n=32因此是m = 9*9*32 = 2592,找到的素数是p = 2689.如您所见,结果匹配,因此不会发生溢出.但是,如果我使用仍然适合所有其他条件的任何较小的质数,则结果将不匹配.我专门用它来尽可能地拉伸NTT值(所有值都是q-1并且大小等于2的幂)

如果您的 NTT 很快而n不是2的幂,那么您需要对每个 NTT 零填充到最接近的较高或等于2的幂. >.但这不会影响m的值,因为零填充不应该增加值的大小.我的测试证明了这一点,因此您可以使用卷积:

m = (n1+n2).((q-1)^2)/2

其中n1,n2是零填充之前的原始输入大小.

有关实现 NTT 的更多信息,您可以在 C ++ (经过广泛优化)中查看我的信息:

因此,请回答您的问题:

  1. 是的,您可以利用输入为mod q的事实,但是不能将q用作p !!!

  2. 您只能将minmod = n * (maxval + 1)用于单个NTT(或第一层NTT),但是当您在NTT使用过程中将它们与卷积链接时,您不能将其用于最后的INTT阶段! !

但是,正如我在评论中提到的那样,最简单的方法是使用最大可能的p,它适合您所使用的数据类型,并且可用于支持2种输​​入大小的所有幂次. >

这基本上使您的问题无关紧要.我只能想到不可能/不希望出现这种情况的唯一情况是在没有最大"限制的任意精度数字上.绑定到变量p的性能有很多问题,因为对p的搜索确实很慢(可能比 NTT 本身还要慢),并且变量p禁用了许多对性能的优化模块化算法需要使 NTT 变慢.

I would like to implement multiplication of polynomials using NTT. I followed Number-theoretic transform (integer DFT) and it seems to work.

Now I would like to implement multiplication of polynomials over finite fields Z_p[x] where p is arbitrary prime number.

Does it changes anything that the coefficients are now bounded by p, compared to the former unbounded case?

In particular, original NTT required to find prime number N as the working modulus that is larger than (magnitude of largest element of input vector)^2 * (length of input vector) + 1 so that the result never overflows. If the result is going to be bounded by that p prime anyway, how small can the modulus be? Note that p - 1 does not have to be of form (some positive integer) * (length of input vector).

Edit: I copy-pasted the source from the link above to illustrate the problem:

# 
# Number-theoretic transform library (Python 2, 3)
# 
# Copyright (c) 2017 Project Nayuki
# All rights reserved. Contact Nayuki for licensing.
# https://www.nayuki.io/page/number-theoretic-transform-integer-dft
#

import itertools, numbers

def find_params_and_transform(invec, minmod):
    check_int(minmod)
    mod = find_modulus(len(invec), minmod)
    root = find_primitive_root(len(invec), mod - 1, mod)
    return (transform(invec, root, mod), root, mod)

def check_int(n):
    if not isinstance(n, numbers.Integral):
        raise TypeError()

def find_modulus(veclen, minimum):
    check_int(veclen)
    check_int(minimum)
    if veclen < 1 or minimum < 1:
        raise ValueError()
    start = (minimum - 1 + veclen - 1) // veclen
    for i in itertools.count(max(start, 1)):
        n = i * veclen + 1
        assert n >= minimum
        if is_prime(n):
            return n

def is_prime(n):
    check_int(n)
    if n <= 1:
        raise ValueError()
    return all((n % i != 0) for i in range(2, sqrt(n) + 1))

def sqrt(n):
    check_int(n)
    if n < 0:
        raise ValueError()
    i = 1
    while i * i <= n:
        i *= 2
    result = 0
    while i > 0:
        if (result + i)**2 <= n:
            result += i
        i //= 2
    return result

def find_primitive_root(degree, totient, mod):
    check_int(degree)
    check_int(totient)
    check_int(mod)
    if not (1 <= degree <= totient < mod):
        raise ValueError()
    if totient % degree != 0:
        raise ValueError()
    gen = find_generator(totient, mod)
    root = pow(gen, totient // degree, mod)
    assert 0 <= root < mod
    return root

def find_generator(totient, mod):
    check_int(totient)
    check_int(mod)
    if not (1 <= totient < mod):
        raise ValueError()
    for i in range(1, mod):
        if is_generator(i, totient, mod):
            return i
    raise ValueError("No generator exists")

def is_generator(val, totient, mod):
    check_int(val)
    check_int(totient)
    check_int(mod)
    if not (0 <= val < mod):
        raise ValueError()
    if not (1 <= totient < mod):
        raise ValueError()
    pf = unique_prime_factors(totient)
    return pow(val, totient, mod) == 1 and all((pow(val, totient // p, mod) != 1) for p in pf)

def unique_prime_factors(n):
    check_int(n)
    if n < 1:
        raise ValueError()
    result = []
    i = 2
    end = sqrt(n)
    while i <= end:
        if n % i == 0:
            n //= i
            result.append(i)
            while n % i == 0:
                n //= i
            end = sqrt(n)
        i += 1
    if n > 1:
        result.append(n)
    return result

def transform(invec, root, mod):
    check_int(root)
    check_int(mod)
    if len(invec) >= mod:
        raise ValueError()
    if not all((0 <= val < mod) for val in invec):
        raise ValueError()
    if not (1 <= root < mod):
        raise ValueError()

    outvec = []
    for i in range(len(invec)):
        temp = 0
        for (j, val) in enumerate(invec):
            temp += val * pow(root, i * j, mod)
            temp %= mod
        outvec.append(temp)
    return outvec

def inverse_transform(invec, root, mod):
    outvec = transform(invec, reciprocal(root, mod), mod)
    scaler = reciprocal(len(invec), mod)
    return [(val * scaler % mod) for val in outvec]

def reciprocal(n, mod):
    check_int(n)
    check_int(mod)
    if not (0 <= n < mod):
        raise ValueError()
    x, y = mod, n
    a, b = 0, 1
    while y != 0:
        a, b = b, a - x // y * b
        x, y = y, x % y
    if x == 1:
        return a % mod
    else:
        raise ValueError("Reciprocal does not exist")

def circular_convolve(vec0, vec1):
    if not (0 < len(vec0) == len(vec1)):
        raise ValueError()
    if any((val < 0) for val in itertools.chain(vec0, vec1)):
        raise ValueError()
    maxval = max(val for val in itertools.chain(vec0, vec1))
    minmod = maxval**2 * len(vec0) + 1
    temp0, root, mod = find_params_and_transform(vec0, minmod)
    temp1 = transform(vec1, root, mod)
    temp2 = [(x * y % mod) for (x, y) in zip(temp0, temp1)]
    return inverse_transform(temp2, root, mod)

vec0 = [24, 12, 28, 8, 0, 0, 0, 0]
vec1 = [4, 26, 29, 23, 0, 0, 0, 0]

print(circular_convolve(vec0, vec1))

def modulo(vec, prime):
    return [x % prime for x in vec]

print(modulo(circular_convolve(vec0, vec1), 31))

Prints:

[96, 672, 1120, 1660, 1296, 876, 184, 0]
[3, 21, 4, 17, 25, 8, 29, 0]

However, where I change minmod = maxval**2 * len(vec0) + 1 to minmod = maxval + 1, it stops working:

[14, 16, 13, 20, 25, 15, 20, 0]
[14, 16, 13, 20, 25, 15, 20, 0]

What is the smallest minmod (N in the link above) be in order to work as expected?

解决方案

If your input of n integers is bound to some prime q (any mod q not just prime will be the same) You can use it as a max value +1 but beware you can not use it as a prime p for the NTT because NTT prime p has special properties. All of them are here:

so our max value of each input is q-1 but during your task computation (Convolution on 2 NTT results) the magnitude of first layer results can rise up to n.(q-1) but as we are doing convolution on them the input magnitude of final iNTT will rise up to:

m = n.((q-1)^2)

If you are doing different operations on the NTTs than the m equation might change.

Now let us get back to the p so in a nutshell you can use any prime p that upholds these:

p mod n == 1
p > m

and there exist 1 <= r,L < p such that:

p mod (L-1) = 0
r^(L*i) mod p == 1 // i = { 0,n }
r^(L*i) mod p != 1 // i = { 1,2,3, ... n-1 }

If all this is satisfied then p is nth root of unity and can be used for NTT. To find such prime and also the r,L look at the link above (there is C++ code that finds such).

For example during string multiplication we take 2 strings do NTT on them then convolute the result and iNTT back the result (that is sum of both input sizes). So for example:

                                99999999999999999999999999999999
                               *99999999999999999999999999999999
----------------------------------------------------------------
9999999999999999999999999999999800000000000000000000000000000001

the q = 10 and both operands are 9^32 so n=32 hence m = 9*9*32 = 2592 and the found prime is p = 2689. As you can see the result matches so no overflow occurs. However if I use any smaller prime that still fit all the other conditions the result will not match. I used this specifically to stretch the NTT values as much as possible (all values are q-1 and sizes are equal to the same power of 2)

In case your NTT is fast and n is not a power of 2 then you need to zero pad to nearest higher or equal power of 2 size for each NTT. But that should not affect the m value as zero pad should not increase the magnitude of values. My testing proves it so for convolution you can use:

m = (n1+n2).((q-1)^2)/2

where n1,n2 are the raw inputs sizes before zeropad.

For more info about implementing NTT you can check out mine in C++ (extensively optimized):

So to answer your questions:

  1. yes you can take advantage of the fact that input is mod q but you can not use q as p !!!

  2. You can use minmod = n * (maxval + 1) only for single NTT (or first layer of NTTs) but as you are chaining them with convolution during your NTT usage you can not use that for the final INTT stage !!!

However as I mentioned in the comments easiest is to use max possible p that fits in the data type you are using and is usable for all power of 2 sizes of input supported.

Which basically renders your question irrelevant. The only case I can think of where this is not possible/desired is on arbitrary precision numbers where there is "no" max limit. There are many performance issues binded to variable p as the search for p is really slow (may be even slower than the NTT itself) and also variable p disables many performance optimizations of the modular arithmetics needed making the NTT really slow.

这篇关于在有限域上实现FFT的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆