将字段中的位扩展为掩码中所有(重叠+相邻)设置位的最快方法? [英] Fastest way to expand bits in a field to all (overlapping + adjacent) set bits in a mask?

查看:85
本文介绍了将字段中的位扩展为掩码中所有(重叠+相邻)设置位的最快方法?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

说我有2个名为IN和MASK的二进制输入.实际字段大小可能是32到256位,具体取决于用于完成任务的指令集.两个输入都会更改每个呼叫.

Say I have 2 binary inputs named IN and MASK. Actual field size could be 32 to 256 bits depending on what instruction set is used to accomplish the task. Both inputs change every call.

Inputs:
IN   = ...1100010010010100...
MASK = ...0001111010111011...
Output:
OUT  = ...0001111010111000...

一些评论讨论的另一个示例结果

edit: another example result from some comment discussion

IN   = ...11111110011010110...
MASK = ...01011011001111110...
Output:
OUT  = ...01011011001111110...

我想获得1位IN所在的MASK的连续相邻1位. (是否有这种操作的总称?也许我没有适当地称呼我的搜索词.)我正在尝试找到一种更快的方法来执行此操作.我愿意使用任何x86或x86 SIMD扩展,这些扩展都可以在最少的cpu周期内完成.首选更广泛的数据类型SIMD,因为它可以让我一次处理更多数据.

