AVX2 列人口计数算法分别针对每个位列 [英] AVX2 column population count algorithm over each bit-column separately

查看:49
本文介绍了AVX2 列人口计数算法分别针对每个位列的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

对于我正在进行的项目,我需要计算翻录的 PDF 图像数据中每列的设置位数.

For a project I'm working on I need to count the number of set bits per column in ripped PDF image data.

我正在尝试获取整个 PDF 作业(所有页面)中每一列的总设置位数.

I'm trying to get the total set bit count for each column in the entire PDF job (all pages).

数据一旦被翻录,就会存储在一个 MemoryMappedFile 中,没有后备文件(在内存中).

The data, once ripped, is stored in a MemoryMappedFile with no backing file (in memory).

PDF 页面尺寸为 13952 像素 x 15125 像素.可以通过将 PDF 的长度(高度)(以像素为单位)乘以以字节为单位的宽度来计算所得翻录数据的总大小.翻录的数据是1 bit == 1 pixel.因此,翻录页面的字节大小为 (13952/8) * 15125.

The PDF page dimensions are 13952 pixels x 15125 pixels. The total size of the resulting ripped data can be calculated by multiplying the length (height) of the PDF in pixels by the width in bytes. The ripped data is 1 bit == 1 pixel. So the size of a ripped page in bytes is (13952 / 8) * 15125.

注意宽度总是64位的倍数.

我必须在被翻录后计算 PDF(可能是数万页)的每一页中每一列的设置位.

I'll have to count the set bits for each column in each page of a PDF (which could be tens of thousands of pages) after being ripped.

我首先用一个基本的解决方案来解决这个问题,即循环遍历每个字节并计算设置位的数量并将结果放入vector.从那以后,我将算法缩减为如下所示.我的执行时间从 ~350 毫秒变为 ~120 毫秒.

I first approached the problem with a basic solution of just looping through each byte and counting the number of set bits and placing the results in a vector. I've since whittled down the algorithm to whats shown below. I've gone from a execution time of ~350ms to ~120ms.

static void count_dots( )
{
    using namespace diag;
    using namespace std::chrono;

    std::vector<std::size_t> dot_counts( 13952, 0 );
    uint64_t* ptr_dot_counts{ dot_counts.data( ) };

    std::vector<uint64_t> ripped_pdf_data( 3297250, 0xFFFFFFFFFFFFFFFFUL );
    const uint64_t* ptr_data{ ripped_pdf_data.data( ) };

    std::size_t line_count{ 0 };
    std::size_t counter{ ripped_pdf_data.size( ) };

    stopwatch sw;
    sw.start( );

    while( counter > 0 )
    {
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000001UL ) >> 0;

        ++ptr_data;
        --counter;
        if( ++line_count >= 218 )
        {
            ptr_dot_counts = dot_counts.data( );
            line_count = 0;
        }
    }   

    sw.stop( );
    std::cout << sw.elapsed<milliseconds>( ) << "ms\n";
}

不幸的是,这仍然会增加很多额外的处理时间,这是不可接受的.

Unfortunately this is still going to add a lot of extra processing time which isn't going to be acceptable.

上面的代码很丑,不会赢得任何选美比赛,但它有助于减少执行时间.从我写的原始版本开始,我做了以下事情:

The above code is ugly and wont win any beauty contests but it has helped in reducing execution time. Since the original version I wrote I've done the following:

  • 使用 pointers 而不是 indexers
  • uint64而不是uint8
  • 的块处理数据
  • 手动展开 for 循环以遍历 uint64 的每个 byte 中的每个 bit
  • 使用最终的bit shift代替__popcnt64来计算屏蔽后的bit
  • Use pointers instead of indexers
  • Process the data in chunks of uint64 instead of uint8
  • Manually unroll the for loop for traversing each bit in each byte of a uint64
  • Use a final bit shift instead of __popcnt64 for counting the set bit after masking

