使用SSE最快实现自然指数函数 [英] Fastest Implementation of the Natural Exponential Function Using SSE

查看:207
本文介绍了使用SSE最快实现自然指数函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在寻找在SSE元素上运行的自然指数函数的近似值.即-__m128 exp( __m128 x ).

I'm looking for an approximation of the natural exponential function operating on SSE element. Namely - __m128 exp( __m128 x ).

我有一个快速执行的方法,但准确性似乎很低:

I have an implementation which is quick but seems to be very low in accuracy:

static inline __m128 FastExpSse(__m128 x)
{
    __m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
    __m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
    __m128  m87 = _mm_set1_ps(-87);
    // fast exponential function, x should be in [-87, 87]
    __m128 mask = _mm_cmpge_ps(x, m87);

    __m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
    return _mm_and_ps(_mm_castsi128_ps(tmp), mask);
}

有人能以更快(或更快)的精度实现更好的实现吗?

Could anybody have an implementation with better accuracy yet as fast (Or faster)?

如果它是用C风格编写的,我会很高兴.

I'd be happy if it is written in C Style.

谢谢.

推荐答案

以下C代码将我在基本思想是将标准指数函数的计算转换为2的幂的计算:expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504).我们将t = x * 1.44269504拆分为整数i和分数f,这样t = i + f0 <= f <= 1.现在,我们可以使用多项式逼近来计算2 f ,然后通过在单精度浮点结果的指数字段中添加i来将结果缩放2 i

The basic idea is to transform the computation of the standard exponential function into computation of a power of 2: expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504). We split t = x * 1.44269504 into an integer i and a fraction f, such that t = i + f and 0 <= f <= 1. We can now compute 2f with a polynomial approximation, then scale the result by 2i by adding i to the exponent field of the single-precision floating-point result.

SSE实现中存在的一个问题是我们要计算i = floorf (t),但是没有快速的方法来计算floor()函数.但是,我们观察到,对于floor(x) == trunc(x)为正数,对于floor(x) == trunc(x) - 1为负数,除非x是负整数.但是,由于核心近似值可以处理f1.0f,因此对负参数使用近似值是无害的. SSE提供了一条指令,可以将单精度浮点操作数转换为带有截断的整数,因此该解决方案非常有效.

One problem that exists with an SSE implementation is that we want to compute i = floorf (t), but there is no fast way to compute the floor() function. However, we observe that for positive numbers, floor(x) == trunc(x), and that for negative numbers, floor(x) == trunc(x) - 1, except when x is a negative integer. However, since the core approximation can handle an f value of 1.0f, using the approximation for negative arguments is harmless. SSE provides an instruction to convert single-precision floating point operands to integers with truncation, so this solution is efficient.

彼得·科德斯(Peter Cordes)指出SSE4.1支持快速下限功能_mm_floor_ps(),因此使用SSE4.1也如下所示.启用S​​SE 4.1代码生成后,并非所有工具链都会自动预定义宏__SSE4_1__,但是gcc会自动预定义宏.

Peter Cordes points out that SSE4.1 supports a fast floor function _mm_floor_ps(), so a variant using SSE4.1 is also shown below. Not all toolchains automatically predefine the macro __SSE4_1__ when SSE 4.1 code generation is enabled, but gcc does.

Compiler Explorer(Godbolt)显示gcc 7.2将以下代码编译为纯SSE的十六个指令针对SSE 4.1的十二个说明.

Compiler Explorer (Godbolt) shows that gcc 7.2 compiles the code below into sixteen instructions for plain SSE and twelve instructions for SSE 4.1.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif

/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 t, f, e, p, r;
    __m128i i, j;
    __m128 l2e = _mm_set1_ps (1.442695041f);  /* log2(e) */
    __m128 c0  = _mm_set1_ps (0.3371894346f);
    __m128 c1  = _mm_set1_ps (0.657636276f);
    __m128 c2  = _mm_set1_ps (1.00172476f);

    /* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */   
    t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
#ifdef __SSE4_1__
    e = _mm_floor_ps (t);                /* floor(t) */
    i = _mm_cvtps_epi32 (e);             /* (int)floor(t) */
#else /* __SSE4_1__*/
    i = _mm_cvttps_epi32 (t);            /* i = (int)t */
    j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
    i = _mm_sub_epi32 (i, j);            /* (int)t - signbit(t) */
    e = _mm_cvtepi32_ps (i);             /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
    f = _mm_sub_ps (t, e);               /* f = t - floor(t) */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
    j = _mm_slli_epi32 (i, 23);          /* i << 23 */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

int main (void)
{
    union {
        float f[4];
        unsigned int i[4];
    } arg, res;
    double relerr, maxrelerr = 0.0;
    int i, j;
    __m128 x, y;

    float start[2] = {-0.0f, 0.0f};
    float finish[2] = {-87.33654f, 88.72283f};

    for (i = 0; i < 2; i++) {

        arg.f[0] = start[i];
        arg.i[1] = arg.i[0] + 1;
        arg.i[2] = arg.i[0] + 2;
        arg.i[3] = arg.i[0] + 3;
        do {
            memcpy (&x, &arg, sizeof(x));
            y = fast_exp_sse (x);
            memcpy (&res, &y, sizeof(y));
            for (j = 0; j < 4; j++) {
                double ref = exp ((double)arg.f[j]);
                relerr = fabs ((res.f[j] - ref) / ref);
                if (relerr > maxrelerr) {
                    printf ("arg=% 15.8e  res=%15.8e  ref=%15.8e  err=%15.8e\n", 
                            arg.f[j], res.f[j], ref, relerr);
                    maxrelerr = relerr;
                }
            }   
            arg.i[0] += 4;
            arg.i[1] += 4;
            arg.i[2] += 4;
            arg.i[3] += 4;
        } while (fabsf (arg.f[3]) < fabsf (finish[i]));
    }
    printf ("maximum relative errror = %15.8e\n", maxrelerr);
    return EXIT_SUCCESS;
}