I want to get the contiguous adjacent 1 bits of MASK that a 1 bit of IN is within. (Is there a general term for this kind of operation? Maybe I'm not phrasing my searches properly.) I'm trying to find a way to do this that is a bit faster. I'm open to using any x86 or x86 SIMD extensions that can get this done in minimum cpu cycles. A wider data type SIMD is preferred as it will allow me to process more data at once.

我想出的最好的天真解决方案是以下伪代码,该伪代码手动向左移动,直到不再有匹配的位,然后重复向右移动:

The best naive solution I've come up with is the following pseudocode, which manually shifts left until there are no more matching bits, then repeats shifting right:

// (using the variables above)
testL = testR = OUT = (IN & MASK);

LoopL:
testL = (testL << 1) & MASK;
if (testL != 0) {
    OUT = OUT | testL;
    goto LoopL;
}

LoopR:
testR = (testR >> 1) & MASK;
if (testR != 0) {
    OUT = OUT | testR;
    goto LoopR;
}

return OUT;

推荐答案

我猜@fuz评论在正确的轨道上. 以下示例显示了以下SSE和AVX2代码的工作方式. 该算法以IN_reduced = IN & MASK开头,因为我们不感兴趣 在MASK0的位置的IN位中.

I guess @fuz comment was on the right track. The following example shows how the SSE and AVX2 code below works. The algorithm starts with IN_reduced = IN & MASK because we are not interested in IN bits at positions where MASK is 0.

IN                                  = . . . 0 0 0 0 . . . . p q r s . . .
MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
IN_reduced = IN & MASK              = . . 0 0 0 0 0 0 . . 0 p q r s 0 . .

如果任何p q r s位为1,则IN_reduced + MASK具有进位位1X位置,即在 请求的连续位.

If any of the p q r s bits is 1, then IN_reduced + MASK has a carry bit 1 at position X, which is right left to the requested contiguous bits.

MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
IN_reduced                          = . . 0 0 0 0 0 0 . . 0 p q r s 0 . .
IN_reduced + MASK                   = . . 0 1 1 1 1 . . . 1 . . . . . .
                                                          X
(IN_reduced + MASK) >>1             = . . . 0 1 1 1 1 . . . 1 . . . . . .

对于>> 1,此进位位1与位p移至同一列 (连续位的第一位). 现在,(IN_reduced + MASK) >>1实际上是IN_reducedMASK的平均值. 为了避免加法的溢出,我们使用以下方法 平均:avg(a, b) = (a & b) + ((a ^ b) >> 1)(请参阅@Harold的评论, 另请参见此处此处. ) 有了average = avg(IN_reduced, MASK),我们得到了

With >> 1 this carry bit 1 is shifted to the same column as bit p (the first bit of the contiguous bits). Now, (IN_reduced + MASK) >>1 is actually an average of IN_reduced and MASK. In order to avoid possible overflow of addition we use the following average: avg(a, b) = (a & b) + ((a ^ b) >> 1) (See @Harold's comment, see also here and here.) With average = avg(IN_reduced, MASK) we get

MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
IN_reduced                          = . . 0 0 0 0 0 0 . . 0 p q r s 0 . .
average                             = . . . 0 1 1 1 1 . . . 1 . . . . . .
MASK >> 1                           = . . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 .  
leading_bits = (~(MASK>>1))&average = . . . 0 0 0 0 0 . . . 1 0 0 0 0 . .  

我们可以用 leading_bits = (~(MASK>>1) ) & average,因为MASK>>1在位置为零 进位位 我们感兴趣的.

We can isolate the leading carry bits with leading_bits = (~(MASK>>1) ) & average because MASK>>1 is zero at the positions of the carry bits that we are interested in.

正常加法时,进位从右向左传播.在这里我们使用 反向加法:从左向右携带. 反向添加MASKleading_bits: rev_added = bit_swap(bit_swap(MASK) + bit_swap(leading_bits)), 这将零位 想要的职位. 使用OUT = (~rev_added) & MASK,我们得到结果.

With normal addition the carry propagates from right to left. Here we use a reverse addition: with a carry from left to right. Reverse adding MASK and leading_bits: rev_added = bit_swap(bit_swap(MASK) + bit_swap(leading_bits)), This zeros the bits at the wanted positions. With OUT = (~rev_added) & MASK we get the result.

MASK                                = . . 0 1 1 1 1 0 . . 0 1 1 1 1 0 . . 
leading_bits                        = . . . 0 0 0 0 0 . . . 1 0 0 0 0 . .  
rev_added (MASK,leading_bits)       = . . . 1 1 1 1 0 . . . 0 0 0 0 1 . .
OUT = ~rev_added & MASK             = . . 0 0 0 0 0 0 . . . 1 1 1 1 0 . .

该算法尚未经过全面测试,但输出看起来还可以.

The algorithm was not thoroughly tested, but the output looks ok.

下面的代码块包含两个单独的代码: 上半部分是SSE代码, 下半部分是AVX2代码. (为了避免 用两个大代码块使答案过分膨胀.) SSE算法适用于2 x 64位元素,而AVX2版本适用于4 x 64位元素.

The code block below contains two separate codes: The upper half is the SSE code, and the lower half is the AVX2 code. (In order to avoid bloating the answer too much with two large code blocks.) The SSE algorithm works with 2 x 64-bit elements and the AVX2 version works with 4 x 64-bit elements.

在gcc 9.1中,算法编译为大约29条指令, 除了4 vmovdqa -s之外,还用于加载一些常量,这很可能 (在内联之后)在现实世界的应用程序中脱颖而出. 这29条指令是执行9个shuffle(vpshufb)的很好的组合 在Intel Skylake的端口5(p5)上,以及许多其他可能经常会出现的说明 在p0,p1或p5上执行.

With gcc 9.1, the algorithm compiles to about 29 instructions, aside from 4 vmovdqa-s for loading some constants, which are likely hoisted out of the loop in a real world application (after inlining). These 29 instructions are a good mix of 9 shuffles (vpshufb) that execute on port 5 (p5) on Intel Skylake, and many other instructions that often may execute on p0, p1 or p5.

因此,每个周期可能执行约3条指令. 在这种情况下,吞吐量约为1个函数调用(内联) 每10个周期.在AVX2情况下,这意味着每个结果有4个uint64_t OUT结果 大约10个周期.

Therefore, a performance of about 3 instructions per cycle might be possible. In that case the throughput would be about 1 function call (inlined) per 10 cycles. In the AVX2 case this means 4 uint64_t OUT results per about 10 cycles.

请注意,性能与data(!)无关,这是一个很好的选择 我认为这个答案的好处.该解决方案是无分支的和无循环的,并且 不会遭受分支预测失败的困扰.

Note that the performance is independent of the data(!), which is a great benefit of this answer I think. The solution is branchless, and loopless, and cannot suffer from failing branch prediction.


/*  gcc -O3 -m64 -Wall -march=skylake select_bits.c    */
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

int print_sse_128_bin(__m128i x);
__m128i bit_128_k(unsigned int k);
__m128i mm_bitreverse_epi64(__m128i x);
__m128i mm_revadd_epi64(__m128i x, __m128i y);


/* Select specific pieces of contiguous bits from `MASK` based on selector `IN`  */
__m128i mm_select_bits_epi64(__m128i IN, __m128i MASK){
    __m128i IN_reduced   = _mm_and_si128(IN, MASK);
    /* Compute the average of IN_reduced and MASK with avg(a,b)=(a&b)+((a^b)>>1)  */
    /* (IN_reduced & MASK) + ((IN_reduced ^ MASK) >>1) =                          */
    /* ((IN & MASK) & MASK) + ((IN_reduced ^ MASK) >>1) =                         */
    /* IN_reduced + ((IN_reduced ^ MASK) >>1)                                     */
    __m128i tmp          = _mm_xor_si128(IN_reduced, MASK);
    __m128i tmp_div2     = _mm_srli_epi64(tmp, 1);
    __m128i average      = _mm_add_epi64(IN_reduced, tmp_div2);   /* average is the average */
    __m128i MASK_div2    = _mm_srli_epi64(MASK, 1);
    __m128i leading_bits = _mm_andnot_si128(MASK_div2, average);
    __m128i rev_added    = mm_revadd_epi64(MASK, leading_bits);
    __m128i OUT          = _mm_andnot_si128(rev_added, MASK);
    /* Uncomment the next lines to check the arithmetic */ /*   
    printf("IN           ");print_sse_128_bin(IN           );       
    printf("MASK         ");print_sse_128_bin(MASK         ); 
    printf("IN_reduced   ");print_sse_128_bin(IN_reduced   );       
    printf("tmp          ");print_sse_128_bin(tmp          );       
    printf("tmp_div2     ");print_sse_128_bin(tmp_div2     );       
    printf("average      ");print_sse_128_bin(average      );       
    printf("MASK_div2    ");print_sse_128_bin(MASK_div2    );       
    printf("leading_bits ");print_sse_128_bin(leading_bits );       
    printf("rev_added    ");print_sse_128_bin(rev_added    );       
    printf("OUT          ");print_sse_128_bin(OUT          );       
    printf("\n");*/
    return OUT;       
}


int main(){
    __m128i IN   = _mm_set_epi64x(0b11111110011010110, 0b1100010010010100);
    __m128i MASK = _mm_set_epi64x(0b01011011001111110, 0b0001111010111011);
    __m128i OUT;    

    printf("Example 1 \n");
    OUT = mm_select_bits_epi64(IN, MASK);
    printf("IN           ");print_sse_128_bin(IN);
    printf("MASK         ");print_sse_128_bin(MASK);
    printf("OUT          ");print_sse_128_bin(OUT);
    printf("\n\n");

                      /*  0b7654321076543210765432107654321076543210765432107654321076543210  */
    IN   = _mm_set_epi64x(0b1000001001001010000010000000100000010000000000100000000111100011, 
                          0b11111110011010111);
    MASK = _mm_set_epi64x(0b1110011110101110111111000000000111011111101101111100011111000001, 
                          0b01011011001111111);

    printf("Example 2 \n");
    OUT = mm_select_bits_epi64(IN, MASK);
    printf("IN           ");print_sse_128_bin(IN);
    printf("MASK         ");print_sse_128_bin(MASK);
    printf("OUT          ");print_sse_128_bin(OUT);
    printf("\n\n");

    return 0;
}


int print_sse_128_bin(__m128i x){
    for (int i = 127; i >= 0; i--){
        printf("%1u", _mm_testnzc_si128(bit_128_k(i), x));
        if (((i & 7) == 0) && (i > 0)) printf(" ");
    }
    printf("\n");
    return 0;
}


/* From my answer here https://stackoverflow.com/a/39595704/2439725, adapted to 128-bit */
inline __m128i bit_128_k(unsigned int k){
  __m128i  indices     = _mm_set_epi32(96, 64, 32, 0);
  __m128i  one         = _mm_set1_epi32(1);

  __m128i  kvec        = _mm_set1_epi32(k);  
  __m128i  shiftcounts = _mm_sub_epi32(kvec, indices);
  __m128i  kbit        = _mm_sllv_epi32(one, shiftcounts);   
  return kbit;                             
}


/* Copied from Harold's answer https://stackoverflow.com/a/46318399/2439725         */
/* Adapted to epi64 and __m128i: bit reverse two 64 bit elements                    */
inline __m128i mm_bitreverse_epi64(__m128i x){
    __m128i shufbytes = _mm_setr_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8); 
    __m128i luthigh = _mm_setr_epi8(0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15);
    __m128i lutlow = _mm_slli_epi16(luthigh, 4);
    __m128i lowmask = _mm_set1_epi8(15);
    __m128i rbytes = _mm_shuffle_epi8(x, shufbytes);
    __m128i high = _mm_shuffle_epi8(lutlow, _mm_and_si128(rbytes, lowmask));
    __m128i low = _mm_shuffle_epi8(luthigh, _mm_and_si128(_mm_srli_epi16(rbytes, 4), lowmask));
    return _mm_or_si128(low, high);
}


/* Add in the reverse direction: With a carry from left to */
/* right, instead of right to left                         */
inline __m128i mm_revadd_epi64(__m128i x, __m128i y){
    x = mm_bitreverse_epi64(x);
    y = mm_bitreverse_epi64(y);
    __m128i sum = _mm_add_epi64(x, y);
    return mm_bitreverse_epi64(sum);
}
/* End of SSE code */


/************* AVX2 code starts here ********************************************/

/*  gcc -O3 -m64 -Wall -march=skylake select_bits256.c    */
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

int print_avx_256_bin(__m256i x);
__m256i bit_256_k(unsigned int k);
__m256i mm256_bitreverse_epi64(__m256i x);
__m256i mm256_revadd_epi64(__m256i x, __m256i y);


/* Select specific pieces of contiguous bits from `MASK` based on selector `IN`  */
__m256i mm256_select_bits_epi64(__m256i IN, __m256i MASK){
    __m256i IN_reduced   = _mm256_and_si256(IN, MASK);
    /* Compute the average of IN_reduced and MASK with avg(a,b)=(a&b)+((a^b)>>1)  */
    /* (IN_reduced & MASK) + ((IN_reduced ^ MASK) >>1) =                          */
    /* ((IN & MASK) & MASK) + ((IN_reduced ^ MASK) >>1) =                         */
    /* IN_reduced + ((IN_reduced ^ MASK) >>1)                                     */
    __m256i tmp          = _mm256_xor_si256(IN_reduced, MASK);
    __m256i tmp_div2     = _mm256_srli_epi64(tmp, 1);
    __m256i average      = _mm256_add_epi64(IN_reduced, tmp_div2);   /* average is the average */
    __m256i MASK_div2    = _mm256_srli_epi64(MASK, 1);
    __m256i leading_bits = _mm256_andnot_si256(MASK_div2, average);
    __m256i rev_added    = mm256_revadd_epi64(MASK, leading_bits);
    __m256i OUT          = _mm256_andnot_si256(rev_added, MASK);
    /* Uncomment the next lines to check the arithmetic */ /*   
    printf("IN           ");print_avx_256_bin(IN           );       
    printf("MASK         ");print_avx_256_bin(MASK         ); 
    printf("IN_reduced   ");print_avx_256_bin(IN_reduced   );       
    printf("tmp          ");print_avx_256_bin(tmp          );       
    printf("tmp_div2     ");print_avx_256_bin(tmp_div2     );       
    printf("average      ");print_avx_256_bin(average      );       
    printf("MASK_div2    ");print_avx_256_bin(MASK_div2    );       
    printf("leading_bits ");print_avx_256_bin(leading_bits );       
    printf("rev_added    ");print_avx_256_bin(rev_added    );       
    printf("OUT          ");print_avx_256_bin(OUT          );       
    printf("\n");*/
    return OUT;       
}


int main(){
    __m256i IN   = _mm256_set_epi64x(0b11111110011010110, 
                                     0b1100010010010100,
                                     0b1000001001001010000010000000100000010000000000100000000111100011, 
                                     0b11111110011010111
    );
    __m256i MASK = _mm256_set_epi64x(0b01011011001111110, 
                                     0b0001111010111011,
                                     0b1110011110101110111111000000000111011111101101111100011111000001, 
                                     0b01011011001111111);
    __m256i OUT;    

    printf("Example \n");
    OUT = mm256_select_bits_epi64(IN, MASK);
    printf("IN           ");print_avx_256_bin(IN);
    printf("MASK         ");print_avx_256_bin(MASK);
    printf("OUT          ");print_avx_256_bin(OUT);
    printf("\n");

    return 0;
}


int print_avx_256_bin(__m256i x){
    for (int i=255;i>=0;i--){
        printf("%1u",_mm256_testnzc_si256(bit_256_k(i),x));
        if (((i&7) ==0)&&(i>0)) printf(" ");
    }
    printf("\n");
    return 0;
}


/* From my answer here https://stackoverflow.com/a/39595704/2439725 */
inline __m256i bit_256_k(unsigned int k){
  __m256i  indices     = _mm256_set_epi32(224,192,160,128,96,64,32,0);
  __m256i  one         = _mm256_set1_epi32(1);

  __m256i  kvec        = _mm256_set1_epi32(k);  
  __m256i  shiftcounts = _mm256_sub_epi32(kvec, indices);
  __m256i  kbit        = _mm256_sllv_epi32(one, shiftcounts);   
  return kbit;                             
}


/* Copied from Harold's answer https://stackoverflow.com/a/46318399/2439725         */
/* Adapted to epi64: bit reverse four 64 bit elements                    */
inline __m256i mm256_bitreverse_epi64(__m256i x){
    __m256i shufbytes = _mm256_setr_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8);
    __m256i luthigh = _mm256_setr_epi8(0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15, 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15);
    __m256i lutlow = _mm256_slli_epi16(luthigh, 4);
    __m256i lowmask = _mm256_set1_epi8(15);
    __m256i rbytes = _mm256_shuffle_epi8(x, shufbytes);
    __m256i high = _mm256_shuffle_epi8(lutlow, _mm256_and_si256(rbytes, lowmask));
    __m256i low = _mm256_shuffle_epi8(luthigh, _mm256_and_si256(_mm256_srli_epi16(rbytes, 4), lowmask));
    return _mm256_or_si256(low, high);
}