对于这个测试,我生成了伪造的翻录数据,其中每个 bit 都设置为 1.测试完成后,dot_counts vector 应包含每个 element15125.

For this test I'm generating phony ripped data where each bit is set to 1. The dot_counts vector should contain 15125 for each element after the test has completed.

我希望这里的一些人可以帮助我使算法的平均执行时间低于 100 毫秒.我不在乎这里的便携性.

I'm hoping some folks here can help me in getting the algorithms average execution time below 100ms. I do not care what-so-ever about portability here.

  • 目标机器的 CPU:Xeon E5-2680 v4 - Intel
  • 编译器:MSVC++ 14.23
  • 操作系统:Windows 10
  • C++ 版本:C++17
  • 编译器标志:/O2 /arch:AVX2

大约 8 年前有人问过一个非常相似的问题:如何在 Sandy Bridge 上的一系列整数中快速将位计数到单独的 bin 中?

A very similar question was asked ~8 years ago: How to quickly count bits into separate bins in a series of ints on Sandy Bridge?

(编者注:也许你错过了 在许多 64 位位掩码上分别计算每个位的位置,使用 AVX 而不是 AVX2,它有一些更新更快的答案,至少对于沿着一列而不是沿着一行向下移动在连续的内存中.也许你可以在一列下移动 1 或 2 个高速缓存行,这样你就可以在 SIMD 寄存器中保持计数器的热度.)

(Editor's note: perhaps you missed Count each bit-position separately over many 64-bit bitmasks, with AVX but not AVX2 which has some more recent faster answers, at least for going down a column instead of along a row in contiguous memory. Maybe you can go 1 or 2 cache-lines wide down a column so you can keep your counters hot in SIMD registers.)

当我将迄今为止的结果与接受的答案进行比较时,我非常接近.我已经在处理 uint64 而不是 uint8 的块.我只是想知道我是否还能做更多的事情,无论是使用内在函数、汇编还是简单的事情,比如更改我正在使用的数据结构.

When I compare what I have thus far to the accepted answer I'm fairly close. I was already processing in chunks of uint64 instead of uint8. I'm just wondering if there is more I can do, whether that be with intrinsics, assembly, or something simple like changing what data structures I'm using.

推荐答案

它可以通过 AVX2 完成,如标记.

It could be done with AVX2, as tagged.

为了使这项工作正常进行,我建议使用 vector 进行计数.增加计数是最大的问题,我们需要扩大的越多,问题就越大.uint16_t 足以计算一页,因此您可以一次计算一页并将计数器添加到一组更宽的计数器中以获得总数.这是一些开销,但比必须在主循环中加宽要少得多.

In order to make this work out properly, I recommend vector<uint16_t> for the counts. Adding into the counts is the biggest problem, and the more we need to widen, the bigger the problem. uint16_t is enough to count one page, so you can count one page at the time and add the counters into a set of wider counters for the totals. That is some overhead, but much less than having to widen more in the main loop.

计数的大端排序非常烦人,引入了更多的洗牌以使其正确.所以我建议把它错误,然后重新排列计数(也许在将它们汇总到总数中时?).先右移 7,然后是 6,然后是 5"的顺序可以免费保持,因为我们可以以任何我们想要的方式选择 64 位块的移位计数.所以在下面的代码中,实际的计数顺序是:

The big-endian ordering of the counts is very annoying, introducing even more shuffles to get it right. So I recommend getting it wrong and reordering the counts later (maybe during summing them into the totals?). The order of "right shift by 7 first, then 6, then 5" can be maintained for free, because we get to choose the shift counts for the 64bit blocks any way we want. So in the code below, the actual order of counts is:

  • 最低有效字节的第 7 位,
  • 第二个字节的第 7 位
  • ...
  • 最高有效字节的第 7 位,
  • 最低有效字节的第 6 位,
  • ...

所以每组8个都颠倒了.(至少这是我打算做的,AVX2 解压 令人困惑)

So every group of 8 is reversed. (at least this is what I intended to do, AVX2 unpacks are confusing)

代码(未测试):

while( counter > 0 )
{
    __m256i data = _mm256_set1_epi64x(*ptr_data);        
    __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
    __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
    data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
    data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

    __m256i zero = _mm256_setzero_si256();

    __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);

    ptr_dot_counts += 64;
    ++ptr_data;
    --counter;
    if( ++line_count >= 218 )
    {
        ptr_dot_counts = dot_counts.data( );
        line_count = 0;
    }
}

