math - 在有限域上实现 FFT
问题描述
我想使用 NTT 实现多项式的乘法。我遵循数论变换(整数 DFT),它似乎有效。
现在我想在任意素数的有限域Z_p[x]
上实现多项式的乘法。p
p
与以前的无界情况相比,它是否改变了系数现在以 为界的任何内容?
特别是,原始 NTT 需要找到N
大于的工作模数作为工作模数,(magnitude of largest element of input vector)^2 * (length of input vector) + 1
以便结果永远不会溢出。如果结果无论如何都会受到那个p
素数的限制,那么模数可以有多小?请注意,p - 1
不必是 form (some positive integer) * (length of input vector)
。
编辑:我从上面的链接中复制粘贴了源代码来说明问题:
#
# Number-theoretic transform library (Python 2, 3)
#
# Copyright (c) 2017 Project Nayuki
# All rights reserved. Contact Nayuki for licensing.
# https://www.nayuki.io/page/number-theoretic-transform-integer-dft
#
import itertools, numbers
def find_params_and_transform(invec, minmod):
check_int(minmod)
mod = find_modulus(len(invec), minmod)
root = find_primitive_root(len(invec), mod - 1, mod)
return (transform(invec, root, mod), root, mod)
def check_int(n):
if not isinstance(n, numbers.Integral):
raise TypeError()
def find_modulus(veclen, minimum):
check_int(veclen)
check_int(minimum)
if veclen < 1 or minimum < 1:
raise ValueError()
start = (minimum - 1 + veclen - 1) // veclen
for i in itertools.count(max(start, 1)):
n = i * veclen + 1
assert n >= minimum
if is_prime(n):
return n
def is_prime(n):
check_int(n)
if n <= 1:
raise ValueError()
return all((n % i != 0) for i in range(2, sqrt(n) + 1))
def sqrt(n):
check_int(n)
if n < 0:
raise ValueError()
i = 1
while i * i <= n:
i *= 2
result = 0
while i > 0:
if (result + i)**2 <= n:
result += i
i //= 2
return result
def find_primitive_root(degree, totient, mod):
check_int(degree)
check_int(totient)
check_int(mod)
if not (1 <= degree <= totient < mod):
raise ValueError()
if totient % degree != 0:
raise ValueError()
gen = find_generator(totient, mod)
root = pow(gen, totient // degree, mod)
assert 0 <= root < mod
return root
def find_generator(totient, mod):
check_int(totient)
check_int(mod)
if not (1 <= totient < mod):
raise ValueError()
for i in range(1, mod):
if is_generator(i, totient, mod):
return i
raise ValueError("No generator exists")
def is_generator(val, totient, mod):
check_int(val)
check_int(totient)
check_int(mod)
if not (0 <= val < mod):
raise ValueError()
if not (1 <= totient < mod):
raise ValueError()
pf = unique_prime_factors(totient)
return pow(val, totient, mod) == 1 and all((pow(val, totient // p, mod) != 1) for p in pf)
def unique_prime_factors(n):
check_int(n)
if n < 1:
raise ValueError()
result = []
i = 2
end = sqrt(n)
while i <= end:
if n % i == 0:
n //= i
result.append(i)
while n % i == 0:
n //= i
end = sqrt(n)
i += 1
if n > 1:
result.append(n)
return result
def transform(invec, root, mod):
check_int(root)
check_int(mod)
if len(invec) >= mod:
raise ValueError()
if not all((0 <= val < mod) for val in invec):
raise ValueError()
if not (1 <= root < mod):
raise ValueError()
outvec = []
for i in range(len(invec)):
temp = 0
for (j, val) in enumerate(invec):
temp += val * pow(root, i * j, mod)
temp %= mod
outvec.append(temp)
return outvec
def inverse_transform(invec, root, mod):
outvec = transform(invec, reciprocal(root, mod), mod)
scaler = reciprocal(len(invec), mod)
return [(val * scaler % mod) for val in outvec]
def reciprocal(n, mod):
check_int(n)
check_int(mod)
if not (0 <= n < mod):
raise ValueError()
x, y = mod, n
a, b = 0, 1
while y != 0:
a, b = b, a - x // y * b
x, y = y, x % y
if x == 1:
return a % mod
else:
raise ValueError("Reciprocal does not exist")
def circular_convolve(vec0, vec1):
if not (0 < len(vec0) == len(vec1)):
raise ValueError()
if any((val < 0) for val in itertools.chain(vec0, vec1)):
raise ValueError()
maxval = max(val for val in itertools.chain(vec0, vec1))
minmod = maxval**2 * len(vec0) + 1
temp0, root, mod = find_params_and_transform(vec0, minmod)
temp1 = transform(vec1, root, mod)
temp2 = [(x * y % mod) for (x, y) in zip(temp0, temp1)]
return inverse_transform(temp2, root, mod)
vec0 = [24, 12, 28, 8, 0, 0, 0, 0]
vec1 = [4, 26, 29, 23, 0, 0, 0, 0]
print(circular_convolve(vec0, vec1))
def modulo(vec, prime):
return [x % prime for x in vec]
print(modulo(circular_convolve(vec0, vec1), 31))
印刷:
[96, 672, 1120, 1660, 1296, 876, 184, 0]
[3, 21, 4, 17, 25, 8, 29, 0]
但是,在我更改minmod = maxval**2 * len(vec0) + 1
为 的地方minmod = maxval + 1
,它停止工作:
[14, 16, 13, 20, 25, 15, 20, 0]
[14, 16, 13, 20, 25, 15, 20, 0]
为了按预期工作,最小的minmod
(在上面的链接中)是什么?N
解决方案
如果您的n
整数输入绑定到某个素数q
(任何mod q
不只是素数都是相同的)您可以将其用作 amax value +1
但请注意您不能将其用作 NTT 的素数,p
因为NTT素数具有特殊属性。他们都在这里:p
所以我们每个输入的最大值是,q-1
但是在你的任务计算期间(对 2 个NTT结果进行卷积),第一层结果的幅度可以上升到,n.(q-1)
但是当我们对它们进行卷积时,最终iNTT的输入幅度将上升到:
m = n.((q-1)^2)
如果您在NTT上执行不同的操作,则m
等式可能会改变。
现在让我们回到p
简而言之,你可以使用任何p
支持这些的素数:
p mod n == 1
p > m
并且存在1 <= r,L < p
这样的:
p mod (L-1) = 0
r^(L*i) mod p == 1 // i = { 0,n }
r^(L*i) mod p != 1 // i = { 1,2,3, ... n-1 }
如果所有这些都满足,那么p
是第 n 个单位根,可以用于NTT。要找到这样的素数,还要r,L
查看上面的链接(有 C++ 代码可以找到这样的)。
例如,在字符串乘法过程中,我们对两个字符串进行NTT,然后对结果进行卷积,然后iNTT返回结果(即两个输入大小的总和)。例如:
99999999999999999999999999999999
*99999999999999999999999999999999
----------------------------------------------------------------
9999999999999999999999999999999800000000000000000000000000000001
和q = 10
两个操作数都是 9^32n=32
因此m = 9*9*32 = 2592
和找到的素数是p = 2689
. 如您所见,结果匹配,因此不会发生溢出。但是,如果我使用任何仍然适合所有其他条件的较小素数,则结果将不匹配。我专门用它来尽可能地拉伸 NTT 值(所有值q-1
和大小都等于 2 的相同幂)
如果您的NTT速度很快并且不是 2 的幂,那么您需要将每个NTTn
归零到最接近的更高或等于 2 大小的幂。但这不应该影响值,因为零填充不应该增加值的大小。我的测试证明了这一点,因此对于卷积,您可以使用:m
m = (n1+n2).((q-1)^2)/2
n1,n2
zeropad 之前的原始输入大小在哪里。
有关实现NTT的更多信息,您可以查看我的C++(广泛优化):
所以回答你的问题:
mod q
是的,您可以利用 input is但不能用作q
as的事实p
!您
minmod = n * (maxval + 1)
只能用于单个 NTT(或第一层 NTT),但由于您在 NTT 使用期间将它们与卷积链接起来,因此您不能将其用于最后的 INTT 阶段!!!
但是,正如我在评论中提到的,最简单的方法是使用p
适合您正在使用的数据类型的最大可能,并且可用于支持的 2 种输入大小的所有功率。
这基本上使您的问题无关紧要。我能想到的唯一不可能/不希望的情况是在没有“最大”限制的任意精度数字上。有许多与变量相关p
的性能问题,因为搜索p
非常慢(可能甚至比NTT本身还慢),并且变量p
禁用了所需的模运算的许多性能优化,使得NTT非常慢。
推荐阅读
- php - 在 PHP、Codeigniter 中创建或编辑数据后如何返回上一页
- sql-server - SQL Server:多列键字典顺序
- c++ - 是否可以从整数序列生成变体?
- python - 如何将转换后的json数据保存到django中的数据库列中
- c++ - chrono 在 C++ 中可以测量的最大时间间隔是多少?
- python - 查找二元分类的阈值
- java - 如何使用多个 JFrame 将 JTable 数据传递给 JTextArea?
- docker - 使用 Docker 使用 Nginx 启用 HTTPS
- python - 当我们在 python selenium 中使用变量作为搜索文本时,查找包含特定文本的 web 元素会给出不同的结果
- html - 嵌套翻转卡和 Apple Safari