/* Add in the reverse direction: With a carry from left to */
/* right, instead of right to left                         */
inline __m256i mm256_revadd_epi64(__m256i x, __m256i y){
    x = mm256_bitreverse_epi64(x);
    y = mm256_bitreverse_epi64(y);
    __m256i sum = _mm256_add_epi64(x, y);
    return mm256_bitreverse_epi64(sum);
}


带有未注释的调试部分的SSE代码输出:

Output of the SSE code with an uncommented debugging section:

Example 1 
IN           00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010110 00000000 00000000 00000000 00000000 00000000 00000000 11000100 10010100
MASK         00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111011
IN_reduced   00000000 00000000 00000000 00000000 00000000 00000000 10110100 01010110 00000000 00000000 00000000 00000000 00000000 00000000 00000100 10010000
tmp          00000000 00000000 00000000 00000000 00000000 00000000 00000010 00101000 00000000 00000000 00000000 00000000 00000000 00000000 00011010 00101011
tmp_div2     00000000 00000000 00000000 00000000 00000000 00000000 00000001 00010100 00000000 00000000 00000000 00000000 00000000 00000000 00001101 00010101
average      00000000 00000000 00000000 00000000 00000000 00000000 10110101 01101010 00000000 00000000 00000000 00000000 00000000 00000000 00010001 10100101
MASK_div2    00000000 00000000 00000000 00000000 00000000 00000000 01011011 00111111 00000000 00000000 00000000 00000000 00000000 00000000 00001111 01011101
leading_bits 00000000 00000000 00000000 00000000 00000000 00000000 10100100 01000000 00000000 00000000 00000000 00000000 00000000 00000000 00010000 10100000
rev_added    00000000 00000000 00000000 00000000 00000000 00000000 01001001 00000001 00000000 00000000 00000000 00000000 00000000 00000000 00000001 01000111
OUT          00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111000

