# 使用SSE最快实现自然指数函数 [英] Fastest Implementation of the Natural Exponential Function Using SSE

### 问题描述

I'm looking for an approximation of the natural exponential function operating on SSE element. Namely - `__m128 exp( __m128 x )`.

I have an implementation which is quick but seems to be very low in accuracy:

``````static inline __m128 FastExpSse(__m128 x)
{
__m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
__m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
__m128  m87 = _mm_set1_ps(-87);
// fast exponential function, x should be in [-87, 87]

__m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
}
``````

Could anybody have an implementation with better accuracy yet as fast (Or faster)?

I'd be happy if it is written in C Style.

### 推荐答案

The basic idea is to transform the computation of the standard exponential function into computation of a power of 2: `expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504)`. We split `t = x * 1.44269504` into an integer `i` and a fraction `f`, such that `t = i + f` and `0 <= f <= 1`. We can now compute 2f with a polynomial approximation, then scale the result by 2i by adding `i` to the exponent field of the single-precision floating-point result.

SSE实现中存在的一个问题是我们要计算`i = floorf (t)`，但是没有快速的方法来计算`floor()`函数.但是，我们观察到，对于`floor(x) == trunc(x)`为正数，对于`floor(x) == trunc(x) - 1`为负数，除非`x`是负整数.但是，由于核心近似值可以处理`f``1.0f`，因此对负参数使用近似值是无害的. SSE提供了一条指令，可以将单精度浮点操作数转换为带有截断的整数，因此该解决方案非常有效.

One problem that exists with an SSE implementation is that we want to compute `i = floorf (t)`, but there is no fast way to compute the `floor()` function. However, we observe that for positive numbers, `floor(x) == trunc(x)`, and that for negative numbers, `floor(x) == trunc(x) - 1`, except when `x` is a negative integer. However, since the core approximation can handle an `f` value of `1.0f`, using the approximation for negative arguments is harmless. SSE provides an instruction to convert single-precision floating point operands to integers with truncation, so this solution is efficient.

Peter Cordes points out that SSE4.1 supports a fast floor function `_mm_floor_ps()`, so a variant using SSE4.1 is also shown below. Not all toolchains automatically predefine the macro `__SSE4_1__` when SSE 4.1 code generation is enabled, but gcc does.

Compiler Explorer(Godbolt)显示gcc 7.2将以下代码编译为纯SSE的十六个指令针对SSE 4.1的十二个说明.

Compiler Explorer (Godbolt) shows that gcc 7.2 compiles the code below into sixteen instructions for plain SSE and twelve instructions for SSE 4.1.

``````#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif

/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, e, p, r;
__m128i i, j;
__m128 l2e = _mm_set1_ps (1.442695041f);  /* log2(e) */
__m128 c0  = _mm_set1_ps (0.3371894346f);
__m128 c1  = _mm_set1_ps (0.657636276f);
__m128 c2  = _mm_set1_ps (1.00172476f);

/* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */
t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
#ifdef __SSE4_1__
e = _mm_floor_ps (t);                /* floor(t) */
i = _mm_cvtps_epi32 (e);             /* (int)floor(t) */
#else /* __SSE4_1__*/
i = _mm_cvttps_epi32 (t);            /* i = (int)t */
j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
i = _mm_sub_epi32 (i, j);            /* (int)t - signbit(t) */
e = _mm_cvtepi32_ps (i);             /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
f = _mm_sub_ps (t, e);               /* f = t - floor(t) */
p = c0;                              /* c0 */
p = _mm_mul_ps (p, f);               /* c0 * f */
p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
j = _mm_slli_epi32 (i, 23);          /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}

int main (void)
{
union {
float f[4];
unsigned int i[4];
} arg, res;
double relerr, maxrelerr = 0.0;
int i, j;
__m128 x, y;

float start[2] = {-0.0f, 0.0f};
float finish[2] = {-87.33654f, 88.72283f};

for (i = 0; i < 2; i++) {

arg.f[0] = start[i];
arg.i[1] = arg.i[0] + 1;
arg.i[2] = arg.i[0] + 2;
arg.i[3] = arg.i[0] + 3;
do {
memcpy (&x, &arg, sizeof(x));
y = fast_exp_sse (x);
memcpy (&res, &y, sizeof(y));
for (j = 0; j < 4; j++) {
double ref = exp ((double)arg.f[j]);
relerr = fabs ((res.f[j] - ref) / ref);
if (relerr > maxrelerr) {
printf ("arg=% 15.8e  res=%15.8e  ref=%15.8e  err=%15.8e\n",
arg.f[j], res.f[j], ref, relerr);
maxrelerr = relerr;
}
}
arg.i[0] += 4;
arg.i[1] += 4;
arg.i[2] += 4;
arg.i[3] += 4;
} while (fabsf (arg.f[3]) < fabsf (finish[i]));
}
printf ("maximum relative errror = %15.8e\n", maxrelerr);
return EXIT_SUCCESS;
}
``````