这可以进一步展开,一次处理多行.这很好,因为正如前面提到的,对计数器求和是最大的问题,按行展开会做的更少,在寄存器中求和更简单.

This can be further unrolled, handling multiple rows at once. That is good because, as mentioned earlier, summing into the counters is the biggest problem, and unrolling by rows would do less of that and more plain summing in registers.

使用了一些内在函数:

  • _mm256_set1_epi64x,将一个 int64_t 复制到向量的所有 4 个 64 位元素.也适用于 uint64_t.
  • _mm256_set_epi64x,将 4 个 64 位值转换为向量.
  • _mm256_srlv_epi64,逻辑右移,计数可变(每个元素可以有不同的计数).
  • _mm256_and_si256,只是按位与.
  • _mm256_add_epi16,此外,适用于 16 位元素.
  • _mm256_unpacklo_epi8_mm256_unpackhi_epi8,可能是最好的解释通过该页面上的图表
  • _mm256_set1_epi64x, copies one int64_t to all 4 of the 64bit elements of the vector. Also fine for uint64_t.
  • _mm256_set_epi64x, turns 4 64bit values into a vector.
  • _mm256_srlv_epi64, shift right logical, with variable count (can be a different count for each element).
  • _mm256_and_si256, just bitwise AND.
  • _mm256_add_epi16, addition, works on 16bit elements.
  • _mm256_unpacklo_epi8 and _mm256_unpackhi_epi8, probably best explained by the diagrams on that page

可以垂直"求和,使用一个 uint64_t 保存 64 个单独求和的所有第 0 位,另一个 uint64_t 保存所有第一个位总和等.可以通过使用按位算术模拟全加器(电路组件)来完成加法.然后不是只向计数器添加 0 或 1,而是一次性添加更大的数字.

It's possible to sum "vertically", using one uint64_t to hold all the 0th bits of the 64 individual sums, an other uint64_t to hold all the 1st bits of the sums etc. The addition can be done by emulating full adders (the circuit component) with bitwise arithmetic. Then instead of adding just 0 or 1 to the counters, bigger numbers are added all at once.

垂直总和也可以矢量化,但这会显着增加将垂直总和添加到列总和的代码,所以我没有在这里这样做.应该会有所帮助,但代码很多.

The vertical sums can also be vectorized, but that would significantly inflate the code that adds the vertical sums to the column sums, so I didn't do that here. It should help, but it's just a lot of code.

示例(未测试):

