模块化算术和NTT(有限域DFT)优化 [英] Modular arithmetics and NTT (finite field DFT) optimizations

查看:267
本文介绍了模块化算术和NTT(有限域DFT)优化的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用NTT进行快速平方(请参阅快速bignum平方计算),但结果是慢的,即使是非常大的数字..超过12000位。



所以我的问题是


  1. 有没有办法优化我的NTT变换?
    我不是想通过parallelism(线程)加快速度。









  2. $ b

    这是我的(已经优化的)源代码在C + + NTT(它是完整的和100%工作在C + + whitout任何需要的第三方库,并应该是线程安全的。请注意源数组用作临时!!!,也不能将数组转换为自身)。

      // ---------- -------------------------------------------------- --------------- 
    class fourier_NTT //数论理论变换
    {

    public:
    DWORD r,L, p,N;
    DWORD W,iW,rN;
    fourier_NTT(){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; }

    //主接口
    void NTT(DWORD * dst,DWORD * src,DWORD n = 0); // DWORD dst [n] = fast NTT(DWORD src [n])
    void INTT(DWORD * dst,DWORD * src,DWORD n = 0) // DWORD dst [n] = fast INTT(DWORD src [n])

    //帮助函数
    bool init(DWORD n); // init r,L,p,W,iW,rN
    void NTT_fast(DWORD * dst,DWORD * src,DWORD n,DWORD w); // DWORD dst [n] = fast NTT(DWORD src [n])

    //仅用于测试
    void NTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w ); // DWORD dst [n] = slow NTT(DWORD src [n])
    void INTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w); // DWORD dst [n] = slow INTT(DWORD src [n])

    // DWORD算术
    DWORD shl(DWORD a);
    DWORD shr(DWORD a);

    //模块化算术
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };

    // ---------------------------------------- -----------------------------------
    void fourier_NTT :: NTT(DWORD * dst,DWORD * src,DWORD n)
    {
    if(n> 0)init(n);
    NTT_fast(dst,src,N,W);
    // NTT_slow(dst,src,N,W);
    }

    // ----------------------------------- ----------------------------------------
    void fourier_NTT :: INTT( DWORD * dst,DWORD * src,DWORD n)
    {
    if(n> 0)init(n)
    NTT_fast(dst,src,N,iW);
    for(DWORD i = 0; i // INTT_slow(dst,src,N,W);
    }

    // ----------------------------------- ----------------------------------------
    bool fourier_NTT :: init( DWORD n)
    {
    //(max(src [])^ 2)* n < p else NTT溢出可以ocur!
    r = 2; p = 0xC0000001; if((n <2)||(n> 0x10000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x30000000 / n; // 32:30 bit for unsigned 32 bit
    // r = 2; p = 0x78000001; if((n <2)||(n> 0x04000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x3c000000 / n; // 31:27 bit for signed 32 bit
    // r = 2; p = 0x00010001; if((n <2)||(n> 0x00000020)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x00000020 / n; // 17:16 bit for 16 bit
    // r = 2; p = 0x0a000001; if((n <2)||(n> 0x01000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x01000000 / n; // 28:25 bit
    N = n; //向量的大小[DWORDs]
    W = modpow(r,L); // Wn for NTT
    iW = modpow(r,p-1-L); // Wn for INTT
    rN = modpow(n,p-2); // scale for INTT
    return true;
    }

    // ----------------------------------- ----------------------------------------
    void fourier_NTT :: NTT_fast( DWORD * dst,DWORD * src,DWORD n,DWORD w)
    {
    if(n <= 1){if(n == 1)dst [0] = src [0]返回; } w B = b DWORD i,j,a0,a1,n2 = n> 1,w2 = modmul(w,w)
    //重新排序even,odd
    for(i = 0,j = 0; i for(j = 1; i //递归
    NTT_fast(src,dst,n2,w2); // even
    NTT_fast(src + n2,dst + n2,n2,w2); // odd
    //恢复结果
    for(w2 = 1,i = 0,j = n2; i {
    a0 = src [i];
    a1 = modmul(src [j],w2);
    dst [i] = modadd(a0,a1);
    dst [j] = modsub(a0,a1);
    }
    }

    // ------------------------------ ---------------------------------------------
    void fourier_NTT :: NTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a,n2 = n>
    for(wj = 1,j = 0; j {
    a = 0;
    for(wi = 1,i = 0; i {
    a = modadd(a,modmul(wi,src [i]))
    wi = modmul(wi,wj);
    }
    dst [j] = a;
    wj = modmul(wj,w);
    }
    }

    // ------------------------------ ---------------------------------------------
    void fourier_NTT :: INTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w)
    {
    DWORD i,j,wi = 1,wj = 1,a,n2 = n> 1 ;
    for(wj = 1,j = 0; j {
    a = 0;
    for(wi = 1,i = 0; i {
    a = modadd(a,modmul(wi,src [i]))
    wi = modmul(wi,wj);
    }
    dst [j] = modmul(a,rN);
    wj = modmul(wj,iW);
    }
    }

    // ------------------------------ ---------------------------------------------
    DWORD fourier_NTT :: shl(DWORD a){return(a << 1)& 0xFFFFFFFE; }
    DWORD fourier_NTT :: shr(DWORD a){return(a>> 1)& 0x7FFFFFFF; }

    // --------------------------------------- ------------------------------------
    DWORD fourier_NTT :: mod(DWORD a)
    {
    DWORD bb;
    for(bb = p;(DWORD(a)> DWORD(bb))&&(!DWORD(bb& 0x80000000)); bb = shl(bb)
    for(;;)
    {
    if(DWORD(a)> = DWORD(bb))a- = bb;
    if(bb == p)break;
    bb = shr(bb);
    }
    return a;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modadd( DWORD a,DWORD b)
    {
    DWORD d,cy;
    a = mod(a);
    b = mod(b);
    d = a + b;
    cy =(shr(a)+ shr(b)+ shr((a& 1)+(b& 1)))& 0x80000000;
    if(cy)d- = p;
    if(DWORD(d)> = DWORD(p))d- = p;
    return d;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modsub( DWORD a,DWORD b)
    {
    DWORD d;
    a = mod(a);
    b = mod(b);
    d = a-b; if(DWORD(a) if(DWORD(d)> = DWORD(p))d- = p;
    return d;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modmul( DWORD a,DWORD b)
    {// b bez orezania!
    int i;
    DWORD d;
    a = mod(a);
    for(d = 0,i = 0; i <32; i ++)
    {
    if(DWORD(a& 1))d = modadd(d,b)
    a = shr(a);
    b = modadd(b,b);
    }
    return d;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modpow( DWORD a,DWORD b)
    {// a,b bez orezania!
    int i;
    DWORD d = 1;
    for(i = 0; i <32; i ++)
    {
    d = modmul(d,d);
    if(DWORD(b& 0x80000000))d = modmul(d,a);
    b = shl(b);
    }
    return d;
    }
    // --------------------------------------- ------------------------------------

    我的NTT类的使用示例:

      fourier_NTT ntt; 
    const DWORD n = 32
    DWORD x [N] = {0,1,2,3,... 31},y [N] = {32,33,34,35, .63},z [N];

    ntt.NTT(z,x,N); // z [N] = NTT(x [N]),也是N
    的初始化常量ntt.NTT(x,y); // x [N] = NTT(y [N]),不重新计算常数,使用最后N
    //模数卷积y [] = z []。 0; i ntt.INTT(x,y); // x [N] = INTT(y [N]),不重新计算常量,使用最后N
    // x [] =原始x []的卷积y []

    优化前的一些测量(非NTT类):

      a = 0.98765588997654321000 | 389 * 32位
    循环1次
    sqr1 [3.177 ms]快速sqr
    sqr2 [720.419 ms] NTT sqr
    mul1 [5.588 ms] simpe mul
    mul2 [ 3.172 ms] karatsuba mul
    mul3 [1053.382 ms] NTT mul

    (当前代码,较低递归参数大小/计数和更好的模块算术):

      a = 0.98765588997654321000 | 389 * 32位
    循环1x次
    sqr1 [3.214 ms]快速sqr
    sqr2 [208.298 ms] NTT sqr
    mul1 [5.564 ms] simpe mul
    mul2 [ 3.113 ms] karatsuba mul
    mul3 [302.740 ms] NTT mul

    检查NTT mul和NTT sqr次(我的优化加快了3倍多一点)。它只有1x次循环,所以它不是很精确(错误〜10%),但加速现在是明显的(通常我循环它1000x和更多,但我的NTT太慢了)。



    你可以自由使用我的代码...只要保持我的nick和/或链接到这个页面(rem在代码,readme.txt,约或任何)。我希望它有助于...(我没有看到C + +源的快速NTT的任何地方,所以我不得不自己写它)。根据 fourier_NTT :: init(DWORD n)函数测试所有接受的N的统一根。



    PS:有关NTT的详细信息,请参阅 http://stackoverflow.com/a/18547575/2521214



    [edit1:]代码中的更改



    我设法进一步优化我的模块算术,通过利用模数质数总是0xC0000001和消除不必要的调用。结果加速现在是令人惊叹的(超过40x倍)现在和NTT乘法比卡拉图巴在大约1500 * 32位阈值后快。 BTW,我的NTT的速度现在与64位双精度的优化DFFT相同。



    一些测量:

      a = 0.98765588997654321000 | 1553 * 32bits 
    looped 10x times
    mul2 [28.585 ms] karatsuba mul
    mul3 [26.311 ms] NTT mul

    模块化算法的新源代码:

      // ------ -------------------------------------------------- ------------------- 
    DWORD fourier_NTT :: mod(DWORD a)
    {
    if(a> p)a- = p;
    return a;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modadd( DWORD a,DWORD b)
    {
    DWORD d,cy;
    if(a> p)a- = p;
    if(b> p)b- = p;
    d = a + b;
    cy =((a> 1)+(b> 1)+(((a& 1)+(b& 1))> 1))& 0x80000000
    if(cy)d- = p;
    if(d> p)d- = p;
    return d;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modsub( DWORD a,DWORD b)
    {
    DWORD d;
    if(a> p)a- = p;
    if(b> p)b- = p;
    d = a-b;
    if(a< b)d + = p;
    if(d> p)d- = p;
    return d;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modmul( DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a = a;
    _b = b;
    _p = p;
    asm {
    mov eax,_a
    mov ebx,_b
    mul ebx // H(edx),L(eax)= eax * ebx
    mov ebx, _p
    div ebx // eax = H(edx),L(eax)/ ebx
    mov _a,edx // edx = H(edx),L(eax)%ebx
    }
    return _a;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modpow( DWORD a,DWORD b)
    {// b bez orezania!
    int i;
    DWORD d = 1;
    if(a> p)a- = p;
    for(i = 0; i <32; i ++)
    {
    d = modmul(d,d);
    if(DWORD(b& 0x80000000))d = modmul(d,a);
    b << = 1;
    }
    return d;
    }

    // ----------------------------------- ----------------------------------------

    可以看到,函数 shl shr 不再使用。我认为modpow可以进一步优化,但它不是一个关键的函数,因为它只被调用很少次。



    更多问题:




    • 是否还有其他选项来加速NTT?

    • 我的模块算术优化是否安全?



    [edit2]新优化

      a = 0.99991970486 | 2000 * 32位
    循环10x
    sqr1 [13.908 ms]快速sqr
    sqr2 [13.649 ms] NTT sqr
    mul1 [19.726 ms] simpe mul
    mul2 [31.808 ms] karatsuba mul
    mul3 [19.373 ms] NTT mul

    我实现了所有可用的东西



      >
    • + 2.5%通过删除不必要的安全模式(Mandalf The Beige)

    • + 34.9%使用预先计算的W,
    • + 35%总计



    实际完整源代码

      // ---------------------------- ----------------------------------------------- 
    // ---数论变换:2.03 -------------------------------------
    // ---------------------------------------------- -----------------------------
    #ifndef _fourier_NTT_h
    #define _fourier_NTT_h
    // - -------------------------------------------------- ------------------------
    // -------------------- -------------------------------------------------- -----
    class fourier_NTT //数论理论变换
    {
    public:
    DWORD r,L,p,N;
    DWORD W,iW,rN; // W =(r ^ L)mod p,iW = inverse W,rN = inverse N
    DWORD * WW,* iWW,NN; // Precomputed(W,iW)^(0,..,NN-1)powers

    // Internals
    fourier_NTT(){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; WW = NULL; iWW = NULL; NN = 0; }
    〜fourier_NTT(){_free(); }
    void _free(); //自由预计算W,iW幂表
    void _alloc(DWORD n); //分配和预计算W,iW幂表

    //主接口
    void NTT(DWORD * dst,DWORD * src,DWORD n = 0) // DWORD dst [n] = fast NTT(DWORD src [n])
    void iNTT(DWORD * dst,DWORD * src,DWORD n = 0) // DWORD dst [n] = fast INTT(DWORD src [n])

    //帮助函数
    bool init(DWORD n); // init r,L,p,W,iW,rN
    void NTT_fast(DWORD * dst,DWORD * src,DWORD n,DWORD w); // DWORD dst [n] = fast NTT(DWORD src [n])
    void NTT_fast(DWORD * dst,DWORD * src,DWORD n,DWORD * w2,DWORD i2)

    //仅用于测试
    void NTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w); // DWORD dst [n] = slow NTT(DWORD src [n])
    void iNTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w); // DWORD dst [n] = slow INTT(DWORD src [n])

    //模块化算术(优化,但只适用于p> = 0x80000000 !!!)
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };
    // -------------------------------------------- -------------------------------

    // --------- -------------------------------------------------- ----------------
    void fourier_NTT :: _ free()
    {
    NN = 0;
    if(WW)delete [] WW; WW = NULL;
    if(iWW)delete [] iWW; iWW = NULL;
    }

    // ----------------------------------- ----------------------------------------
    void fourier_NTT :: _ alloc( DWORD n)
    {
    if(n <= NN)return;
    DWORD * tmp,i,w;
    tmp = new DWORD [n]; if((NN)&(WW))for(i = 0; i tmp = new DWORD [n]; if((NN)&(iWW))for(i = 0; i NN = n;
    }

    // ----------------------------------- ----------------------------------------
    void fourier_NTT :: NTT( DWORD * dst,DWORD * src,DWORD n)
    {
    if(n> 0)init(n)
    NTT_fast(dst,src,N,WW,1);
    // NTT_fast(dst,src,N,W);
    // NTT_slow(dst,src,N,W);
    }

    // ----------------------------------- ----------------------------------------
    void fourier_NTT :: iNTT( DWORD * dst,DWORD * src,DWORD n)
    {
    if(n> 0)init(n)
    NTT_fast(dst,src,N,iWW,1);
    // NTT_fast(dst,src,N,iW);
    for(DWORD i = 0; i // iNTT_slow(dst,src,N,W);
    }

    // ----------------------------------- ----------------------------------------
    bool fourier_NTT :: init( DWORD n)
    {
    //(max(src [])^ 2)* n < p else NTT溢出可以ocur!
    r = 2; p = 0xC0000001; if((n <2)||(n> 0x10000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x30000000 / n; // 32:30 bit for unsigned 32 bit
    // r = 2; p = 0x78000001; if((n <2)||(n> 0x04000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x3c000000 / n; // 31:27 bit for signed 32 bit
    // r = 2; p = 0x00010001; if((n <2)||(n> 0x00000020)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x00000020 / n; // 17:16 bit for 16 bit
    // r = 2; p = 0x0a000001;如果((n <2)||(n> 0x01000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x01000000 / n; // 28:25 bit
    N = n; //向量的大小[DWORDs]
    W = modpow(r,L); // Wn for NTT
    iW = modpow(r,p-1-L); // Wn for INTT
    rN = modpow(n,p-2); // Scale for INTT
    _alloc(n>> 1); // Precompute W,iW powers
    return true;
    }

    // ----------------------------------- ----------------------------------------
    void fourier_NTT :: NTT_fast( DWORD * dst,DWORD * src,DWORD n,DWORD w)
    {
    if(n <= 1){if(n == 1)dst [0] = src [0]返回; } w B = b DWORD i,j,a0,a1,n2 = n> 1,w2 = modmul(w,w)

    //重新排序even,odd
    for(i = 0,j = 0; i (j = 1; i
    //递归
    NTT_fast(src,dst,n2,w2); // Even
    NTT_fast(src + n2,dst + n2,n2,w2); // Odd

    //为(w2 = 1,i = 0,j = n2; i
    {
    a0 = src [i];
    a1 = modmul(src [j],w2);
    dst [i] = modadd(a0,a1);
    dst [j] = modsub(a0,a1);
    }
    }

    // ------------------------------ ---------------------------------------------
    void fourier_NTT :: NTT_fast(DWORD * dst,DWORD * src,DWORD n,DWORD * w2,DWORD i2)
    {
    if(n <= 1){if(n == 1)dst [0 ] = src [0];返回; }
    DWORD i,j,a0,a1,n2 = n> 1;

    //重新排序even,odd
    for(i = 0,j = 0; i for(j = 1; i
    //递归
    i = i2<< 1;
    NTT_fast(src,dst,n2,w2,i); // Even
    NTT_fast(src + n2,dst + n2,n2,w2,i); // Odd

    //恢复结果
    for(i = 0,j = n2; i {
    a0 = src [i];
    a1 = modmul(src [j],* w2);
    dst [i] = modadd(a0,a1);
    dst [j] = modsub(a0,a1);
    }
    }

    // ------------------------------ ---------------------------------------------
    void fourier_NTT :: NTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a;
    for(wj = 1,j = 0; j {
    a = 0;
    for(wi = 1,i = 0; i {
    a = modadd(a,modmul(wi,src [i]))
    wi = modmul(wi,wj)
    }
    dst [j] = a;
    wj = modmul(wj,w);
    }
    }

    // ------------------------------ ---------------------------------------------
    void fourier_NTT :: iNTT_slow(DWORD * dst,DWORD * src,DWORD n,DWORD w)
    {
    DWORD i,j,wi = 1,wj = 1,a;
    for(wj = 1,j = 0; j {
    a = 0;
    for(wi = 1,i = 0; i {
    a = modadd(a,modmul(wi,src [i]))
    wi = modmul(wi,wj);
    }
    dst [j] = modmul(a,rN);
    wj = modmul(wj,iW);
    }
    }

    // ------------------------------ ---------------------------------------------
    DWORD fourier_NTT :: mod(DWORD a)
    {
    if(a> p)a- = p;
    return a;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modadd( DWORD a,DWORD b)
    {
    DWORD d,cy;
    // if(a> p)a- = p;
    // if(b> p)b- = p;
    d = a + b;
    cy =((a> 1)+(b> 1)+(((a& 1)+(b& 1))> 1))& 0x80000000
    if(cy)d- = p;
    if(d> p)d- = p;
    return d;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modsub( DWORD a,DWORD b)
    {
    DWORD d;
    // if(a> p)a- = p;
    // if(b> p)b- = p;
    d = a-b;
    if(a< b)d + = p;
    if(d> p)d- = p;
    return d;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modmul( DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a = a;
    _b = b;
    _p = p;
    asm {
    mov eax,_a
    mov ebx,_b
    mul ebx // H(edx),L(eax)= eax * ebx
    mov ebx, _p
    div ebx // eax = H(edx),L(eax)/ ebx
    mov _a,edx // edx = H(edx),L(eax)%ebx
    }
    return _a;
    }

    // ----------------------------------- ----------------------------------------
    DWORD fourier_NTT :: modpow( DWORD a,DWORD b)
    {// b不是mod(p)!
    int i;
    DWORD d = 1;
    // if(a> p)a- = p;
    for(i = 0; i <32; i ++)
    {
    d = modmul(d,d)
    if(DWORD(b& 0x80000000))d = modmul(d,a);
    b << = 1;
    }
    return d;
    }
    // --------------------------------------- ------------------------------------
    // -------- -------------------------------------------------- -----------------
    #endif
    // --------------------- -------------------------------------------------- ----
    // ---------------------------------------- -----------------------------------

    通过将 NTT_fast 分隔为两个函数,仍然可以使用更少的堆垃圾。一个用 WW [] ,另一个用 iWW [] ,这导致递归调用中的一个参数较少。但我不期望从它(32位指针)多,而是有一个功能,以便在将来更好的代码管理。现在许多函数都处于休眠状态(用于测试)像慢变量, mod 和旧的快速函数( w code> p / 4 位其中 p NTT 元素的位数因此对于该32位版本使用最大(32位/ 4 - > 8位)输入值。



    [edit3]简单字符串 bigint 测试乘法

      // --------------------------------------------- ------------------------------ 
    char * mul_NTT(const char * sx,const char * sy)
    {
    char * s;
    int i,j,k,n;
    // n = 2的最小幂≤2最大长度(x,y)
    for(i = 0; sx [i]; i ++); for(n = 1; n for(j = 0; sx [j]; j ++);对于(n = 1; n DWORD * x,* y,* xx,* yy,a;
    x = new DWORD [n]; xx = new DWORD [n];
    y = new DWORD [n]; yy = new DWORD [n];

    //对于(k = 0; i> = 0; i - ,k ++)x [k] = sx [i] - '0'的零填充
    ; for(; k for(k = 0; j> = 0; j-,k ++)y [k] = sy [j] for(; k
    // NTT
    fourier_NTT ntt;
    ntt.NTT(xx,x,n);
    ntt.NTT(yy,y);

    //(i = 0; i xx [i] = ntt.modmul(xx [i],yy [i]

    // INTT
    ntt.iNTT(yy,xx);

    // suma
    a = 0; s = new char [n + 1]; for(i = 0; i delete [] x; delete [] xx;
    delete [] y; delete [] yy;

    return s;
    }
    // --------------------------------------- ------------------------------------

    我使用 AnsiString ,所以我将它移植到 char * 希望,我没有做一些错误。它似乎工作正常(与 AnsiString 版本相比)。




    • sx,sy 是十进制整数

    • 返回已分配字符串(char *)= sx * sy



    这只是每32位数据字〜4位,所以没有溢出的风险,慢的当然。在我的 bignum lib中我使用二进制表示,并使用 8位每32位WORD的< 。如果 N 是大的...



    有乐趣

    解决方案

    首先,非常感谢发布和使其免费使用。我真的很感激。



    我能够使用一些技巧来消除一些分支,重新排列主循环,并修改了程序集,并得到一个1.35 x加速。



    此外,我为64位添加了预处理器条件,因为Visual Studio不允许在64位模式下进行内联汇编(谢谢Microsoft;



    当我优化modsub()函数时发生了奇怪的事情。我重写了它使用bit hacks像我做modadd(这是更快)。但是由于某种原因,modsub的位智能版本更慢。不知道为什么。可能只是我的电脑。

      // 
    // Mandalf The Beige
    //基于:
    // Spektre
    // http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations
    //
    //这段代码可以随意选择,只要它伴随着这个通知。
    //




    #ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
    #define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR

    #include<字符串。 h。

    #ifndef uint32
    #define uint32 unsigned long int
    #endif

    #ifndef uint64
    #define uint64 unsigned long long int
    #endif


    class fast_ntt // number theoretic transform
    {
    public:
    fast_ntt()
    {
    r = 0; L = 0;
    W = 0; iW = 0; rN = 0;
    }
    //主接口
    void NTT(uint32 * dst,uint32 * src,uint32 n = 0); // uint32 dst [n] = fast NTT(uint32 src [n])
    void INTT(uint32 * dst,uint32 * src,uint32 n = 0) // uint32 dst [n] = fast INTT(uint32 src [n])
    //帮助函数

    private:
    bool init(uint32 n) // init r,L,p,W,iW,rN
    void NTT_calc(uint32 * dst,uint32 * src,uint32 n,uint32 w); // uint32 dst [n] = fast NTT(uint32 src [n])

    void NTT_fast(uint32 * dst,uint32 * src,uint32 n,uint32 w); // uint32 dst [n] = fast NTT(uint32 src [n])
    void NTT_fast(uint32 * dst,const uint32 * src,uint32 n,uint32 w);
    //仅用于测试
    void NTT_slow(uint32 * dst,uint32 * src,uint32 n,uint32 w); // uint32 dst [n] = slow NTT(uint32 src [n])
    void INTT_slow(uint32 * dst,uint32 * src,uint32 n,uint32 w); // uint32 dst [n] = slow INTT(uint32 src [n])
    // uint32 arithmetics


    //模块算术
    inline uint32 modadd a,uint32 b);
    inline uint32 modsub(uint32 a,uint32 b);
    inline uint32 modmul(uint32 a,uint32 b);
    inline uint32 modpow(uint32 a,uint32 b);

    uint32 r,L,N; //,p;
    uint32 W,iW,rN;

    const uint32 p = 0xC0000001;
    };

    // ---------------------------------------- -----------------------------------
    void fast_ntt :: NTT(uint32 * dst,uint32 * src,uint32 n)
    {
    if(n> 0)
    {
    init(n)
    }
    NTT_fast(dst,src,N,W);
    // NTT_slow(dst,src,N,W);
    }

    // ----------------------------------- ----------------------------------------
    void fast_ntt :: INTT( uint32 * dst,uint32 * src,uint32 n)
    {
    if(n> 0)
    {
    init(n)
    }
    NTT_fast(dst,src,N,iW);
    for(uint32 i = 0; i {
    dst [i] = modmul(dst [i],rN);
    }
    // INTT_slow(dst,src,N,W);
    }

    // ----------------------------------- ----------------------------------------
    bool fast_ntt :: init( uint32 n)
    {
    //(max(src [])^ 2)* n < p else NTT溢出可以ocur!
    r = 2;
    // p = 0xC0000001;
    if((n <2)||(n> 0x10000000))
    {
    r = 0; L = 0; W = 0; // p = 0;
    iW = 0; rN = 0; N = 0;
    return false;
    }
    L = 0x30000000 / n; // 32:30 bit for unsigned 32 bit
    // r = 2; p = 0x78000001; if((n <2)||(n> 0x04000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x3c000000 / n; // 31:27 bit for signed 32 bit
    // r = 2; p = 0x00010001; if((n <2)||(n> 0x00000020)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x00000020 / n; // 17:16 bit for 16 bit
    // r = 2; p = 0x0a000001; if((n <2)||(n> 0x01000000)){r = 0; L = 0; p = 0; W = 0; iW = 0; rN = 0; N = 0; return false; } L = 0x01000000 / n; // 28:25 bit
    N = n; //向量大小[uint32s]
    W = modpow(r,L); // Wn for NTT
    iW = modpow(r,p - 1 - L); // Wn for INTT
    rN = modpow(n, p - 2); // scale for INTT
    return true;
    }

    //---------------------------------------------------------------------------

    void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    if(n > 1)
    {
    if(dst != src)
    {
    NTT_calc(dst, src, n, w);
    }
    else
    {
    uint32* temp = new uint32[n];
    NTT_calc(temp, src, n, w);
    memcpy(dst, temp, n * sizeof(uint32));
    delete [] temp;
    }
    }
    else if(n == 1)
    {
    dst[0] = src[0];
    }
    }

    void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w)
    {
    if (n > 1)
    {
    uint32* temp = new uint32[n];
    memcpy(temp, src, n * sizeof(uint32));
    NTT_calc(dst, temp, n, w);
    delete[] temp;
    }
    else if (n == 1)
    {
    dst[0] = src[0];
    }
    }



    void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    if(n > 1)
    {
    uint32 i, j, a0, a1,
    n2 = n >> 1,
    w2 = modmul(w, w);

    // reorder even,odd
    for (i = 0, j = 0; i < n2; i++, j += 2)
    {
    dst[i] = src[j];
    }
    for (j = 1; i < n; i++, j += 2)
    {
    dst[i] = src[j];
    }
    // recursion
    if(n2 > 1)
    {
    NTT_calc(src, dst, n2, w2); // even
    NTT_calc(src + n2, dst + n2, n2, w2); // odd
    }
    else if(n2 == 1)
    {
    src[0] = dst[0];
    src[1] = dst[1];
    }

    // restore results

    w2 = 1, i = 0, j = n2;
    a0 = src[i];
    a1 = src[j];
    dst[i] = modadd(a0, a1);
    dst[j] = modsub(a0, a1);
    while (++i < n2)
    {
    w2 = modmul(w2, w);
    j++;
    a0 = src[i];
    a1 = modmul(src[j], w2);
    dst[i] = modadd(a0, a1);
    dst[j] = modsub(a0, a1);
    }
    }
    }

    //---------------------------------------------------------------------------
    void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    uint32 i, j, wj, wi, a,
    n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
    a = 0;
    for (wi = 1, i = 0; i < n; i++)
    {
    a = modadd(a, modmul(wi, src[i]));
    wi = modmul(wi, wj);
    }
    dst[j] = a;
    wj = modmul(wj, w);
    }
    }

    //---------------------------------------------------------------------------
    void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
    uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1;

    for (wj = 1, j = 0; j < n; j++)
    {
    a = 0;
    for (wi = 1, i = 0; i < n; i++)
    {
    a = modadd(a, modmul(wi, src[i]));
    wi = modmul(wi, wj);
    }
    dst[j] = modmul(a, rN);
    wj = modmul(wj, iW);
    }
    }


    //---------------------------------------------------------------------------
    uint32 fast_ntt::modadd(uint32 a, uint32 b)
    {
    uint32 d;
    d = a + b;

    if(d < a)
    {
    d -= p;
    }
    if (d >= p)
    {
    d -= p;
    }
    return d;
    }

    //---------------------------------------------------------------------------
    uint32 fast_ntt::modsub(uint32 a, uint32 b)
    {
    uint32 d;
    d = a - b;
    if (d > a)
    {
    d += p;
    }
    return d;
    }

    //---------------------------------------------------------------------------
    uint32 fast_ntt::modmul(uint32 a, uint32 b)
    {
    uint32 _a = a;
    uint32 _b = b;

    // Original
    uint32 _p = p;
    __asm
    {
    mov eax, _a;
    mul _b;
    div _p;
    mov eax, edx;
    };
    }


    uint32 fast_ntt::modpow(uint32 a, uint32 b)
    {
    //*
    uint64 D, M, A, P;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
    A = modmul(A, A);
    //A = (A * A) % P;

    if ((b & 1) == 1)
    {
    //D = (D * A) % P;
    D = modmul(D, A);
    }
    }
    return (uint32)D;
    }

    New modmul

    uint32 fast_ntt::modmul(uint32 a, uint32 b) 
    {
    uint32 _a = a;
    uint32 _b = b;

    __asm
    {
    mov eax, a;
    mul b;
    mov ebx, eax;
    mov eax, 2863311530;
    mov ecx, edx;
    mul edx;
    shld edx, eax, 1;
    mov eax, 3221225473;

    mul edx;
    sub ebx, eax;
    mov eax, 3221225473;
    sbb ecx, edx;
    jc addback;

    neg ecx;
    and ecx, eax;
    sub ebx, ecx;

    sub ebx, eax;
    sbb edx, edx;
    and eax, edx;
    addback:
    add eax, ebx;
    };
    }


    Spektre, based on your feedback I changed the modadd & modsub back to their original. I also realized I made some changes to the recursive NTT function I shouldn’t have.




    Removed unneeded if statements and bitwise functions.




    Added new modmul inline assembly.


    I wanted to use NTT for fast squaring (see Fast bignum square computation), but the result is slow even for really big numbers .. more than 12000 bits.

    So my question is:

    1. Is there a way to optimize my NTT transform? I did not mean to speed it by parallelism (threads); this is low-level layer only.
    2. Is there a way to speed up my modular arithmetics?

    This is my (already optimized) source code in C++ for NTT (it's complete and 100% working in C++ whitout any need for third-party libs and should also be thread-safe. Beware the source array is used as a temporary!!!, Also it cannot transform the array to itself).

    //---------------------------------------------------------------------------
    class fourier_NTT                                    // Number theoretic transform
        {
    
    public:
        DWORD r,L,p,N;
        DWORD W,iW,rN;
        fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; }
    
        // main interface
        void  NTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast  NTT(DWORD src[n])
        void INTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])
    
        // Helper functions
        bool init(DWORD n);                                       // init r,L,p,W,iW,rN
        void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])
    
        // Only for testing
        void  NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
        void INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])
    
        // DWORD arithmetics
        DWORD shl(DWORD a);
        DWORD shr(DWORD a);
    
        // Modular arithmetics
        DWORD mod(DWORD a);
        DWORD modadd(DWORD a,DWORD b);
        DWORD modsub(DWORD a,DWORD b);
        DWORD modmul(DWORD a,DWORD b);
        DWORD modpow(DWORD a,DWORD b);
        };
    
    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
        {
        if (n>0) init(n);
        NTT_fast(dst,src,N,W);
    //    NTT_slow(dst,src,N,W);
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT::INTT(DWORD *dst,DWORD *src,DWORD n)
        {
        if (n>0) init(n);
        NTT_fast(dst,src,N,iW);
        for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
           //    INTT_slow(dst,src,N,W);
        }
    
    //---------------------------------------------------------------------------
    bool fourier_NTT::init(DWORD n)
        {
        // (max(src[])^2)*n < p else NTT overflow can ocur !!!
        r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
    //    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    //    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    //    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
         N=n;                // size of vectors [DWORDs]
         W=modpow(r,    L);    // Wn for NTT
        iW=modpow(r,p-1-L);    // Wn for INTT
        rN=modpow(n,p-2  );    // scale for INTT
        return true;
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
        {
        if (n<=1) { if (n==1) dst[0]=src[0]; return; }
        DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
        // reorder even,odd
        for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
        for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];
        // recursion
        NTT_fast(src   ,dst   ,n2,w2);    // even
        NTT_fast(src+n2,dst+n2,n2,w2);    // odd
        // restore results
        for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
            {
            a0=src[i];
            a1=modmul(src[j],w2);
            dst[i]=modadd(a0,a1);
            dst[j]=modsub(a0,a1);
            }
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
        {
        DWORD i,j,wj,wi,a,n2=n>>1;
        for (wj=1,j=0;j<n;j++)
            {
            a=0;
            for (wi=1,i=0;i<n;i++)
                {
                a=modadd(a,modmul(wi,src[i]));
                wi=modmul(wi,wj);
                }
            dst[j]=a;
            wj=modmul(wj,w);
            }
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT::INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
        {
        DWORD i,j,wi=1,wj=1,a,n2=n>>1;
        for (wj=1,j=0;j<n;j++)
            {
            a=0;
            for (wi=1,i=0;i<n;i++)
                {
                a=modadd(a,modmul(wi,src[i]));
                wi=modmul(wi,wj);
                }
            dst[j]=modmul(a,rN);
            wj=modmul(wj,iW);
            }
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; }
    DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::mod(DWORD a)
        {
        DWORD bb;
        for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb));
        for (;;)
            {
            if (DWORD(a)>=DWORD(bb)) a-=bb;
            if (bb==p) break;
            bb =shr(bb);
            }
        return a;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modadd(DWORD a,DWORD b)
        {
        DWORD d,cy;
        a=mod(a);
        b=mod(b);
        d=a+b;
        cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000;
        if (cy) d-=p;
        if (DWORD(d)>=DWORD(p)) d-=p;
        return d;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modsub(DWORD a,DWORD b)
        {
        DWORD d;
        a=mod(a);
        b=mod(b);
        d=a-b; if (DWORD(a)<DWORD(b)) d+=p;
        if (DWORD(d)>=DWORD(p)) d-=p;
        return d;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modmul(DWORD a,DWORD b)
        {    // b bez orezania !
        int i;
        DWORD d;
        a=mod(a);
        for (d=0,i=0;i<32;i++)
            {
            if (DWORD(a&1))    d=modadd(d,b);
            a=shr(a);
            b=modadd(b,b);
            }
        return d;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modpow(DWORD a,DWORD b)
        {    // a,b bez orezania !
        int i;
        DWORD d=1;
        for (i=0;i<32;i++)
            {
            d=modmul(d,d);
            if (DWORD(b&0x80000000)) d=modmul(d,a);
            b=shl(b);
            }
        return d;
        }
    //---------------------------------------------------------------------------
    

    Example of usage of my NTT class:

    fourier_NTT ntt;
    const DWORD n=32
    DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N];
    
    ntt.NTT(z,x,N);    // z[N]=NTT(x[N]), also init constants for N
    ntt.NTT(x,y);    // x[N]=NTT(y[N]), no recompute of constants, use last N
    // modular convolution y[]=z[].x[]
    for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]);
    ntt.INTT(x,y);    // x[N]=INTT(y[N]), no recompute of constants, use last N
    // x[]=convolution of original x[].y[]
    

    Some measurements before optimizations (non Class NTT):

    a = 0.98765588997654321000 | 389*32 bits
    looped 1x times
    sqr1[ 3.177 ms ] fast sqr
    sqr2[ 720.419 ms ] NTT sqr
    mul1[ 5.588 ms ] simpe mul
    mul2[ 3.172 ms ] karatsuba mul
    mul3[ 1053.382 ms ] NTT mul
    

    Some measurements after my optimizations (current code, lower recursion parameter size/count, and better modular arithmetics):

    a = 0.98765588997654321000 | 389*32 bits
    looped 1x times
    sqr1[ 3.214 ms ] fast sqr
    sqr2[ 208.298 ms ] NTT sqr
    mul1[ 5.564 ms ] simpe mul
    mul2[ 3.113 ms ] karatsuba mul
    mul3[ 302.740 ms ] NTT mul
    

    Check the NTT mul and NTT sqr times (my optimizations speed it up little over 3x times). It's only 1x times loop so it's not very precise (error ~ 10%), but the speedup is noticeable even now (normally I loop it 1000x and more, but my NTT is too slow for that).

    You can use my code freely... Just keep my nick and/or link to this page somewhere (rem in code, readme.txt, about or whatever). I hope it helps... (I did not see C++ source for fast NTTs anywhere so I had to write it by myself). Roots of unity were tested for all accepted N, see the fourier_NTT::init(DWORD n) function.

    P.S.: For more information about NTT, see http://stackoverflow.com/a/18547575/2521214. This code is based on my posts inside that link.

    [edit1:] Further changes in the code

    I managed to further optimize my modular arithmetics, by exploiting that modulo prime is allways 0xC0000001 and eliminating unnecessary calls. The resulting speedup is stunning (more than 40x times) now and NTT multiplication is faster than karatsuba after about the 1500 * 32 bits threshold. BTW, the speed of my NTT is now the same as my optimized DFFT on 64-bit doubles.

    Some measurements:

    a = 0.98765588997654321000 | 1553*32bits
    looped 10x times
    mul2[ 28.585 ms ] karatsuba mul
    mul3[ 26.311 ms ] NTT mul
    

    New source code for modular arithmetics:

    //---------------------------------------------------------------------------
    DWORD fourier_NTT::mod(DWORD a)
        {
        if (a>p) a-=p;
        return a;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modadd(DWORD a,DWORD b)
        {
        DWORD d,cy;
        if (a>p) a-=p;
        if (b>p) b-=p;
        d=a+b;
        cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
        if (cy ) d-=p;
        if (d>p) d-=p;
        return d;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modsub(DWORD a,DWORD b)
        {
        DWORD d;
        if (a>p) a-=p;
        if (b>p) b-=p;
        d=a-b;
        if (a<b) d+=p;
        if (d>p) d-=p;
        return d;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modmul(DWORD a,DWORD b)
        {
        DWORD _a,_b,_p;
        _a=a;
        _b=b;
        _p=p;
        asm    {
            mov    eax,_a
            mov    ebx,_b
            mul    ebx        // H(edx),L(eax) = eax * ebx
            mov    ebx,_p
            div    ebx        // eax = H(edx),L(eax) / ebx
            mov    _a,edx    // edx = H(edx),L(eax) % ebx
            }
        return _a;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modpow(DWORD a,DWORD b)
        {    // b bez orezania!
        int i;
        DWORD d=1;
        if (a>p) a-=p;
        for (i=0;i<32;i++)
            {
            d=modmul(d,d);
            if (DWORD(b&0x80000000)) d=modmul(d,a);
            b<<=1;
            }
        return d;
        }
    
    //---------------------------------------------------------------------------
    

    As you can see, functions shl and shr are no more used. I think that modpow can be further optimized, but it's not a critical function because it is called only very few times. The most critical function is modmul, and that seems to be in the best shape possible.

    Further questions:

    • Is there any other option to speedup NTT?
    • Are my optimizations of modular arithmetics safe? (Results seem to be the same, but I could miss something.)

    [edit2] New optimizations

    a = 0.99991970486 | 2000*32 bits
    looped 10x
    sqr1[  13.908 ms ] fast sqr
    sqr2[  13.649 ms ] NTT sqr
    mul1[  19.726 ms ] simpe mul
    mul2[  31.808 ms ] karatsuba mul
    mul3[  19.373 ms ] NTT mul
    

    I implemented all the usable stuff from all of your comments (thanks for the insight).

    Speedups:

    • +2.5% by removing unnecessary safety mods (Mandalf The Beige)
    • +34.9% by use of precomputed W,iW powers (Mysticial)
    • +35% total

    Actual full source code:

    //---------------------------------------------------------------------------
    //--- Number theoretic transforms: 2.03 -------------------------------------
    //---------------------------------------------------------------------------
    #ifndef _fourier_NTT_h
    #define _fourier_NTT_h
    //---------------------------------------------------------------------------
    //---------------------------------------------------------------------------
    class fourier_NTT        // Number theoretic transform
        {
    public:
        DWORD r,L,p,N;
        DWORD W,iW,rN;        // W=(r^L) mod p, iW=inverse W, rN = inverse N
        DWORD *WW,*iWW,NN;    // Precomputed (W,iW)^(0,..,NN-1) powers
    
        // Internals
        fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; }
        ~fourier_NTT(){ _free(); }
        void _free();                                            // Free precomputed W,iW powers tables
        void _alloc(DWORD n);                                    // Allocate and precompute W,iW powers tables
    
        // Main interface
        void  NTT(DWORD *dst,DWORD *src,DWORD n=0);                // DWORD dst[n] = fast  NTT(DWORD src[n])
        void iNTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])
    
        // Helper functions
        bool init(DWORD n);                                          // init r,L,p,W,iW,rN
        void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])
        void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2);
    
        // Only for testing
        void  NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
        void iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])
    
        // Modular arithmetics (optimized, but it works only for p >= 0x80000000!!!)
        DWORD mod(DWORD a);
        DWORD modadd(DWORD a,DWORD b);
        DWORD modsub(DWORD a,DWORD b);
        DWORD modmul(DWORD a,DWORD b);
        DWORD modpow(DWORD a,DWORD b);
        };
    //---------------------------------------------------------------------------
    
    //---------------------------------------------------------------------------
    void fourier_NTT::_free()
        {
        NN=0;
        if ( WW) delete[]  WW;  WW=NULL;
        if (iWW) delete[] iWW; iWW=NULL;
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT::_alloc(DWORD n)
        {
        if (n<=NN) return;
        DWORD *tmp,i,w;
        tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[]  WW;  WW=tmp;  WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w, W);  WW[i]=w; }
        tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; }
        NN=n;
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
        {
        if (n>0) init(n);
        NTT_fast(dst,src,N,WW,1);
    //    NTT_fast(dst,src,N,W);
    //    NTT_slow(dst,src,N,W);
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT::iNTT(DWORD *dst,DWORD *src,DWORD n)
        {
        if (n>0) init(n);
        NTT_fast(dst,src,N,iWW,1);
    //    NTT_fast(dst,src,N,iW);
        for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
    //    iNTT_slow(dst,src,N,W);
        }
    
    //---------------------------------------------------------------------------
    bool fourier_NTT::init(DWORD n)
        {
        // (max(src[])^2)*n < p else NTT overflow can ocur!!!
        r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
    //    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    //    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    //    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
         N=n;                // Size of vectors [DWORDs]
         W=modpow(r,    L);  // Wn for NTT
        iW=modpow(r,p-1-L);  // Wn for INTT
        rN=modpow(n,p-2  );  // Scale for INTT
        _alloc(n>>1);        // Precompute W,iW powers
        return true;
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
        {
        if (n<=1) { if (n==1) dst[0]=src[0]; return; }
        DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
    
        // Reorder even,odd
        for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
        for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];
    
        // Recursion
        NTT_fast(src   ,dst   ,n2,w2);    // Even
        NTT_fast(src+n2,dst+n2,n2,w2);    // Odd
    
        // Restore results
        for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
            {
            a0=src[i];
            a1=modmul(src[j],w2);
            dst[i]=modadd(a0,a1);
            dst[j]=modsub(a0,a1);
            }
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2)
        {
        if (n<=1) { if (n==1) dst[0]=src[0]; return; }
        DWORD i,j,a0,a1,n2=n>>1;
    
        // Reorder even,odd
        for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
        for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];
    
        // Recursion
        i=i2<<1;
        NTT_fast(src   ,dst   ,n2,w2,i);    // Even
        NTT_fast(src+n2,dst+n2,n2,w2,i);    // Odd
    
        // Restore results
        for (i=0,j=n2;i<n2;i++,j++,w2+=i2)
            {
            a0=src[i];
            a1=modmul(src[j],*w2);
            dst[i]=modadd(a0,a1);
            dst[j]=modsub(a0,a1);
            }
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
        {
        DWORD i,j,wj,wi,a;
        for (wj=1,j=0;j<n;j++)
            {
            a=0;
            for (wi=1,i=0;i<n;i++)
                {
                a=modadd(a,modmul(wi,src[i]));
                wi=modmul(wi,wj);
                }
            dst[j]=a;
            wj=modmul(wj,w);
            }
        }
    
    //---------------------------------------------------------------------------
    void fourier_NTT::iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
        {
        DWORD i,j,wi=1,wj=1,a;
        for (wj=1,j=0;j<n;j++)
            {
            a=0;
            for (wi=1,i=0;i<n;i++)
                {
                a=modadd(a,modmul(wi,src[i]));
                wi=modmul(wi,wj);
                }
            dst[j]=modmul(a,rN);
            wj=modmul(wj,iW);
            }
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::mod(DWORD a)
        {
        if (a>p) a-=p;
        return a;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modadd(DWORD a,DWORD b)
        {
        DWORD d,cy;
        //if (a>p) a-=p;
        //if (b>p) b-=p;
        d=a+b;
        cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
        if (cy ) d-=p;
        if (d>p) d-=p;
        return d;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modsub(DWORD a,DWORD b)
        {
        DWORD d;
        //if (a>p) a-=p;
        //if (b>p) b-=p;
        d=a-b;
        if (a<b) d+=p;
        if (d>p) d-=p;
        return d;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modmul(DWORD a,DWORD b)
        {
        DWORD _a,_b,_p;
        _a=a;
        _b=b;
        _p=p;
        asm    {
            mov    eax,_a
            mov    ebx,_b
            mul    ebx        // H(edx),L(eax) = eax * ebx
            mov    ebx,_p
            div    ebx        // eax = H(edx),L(eax) / ebx
            mov    _a,edx    // edx = H(edx),L(eax) % ebx
            }
        return _a;
        }
    
    //---------------------------------------------------------------------------
    DWORD fourier_NTT::modpow(DWORD a,DWORD b)
        {    // b is not mod(p)!
        int i;
        DWORD d=1;
        //if (a>p) a-=p;
        for (i=0;i<32;i++)
            {
            d=modmul(d,d);
            if (DWORD(b&0x80000000)) d=modmul(d,a);
            b<<=1;
            }
        return d;
        }
    //---------------------------------------------------------------------------
    //---------------------------------------------------------------------------
    #endif
    //---------------------------------------------------------------------------
    //---------------------------------------------------------------------------
    

    There is still the possibility to use less heap trashing by separating NTT_fast to two functions. One with WW[] and the other with iWW[] which leads to one parameter less in recursion calls. But I do not expect much from it (32-bit pointer only) and rather have one function for better code management in the future. Many functions are dormant now (for testing) Like slow variants, mod and the older fast function (with w parameter instead of *w2,i2).

    To avoid overflows for big datasets, limit input numbers to p/4 bits. Where p is number of bits per NTT element so for this 32 bit version use max (32 bit/4 -> 8 bit) input values.

    [edit3] Simple string bigint multiplication for testing

    //---------------------------------------------------------------------------
    char* mul_NTT(const char *sx,const char *sy)
        {
        char *s;
        int i,j,k,n;
        // n = min power of 2 <= 2 max length(x,y)
        for (i=0;sx[i];i++); for (n=1;n<i;n<<=1);        i--;
        for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--;
        DWORD *x,*y,*xx,*yy,a;
        x=new DWORD[n]; xx=new DWORD[n];
        y=new DWORD[n]; yy=new DWORD[n];
    
        // Zero padding
        for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0;
        for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0;
    
        //NTT
        fourier_NTT ntt;
        ntt.NTT(xx,x,n);
        ntt.NTT(yy,y);
    
        // Convolution
        for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]);
    
        //INTT
        ntt.iNTT(yy,xx);
    
        //suma
        a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0;
        delete[] x; delete[] xx;
        delete[] y; delete[] yy;
    
        return s;
        }
    //---------------------------------------------------------------------------
    

    I use AnsiString's, so I port it to char* hopefully, I did not do some mistake. It looks like it works properly (in comparison to the AnsiString version).

    • sx,sy are decadic integer numbers
    • Returns allocated string (char*)=sx*sy

    This is only ~4 bit per 32 bit data word so there is no risk of overflow, but it is slower of course. In my bignum lib I use a binary representation and use 8 bit chunks per 32-bit WORD for NTT. More than that is risky if N is big ...

    Have fun with this

    解决方案

    First off, thank you very much for posting and making it free to use. I really appreciate that.

    I was able to use some bit tricks to eliminate some branching, rearranged the main loop, and modified the assembly, and was able to get a 1.35x speedup.

    Also, I added a preprocessor condition for 64 bit, seeing as Visual Studio doesn't allow inline assembly in 64 bit mode (thank you Microsoft; feel free to go screw yourself).

    Something strange happened when I was optimizing the modsub() function. I rewrote it using bit hacks like I did modadd (which was faster). But for some reason, the bit wise version of modsub was slower. Not sure why. Might just be my computer.

    //
    // Mandalf The Beige
    // Based on:
    // Spektre
    // http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations
    //
    // This code may be freely used however you choose, so long as it is accompanied by this notice.
    //
    
    
    
    
    #ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
    #define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
    
    #include <string.h>
    
    #ifndef uint32
    #define uint32 unsigned long int
    #endif
    
    #ifndef uint64
    #define uint64 unsigned long long int
    #endif
    
    
    class fast_ntt                                   // number theoretic transform
    {
        public:
        fast_ntt()
        {
            r = 0; L = 0;
            W = 0; iW = 0; rN = 0;
        }
        // main interface
        void  NTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast  NTT(uint32 src[n])
        void INTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast INTT(uint32 src[n])
        // helper functions
    
        private:
        bool init(uint32 n);                                     // init r,L,p,W,iW,rN
        void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
    
        void  NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
        void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
        // only for testing
        void  NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow  NTT(uint32 src[n])
        void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow INTT(uint32 src[n])
        // uint32 arithmetics
    
    
        // modular arithmetics
        inline uint32 modadd(uint32 a, uint32 b);
        inline uint32 modsub(uint32 a, uint32 b);
        inline uint32 modmul(uint32 a, uint32 b);
        inline uint32 modpow(uint32 a, uint32 b);
    
        uint32 r, L, N;//, p;
        uint32 W, iW, rN;
    
        const uint32 p = 0xC0000001;
    };
    
    //---------------------------------------------------------------------------
    void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
    {
        if (n > 0)
        {
            init(n);
        }
        NTT_fast(dst, src, N, W);
        //  NTT_slow(dst,src,N,W);
    }
    
    //---------------------------------------------------------------------------
    void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n)
    {
        if (n > 0)
        {
            init(n);
        }
        NTT_fast(dst, src, N, iW);
        for (uint32 i = 0; i<N; i++)
        {
            dst[i] = modmul(dst[i], rN);
        }
        //  INTT_slow(dst,src,N,W);
    }
    
    //---------------------------------------------------------------------------
    bool fast_ntt::init(uint32 n)
    {
        // (max(src[])^2)*n < p else NTT overflow can ocur !!!
        r = 2;
        //p = 0xC0000001;
        if ((n < 2) || (n > 0x10000000))
        {
            r = 0; L = 0; W = 0; // p = 0;
            iW = 0; rN = 0; N = 0;
            return false;
        }
        L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
        //  r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
        //  r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
        //  r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
        N = n;               // size of vectors [uint32s]
        W = modpow(r, L); // Wn for NTT
        iW = modpow(r, p - 1 - L); // Wn for INTT
        rN = modpow(n, p - 2); // scale for INTT
        return true;
    }
    
    //---------------------------------------------------------------------------
    
    void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
        if(n > 1)
        {
            if(dst != src)
            {
                NTT_calc(dst, src, n, w);
            }
            else
            {
                uint32* temp = new uint32[n];
                NTT_calc(temp, src, n, w);
                memcpy(dst, temp, n * sizeof(uint32));
                delete [] temp;
            }
        }
        else if(n == 1)
        {
            dst[0] = src[0];
        }
    }
    
    void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w)
    {
        if (n > 1)
        {
            uint32* temp = new uint32[n];
            memcpy(temp, src, n * sizeof(uint32));
            NTT_calc(dst, temp, n, w);
            delete[] temp;
        }
        else if (n == 1)
        {
            dst[0] = src[0];
        }
    }
    
    
    
    void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
        if(n > 1)
        {
            uint32 i, j, a0, a1,
            n2 = n >> 1,
            w2 = modmul(w, w);
    
            // reorder even,odd
            for (i = 0, j = 0; i < n2; i++, j += 2)
            {
                dst[i] = src[j];
            }
            for (j = 1; i < n; i++, j += 2)
            {
                dst[i] = src[j];
            }
            // recursion
            if(n2 > 1)
            {
                NTT_calc(src, dst, n2, w2);  // even
                NTT_calc(src + n2, dst + n2, n2, w2);  // odd
            }
            else if(n2 == 1)
            {
                src[0] = dst[0];
                src[1] = dst[1];
            }
    
            // restore results
    
            w2 = 1, i = 0, j = n2;
            a0 = src[i];
            a1 = src[j];
            dst[i] = modadd(a0, a1);
            dst[j] = modsub(a0, a1);
            while (++i < n2)
            {
                w2 = modmul(w2, w);
                j++;
                a0 = src[i];
                a1 = modmul(src[j], w2);
                dst[i] = modadd(a0, a1);
                dst[j] = modsub(a0, a1);
            }
        }
    }
    
    //---------------------------------------------------------------------------
    void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
        uint32 i, j, wj, wi, a,
            n2 = n >> 1;
        for (wj = 1, j = 0; j < n; j++)
        {
            a = 0;
            for (wi = 1, i = 0; i < n; i++)
            {
                a = modadd(a, modmul(wi, src[i]));
                wi = modmul(wi, wj);
            }
            dst[j] = a;
            wj = modmul(wj, w);
        }
    }
    
    //---------------------------------------------------------------------------
    void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
    {
        uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1;
    
        for (wj = 1, j = 0; j < n; j++)
        {
            a = 0;
            for (wi = 1, i = 0; i < n; i++)
            {
                a = modadd(a, modmul(wi, src[i]));
                wi = modmul(wi, wj);
            }
            dst[j] = modmul(a, rN);
            wj = modmul(wj, iW);
        }
    }    
    
    
    //---------------------------------------------------------------------------
    uint32 fast_ntt::modadd(uint32 a, uint32 b)
    {
        uint32 d;
        d = a + b;
    
        if(d < a)
        {
            d -= p;
        }
        if (d >= p)
        {
            d -= p;
        }
        return d;
    }
    
    //---------------------------------------------------------------------------
    uint32 fast_ntt::modsub(uint32 a, uint32 b)
    {
        uint32 d;
        d = a - b;
        if (d > a)
        {
            d += p;
        }
        return d;
    }
    
    //---------------------------------------------------------------------------
    uint32 fast_ntt::modmul(uint32 a, uint32 b)
    {
        uint32 _a = a;
        uint32 _b = b;
    
        // Original
        uint32 _p = p;
        __asm
        {
            mov eax, _a;
            mul _b;
            div _p;
            mov eax, edx;
        };
    }
    
    
    uint32 fast_ntt::modpow(uint32 a, uint32 b)
    {
        //*
        uint64 D, M, A, P;
    
        P = p; A = a;
        M = 0llu - (b & 1);
        D = (M & A) | ((~M) & 1);
    
        while ((b >>= 1) != 0)
        {
            A = modmul(A, A);
            //A = (A * A) % P;
    
            if ((b & 1) == 1)
            {
                //D = (D * A) % P;
                D = modmul(D, A);
            }
        }
        return (uint32)D;
    }
    

    New modmul

    uint32 fast_ntt::modmul(uint32 a, uint32 b)
    {
        uint32 _a = a;
        uint32 _b = b;   
    
        __asm
        {
        mov eax, a;
        mul b;
        mov ebx, eax;
        mov eax, 2863311530;
        mov ecx, edx;
        mul edx;
        shld edx, eax, 1;
        mov eax, 3221225473;
    
        mul edx;
        sub ebx, eax;
        mov eax, 3221225473;
        sbb ecx, edx;
        jc addback;
    
                neg ecx;
                and ecx, eax;
                sub ebx, ecx;
    
        sub ebx, eax;
        sbb edx, edx;
        and eax, edx;
                addback:
        add eax, ebx;          
        };  
    }
    

    [EDIT] Spektre, based on your feedback I changed the modadd & modsub back to their original. I also realized I made some changes to the recursive NTT function I shouldn't have.

    [EDIT2] Removed unneeded if statements and bitwise functions.

    [EDIT3] Added new modmul inline assembly.

    这篇关于模块化算术和NTT(有限域DFT)优化的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