IN           00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010110 00000000 00000000 00000000 00000000 00000000 00000000 11000100 10010100
MASK         00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111011
OUT          00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111110 00000000 00000000 00000000 00000000 00000000 00000000 00011110 10111000


Example 2 
IN           10000010 01001010 00001000 00001000 00010000 00000010 00000001 11100011 00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010111
MASK         11100111 10101110 11111100 00000001 11011111 10110111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111
IN_reduced   10000010 00001010 00001000 00000000 00010000 00000010 00000001 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110100 01010111
tmp          01100101 10100100 11110100 00000001 11001111 10110101 11000110 00000000 00000000 00000000 00000000 00000000 00000000 00000000 00000010 00101000
tmp_div2     00110010 11010010 01111010 00000000 11100111 11011010 11100011 00000000 00000000 00000000 00000000 00000000 00000000 00000000 00000001 00010100
average      10110100 11011100 10000010 00000000 11110111 11011100 11100100 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110101 01101011
MASK_div2    01110011 11010111 01111110 00000000 11101111 11011011 11100011 11100000 00000000 00000000 00000000 00000000 00000000 00000000 01011011 00111111
leading_bits 10000100 00001000 10000000 00000000 00010000 00000100 00000100 00000001 00000000 00000000 00000000 00000000 00000000 00000000 10100100 01000000
rev_added    00010000 01100001 00000010 00000001 11000000 01110000 00100000 00100000 00000000 00000000 00000000 00000000 00000000 00000000 01001001 00000000
OUT          11100111 10001110 11111100 00000000 00011111 10000111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111

IN           10000010 01001010 00001000 00001000 00010000 00000010 00000001 11100011 00000000 00000000 00000000 00000000 00000000 00000001 11111100 11010111
MASK         11100111 10101110 11111100 00000001 11011111 10110111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111
OUT          11100111 10001110 11111100 00000000 00011111 10000111 11000111 11000001 00000000 00000000 00000000 00000000 00000000 00000000 10110110 01111111

这篇关于将字段中的位扩展为掩码中所有(重叠+相邻)设置位的最快方法?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