size_t y;
// sum 7 rows at once
for (y = 0; (y + 6) < 15125; y += 7) {
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) {
        uint64_t dataA = ptr_data[0];
        uint64_t dataB = ptr_data[218];
        uint64_t dataC = ptr_data[218 * 2];
        uint64_t dataD = ptr_data[218 * 3];
        uint64_t dataE = ptr_data[218 * 4];
        uint64_t dataF = ptr_data[218 * 5];
        uint64_t dataG = ptr_data[218 * 6];
        // vertical sums, 7 bits to 3
        uint64_t abc0 = (dataA ^ dataB) ^ dataC;
        uint64_t abc1 = (dataA ^ dataB) & dataC | (dataA & dataB);
        uint64_t def0 = (dataD ^ dataE) ^ dataF;
        uint64_t def1 = (dataD ^ dataE) & dataF | (dataD & dataE);
        uint64_t bit0 = (abc0 ^ def0) ^ dataG;
        uint64_t c1   = (abc0 ^ def0) & dataG | (abc0 & def0);
        uint64_t bit1 = (abc1 ^ def1) ^ c1;
        uint64_t bit2 = (abc1 ^ def1) & c1 | (abc1 & def1);
        // add vertical sums to column counts
        __m256i bit0v = _mm256_set1_epi64x(bit0);
        __m256i data01 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data02 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(0, 2, 1, 3));
        data01 = _mm256_and_si256(data01, _mm256_set1_epi8(1));
        data02 = _mm256_and_si256(data02, _mm256_set1_epi8(1));
        __m256i bit1v = _mm256_set1_epi64x(bit1);
        __m256i data11 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data12 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(0, 2, 1, 3));
        data11 = _mm256_and_si256(data11, _mm256_set1_epi8(1));
        data12 = _mm256_and_si256(data12, _mm256_set1_epi8(1));
        data11 = _mm256_add_epi8(data11, data11);
        data12 = _mm256_add_epi8(data12, data12);
        __m256i bit2v = _mm256_set1_epi64x(bit2);
        __m256i data21 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data22 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(0, 2, 1, 3));
        data21 = _mm256_and_si256(data21, _mm256_set1_epi8(1));
        data22 = _mm256_and_si256(data22, _mm256_set1_epi8(1));
        data21 = _mm256_slli_epi16(data21, 2);
        data22 = _mm256_slli_epi16(data22, 2);
        __m256i data1 = _mm256_add_epi8(_mm256_add_epi8(data01, data11), data21);
        __m256i data2 = _mm256_add_epi8(_mm256_add_epi8(data02, data12), data22);

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    }
}
// leftover rows
for (; y < 15125; y++) {
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) {
        __m256i data = _mm256_set1_epi64x(*ptr_data);
        __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
        data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
        data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    }
}

<小时>

目前第二好的方法是更简单的方法,除了一次运行 yloopLen 行以利用快速 8 位求和外,更像第一个版本:


The second best so far was a simpler approach, more like the first version except doing runs of yloopLen rows at once to take advantage of fast 8bit sums:

size_t yloopLen = 32;
size_t yblock = yloopLen * 1;
size_t yy;
for (yy = 0; yy < 15125; yy += yblock) {
    for (size_t x = 0; x < 218; x++) {
        ptr_data = ripped_pdf_data.data() + x;
        ptr_dot_counts = dot_counts.data() + x * 64;
        __m256i zero = _mm256_setzero_si256();

        __m256i c1 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        __m256i c2 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        __m256i c3 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        __m256i c4 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);

        size_t end = std::min(yy + yblock, size_t(15125));
        size_t y;
        for (y = yy; y < end; y += yloopLen) {
            size_t len = std::min(size_t(yloopLen), end - y);
            __m256i count1 = zero;
            __m256i count2 = zero;

            for (size_t t = 0; t < len; t++) {
                __m256i data = _mm256_set1_epi64x(ptr_data[(y + t) * 218]);
                __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
                __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
                data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
                data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));
                count1 = _mm256_add_epi8(count1, data1);
                count2 = _mm256_add_epi8(count2, data2);
            }

            c1 = _mm256_add_epi16(_mm256_unpacklo_epi8(count1, zero), c1);
            c2 = _mm256_add_epi16(_mm256_unpackhi_epi8(count1, zero), c2);
            c3 = _mm256_add_epi16(_mm256_unpacklo_epi8(count2, zero), c3);
            c4 = _mm256_add_epi16(_mm256_unpackhi_epi8(count2, zero), c4);
        }

        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c1);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c2);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c3);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c4);
    }
}

之前有一些测量问题,最后这实际上并没有更好,但也不比上面更高级的垂直求和"版本差多少.

There were some measurement issues before, in the end this wasn't actually better, but also not much worse than the fancier "vertical sum" version above.

这篇关于AVX2 列人口计数算法分别针对每个位列的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