fast_sse_exp()的另一种设计,是使用四舍五入到最接近模式提取调整后的自变量x / log(2)的整数部分的方法,它使用了众所周知的添加魔术"转换常数1.5 * 2 23的技术. 强制四舍五入到正确的位位置,然后再次减去相同的数字.这要求加法期间有效的SSE舍入模式为舍入到最接近或什至",这是默认设置. wim 在评论中指出,一些编译器可能会在激进时优化转换常数cvt的加减运算由于使用了优化,因此会干扰此代码序列的功能,因此建议检查生成的机器代码.现在计算2 f 的近似间隔以零为中心,因为-0.5 <= f <= 0.5需要不同的核心近似.

An alternative design for fast_sse_exp() extracts the integer portion of the adjusted argument x / log(2) in round-to-nearest mode, using the well-known technique of adding the "magic" conversion constant 1.5 * 223 to force rounding in the correct bit position, then subtracting out the same number again. This requires that the SSE rounding mode in effect during the addition is "round to nearest or even", which is the default. wim pointed out in comments that some compilers may optimize out the addition and subtraction of the conversion constant cvt as redundant when aggressive optimization is used, interfering with the functionality of this code sequence, so it is recommended to inspect the machine code generated. The approximation interval for computation of 2f is now centered around zero, since -0.5 <= f <= 0.5, requiring a different core approximation.

/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 t, f, p, r;
    __m128i i, j;

    const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
    const __m128 cvt = _mm_set1_ps (12582912.0f);  /* 1.5 * (1 << 23) */
    const __m128 c0 =  _mm_set1_ps (0.238428936f);
    const __m128 c1 =  _mm_set1_ps (0.703448006f);
    const __m128 c2 =  _mm_set1_ps (1.000443142f);

    /* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
    t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
    r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
    f = _mm_sub_ps (t, r);               /* f = t - rint (t) */
    i = _mm_cvtps_epi32 (t);             /* i = (int)t */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
    j = _mm_slli_epi32 (i, 23);          /* i << 23 */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

问题代码的算法似乎取自Nicol N. Schraudolph的工作,后者巧妙地利用了IEEE-754二进制浮点格式的半对数性质:

The algorithm for the code in the question appears to be taken from the work of Nicol N. Schraudolph, which cleverly exploits the semi-logarithmic nature of IEEE-754 binary floating-point formats:

N. N. Schraudolph. 指数函数的快速紧凑逼近." 神经计算,1999年5月,第11版,第853-862页.

N. N. Schraudolph. "A fast, compact approximation of the exponential function." Neural Computation, 11(4), May 1999, pp.853-862.

除去自变量夹紧代码后,它减少为仅三个SSE指令.对于将整个输入域上的最大相对误差最小化,魔术"校正常数486411不是最佳的.根据简单的二进制搜索,值298765似乎更好,将FastExpSse()的最大相对误差降低到3.56e-2,而fast_exp_sse()的最大相对误差为1.73e-3.

After removal of the argument clamping code, it reduces to just three SSE instructions. The "magical" correction constant 486411 is not optimal for minimizing maximum relative error over the entire input domain. Based on simple binary search, the value 298765 seems to be superior, reducing maximum relative error for FastExpSse() to 3.56e-2 vs. maximum relative error of 1.73e-3 for fast_exp_sse().

/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
    __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
    __m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
    __m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
    return _mm_castsi128_ps (t);
}

Schraudolph的算法基本上对[0,1]中的f使用线性逼近2 f 〜= 1.0 + f,其精度可以通过添加二次项来提高. Schraudolph方法的巧妙部分是计算2 i * 2 f ,而没有明确地将整数部分i = floor(x * 1.44269504)从分数中分离出来.我看不到将这种技巧扩展到二次逼近的方法,但是可以肯定的是,可以将Schraudolph的floor()计算与上面使用的二次逼近相结合:

Schraudolph's algorithm basically uses the linear approximation 2f ~= 1.0 + f for f in [0,1], and its accuracy could be improved by adding a quadratic term. The clever part of Schraudolph's approach is computing 2i * 2f without explicitly separating the integer portion i = floor(x * 1.44269504) from the fraction. I see no way to extend that trick to a quadratic approximation, but one can certainly combine the floor() computation from Schraudolph with the quadratic approximation used above:

/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 f, p, r;
    __m128i t, j;
    const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
    const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
    const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
    const __m128 c0 = _mm_set1_ps (0.3371894346f);
    const __m128 c1 = _mm_set1_ps (0.657636276f);
    const __m128 c2 = _mm_set1_ps (1.00172476f);

    t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
    j = _mm_and_si128 (t, m);            /* j = (int)(floor (x/log(2))) << 23 */
    t = _mm_sub_epi32 (t, j);
    f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

这篇关于使用SSE最快实现自然指数函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