`fast_sse_exp()`的另一种设计，是使用四舍五入到最接近模式提取调整后的自变量`x / log(2)`的整数部分的方法，它使用了众所周知的添加魔术"转换常数1.5 * 2 23的技术. 强制四舍五入到正确的位位置，然后再次减去相同的数字.这要求加法期间有效的SSE舍入模式为舍入到最接近或什至"，这是默认设置. wim 在评论中指出，一些编译器可能会在激进时优化转换常数`cvt`的加减运算由于使用了优化，因此会干扰此代码序列的功能，因此建议检查生成的机器代码.现在计算2 f 的近似间隔以零为中心，因为`-0.5 <= f <= 0.5`需要不同的核心近似.

An alternative design for `fast_sse_exp()` extracts the integer portion of the adjusted argument `x / log(2)` in round-to-nearest mode, using the well-known technique of adding the "magic" conversion constant 1.5 * 223 to force rounding in the correct bit position, then subtracting out the same number again. This requires that the SSE rounding mode in effect during the addition is "round to nearest or even", which is the default. wim pointed out in comments that some compilers may optimize out the addition and subtraction of the conversion constant `cvt` as redundant when aggressive optimization is used, interfering with the functionality of this code sequence, so it is recommended to inspect the machine code generated. The approximation interval for computation of 2f is now centered around zero, since `-0.5 <= f <= 0.5`, requiring a different core approximation.

``````/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, p, r;
__m128i i, j;

const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
const __m128 cvt = _mm_set1_ps (12582912.0f);  /* 1.5 * (1 << 23) */
const __m128 c0 =  _mm_set1_ps (0.238428936f);
const __m128 c1 =  _mm_set1_ps (0.703448006f);
const __m128 c2 =  _mm_set1_ps (1.000443142f);

/* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
f = _mm_sub_ps (t, r);               /* f = t - rint (t) */
i = _mm_cvtps_epi32 (t);             /* i = (int)t */
p = c0;                              /* c0 */
p = _mm_mul_ps (p, f);               /* c0 * f */
p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
j = _mm_slli_epi32 (i, 23);          /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
``````

The algorithm for the code in the question appears to be taken from the work of Nicol N. Schraudolph, which cleverly exploits the semi-logarithmic nature of IEEE-754 binary floating-point formats:

N. N. Schraudolph. 指数函数的快速紧凑逼近." 神经计算，1999年5月，第11版，第853-862页.

N. N. Schraudolph. "A fast, compact approximation of the exponential function." Neural Computation, 11(4), May 1999, pp.853-862.

After removal of the argument clamping code, it reduces to just three SSE instructions. The "magical" correction constant `486411` is not optimal for minimizing maximum relative error over the entire input domain. Based on simple binary search, the value `298765` seems to be superior, reducing maximum relative error for `FastExpSse()` to 3.56e-2 vs. maximum relative error of 1.73e-3 for `fast_exp_sse()`.

``````/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
__m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
__m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
__m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
return _mm_castsi128_ps (t);
}
``````

Schraudolph的算法基本上对[0,1]中的`f`使用线性逼近2 f 〜= `1.0 + f`，其精度可以通过添加二次项来提高. Schraudolph方法的巧妙部分是计算2 i * 2 f ，而没有明确地将整数部分`i = floor(x * 1.44269504)`从分数中分离出来.我看不到将这种技巧扩展到二次逼近的方法，但是可以肯定的是，可以将Schraudolph的`floor()`计算与上面使用的二次逼近相结合:

Schraudolph's algorithm basically uses the linear approximation 2f ~= `1.0 + f` for `f` in [0,1], and its accuracy could be improved by adding a quadratic term. The clever part of Schraudolph's approach is computing 2i * 2f without explicitly separating the integer portion `i = floor(x * 1.44269504)` from the fraction. I see no way to extend that trick to a quadratic approximation, but one can certainly combine the `floor()` computation from Schraudolph with the quadratic approximation used above:

``````/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 f, p, r;
__m128i t, j;
const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
const __m128 c0 = _mm_set1_ps (0.3371894346f);
const __m128 c1 = _mm_set1_ps (0.657636276f);
const __m128 c2 = _mm_set1_ps (1.00172476f);

t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
j = _mm_and_si128 (t, m);            /* j = (int)(floor (x/log(2))) << 23 */
t = _mm_sub_epi32 (t, j);
f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
p = c0;                              /* c0 */
p = _mm_mul_ps (p, f);               /* c0 * f */
p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
``````