在Java中获取k个最小(或最大)数组元素的最快方法是什么? [英] What is the fastest way to get k smallest (or largest) elements of array in Java?

查看:81
本文介绍了在Java中获取k个最小(或最大)数组元素的最快方法是什么?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个元素数组(在示例中,这些元素只是简单的整数),使用一些自定义比较器对其进行比较.在此示例中,我通过定义i SMALLER j当且仅当scores[i] <= scores[j]来模拟此比较器.

I have an array of elements (in the example, these are simply integers), which are compared using some custom comparator. In this example, I simulate this comparator by defining i SMALLER j if and only if scores[i] <= scores[j].

我有两种方法:

  • 使用当前k个候选对象的堆
  • 使用当前k个候选数组

我通过以下方式更新上面的两个结构:

I update the upper two structures in the following way:

  • 堆:方法PriorityQueue.pollPriorityQueue.offer
  • array:存储候选数组中前k个候选者中最差的索引top.如果新看到的示例比索引top处的元素更好,则将后者替换为前一个索引,并通过遍历数组的所有k个元素来更新top.
  • heap: methods PriorityQueue.poll and PriorityQueue.offer,
  • array: index top of the worst among top k candidates in the array of candidates is stored. If a newly seen example is better than the element at the index top, the latter is replaced by the former and top is updated by iterating through all k elements of the array.

但是,当我测试了哪种方法更快时,我发现这是第二种方法.问题是:

However, when I have tested, which of the approaches is faster, I found out that this is the second. The questions are:

  • 我对PriorityQueue的使用不是最佳吗?
  • 计算k个最小元素的最快方法是什么?
  • Is my use of PriorityQueue suboptimal?
  • What is the fastest way to compute k smallest elements?

我感兴趣的是这种情况,示例的数量可以很多,但是邻居的数量相对较少(在10到20之间).

I am interested in the case, when the number of examples can be large, but the number of neighbours is relatively small (between 10 and 20).

这是代码:

public static void main(String[] args) {
    long kopica, navadno, sortiranje;

    int numTries = 10000;
    int numExamples = 1000;
    int numNeighbours = 10;

    navadno = testSimple(numExamples, numNeighbours, numTries);
    kopica = testHeap(numExamples, numNeighbours, numTries);

    sortiranje = testSort(numExamples, numNeighbours, numTries, false);
    System.out.println(String.format("tries: %d examples: %d neighbours: %d\n time heap[ms]: %d\n time simple[ms]: %d", numTries, numExamples, numNeighbours, kopica, navadno));
}

public static long testHeap(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        PriorityQueue<Integer> myHeap = new PriorityQueue(numberNeighbours, new Comparator<Integer>(){
            @Override
            public int compare(Integer o1, Integer o2) {
                return -Double.compare(scores[o1], scores[o2]);
            }
        });

        int top;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                myHeap.offer(i);
            } else{
                top = myHeap.peek();
                if(scores[top] > scores[i]){
                    myHeap.poll();
                    myHeap.offer(i);
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

public static long testSimple(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        int[] candidates = new int[numberNeighbours];
        int top = 0;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                candidates[i] = i;
                if(scores[candidates[top]] < scores[candidates[i]]) top = i;
            } else{
                if(scores[candidates[top]] > scores[i]){
                    candidates[top] = i;
                    top = 0;
                    for(int j = 1; j < numberNeighbours; j++){
                        if(scores[candidates[top]] < scores[candidates[j]]) top = j;                            
                    }
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

这将产生以下结果:

tries: 10000 examples: 1000 neighbours: 10
   time heap[ms]: 393
   time simple[ms]: 388

推荐答案

首先,您的基准测试方法不正确.您正在测量输入数据的创建以及算法性能,并且没有在测量之前预热JVM.通过 JMH

First of all, your benchmarking method is incorrect. You are measuring input data creation along with an algorithm performance, and you aren't warming up the JVM before measuring. Results for your code, when tested through the JMH:

Benchmark                     Mode  Cnt      Score   Error  Units
CounterBenchmark.testHeap    thrpt    2  18103,296          ops/s
CounterBenchmark.testSimple  thrpt    2  59490,384          ops/s

修改后的基准 pastebin .

关于两个提供的解决方案之间的3倍差异.用big-O表示法,您的第一个算法可能看起来更好,但实际上big-O表示法仅告诉您缩放方面的算法有多好,而从不告诉您它的执行速度如何(请参阅此问题).在您的情况下,缩放不是问题,因为您的numNeighbours限于20.换句话说,大O标记表示完成算法需要多少滴答声,但并不限制运算符的持续时间.滴答声,它只是说输入变化时滴答持续时间不会改变.就滴答复杂度而言,您的第二种算法肯定会获胜.

Regarding 3x times difference between two provided solutions. In the terms of big-O notation your first algorithm may seem better, but in fact big-O notation only tells your how good the algorithm is in the terms of scaling, it never tells you how fast it performs (see this question also). And in your case scaling is not the issue, as your numNeighbours is limited to 20. In other words big-O notation describes how many ticks of algorithm is necessary for it to complete, but it doesn't limit the duration of a tick, it just says that the tick duration doesn't change when inputs change. And in terms of tick complexity your second algorithm surely wins.

计算k个最小元素的最快方法是什么?

What is the fastest way to compute k smallest elements?

我提出了下一个解决方案,我相信该解决方案可以分支预测来完成其工作:

I've came up with the next solution which I do believe allows branch prediction to do its job:

@Benchmark
public void testModified(Blackhole bh) {
    final double[] scores = sampleData;
    int[] candidates = new int[numberNeighbours];
    for (int i = 0; i < numberNeighbours; i++) {
        candidates[i] = i;
    }
    // sorting candidates so scores[candidates[0]] is the largest
    for (int i = 0; i < numberNeighbours; i++) {
        for (int j = i+1; j < numberNeighbours; j++) {
            if (scores[candidates[i]] < scores[candidates[j]]) {
                int temp = candidates[i];
                candidates[i] = candidates[j];
                candidates[j] = temp;
            }
        }
    }
    // processing other scores, while keeping candidates array sorted in the descending order
    for (int i = numberNeighbours; i < numberExamples; i++) {
        if (scores[i] > scores[candidates[0]]) {
            continue;
        }
        // moving all larger candidates to the left, to keep the array sorted
        int j; // here the branch prediction should kick-in
        for (j = 1; j < numberNeighbours && scores[i] < scores[candidates[j]]; j++) {
            candidates[j - 1] = candidates[j];
        }
        // inserting the new item
        candidates[j - 1] = i;
    }
    bh.consume(candidates);
}

基准测试结果(比当前解决方案快2倍):

Benchmark results (2x times faster than your current solution):

(10 neighbours) CounterBenchmark.testModified    thrpt    2  136492,151          ops/s
(20 neighbours) CounterBenchmark.testModified    thrpt    2  118395,598          ops/s

其他人提到快速选择,但是正如人们所期望的那样,该算法的复杂性忽略了它的复杂性.在您的情况下有利的一面:

Others mentioned quickselect, but as one may expect, the complexity of that algorithm neglects its strong sides in your case:

@Benchmark
public void testQuickSelect(Blackhole bh) {
    final int[] candidates = new int[sampleData.length];
    for (int i = 0; i < candidates.length; i++) {
        candidates[i] = i;
    }
    final int[] resultIndices = new int[numberNeighbours];
    int neighboursToAdd = numberNeighbours;

    int left = 0;
    int right = candidates.length - 1;
    while (neighboursToAdd > 0) {
        int partitionIndex = partition(candidates, left, right);
        int smallerItemsPartitioned = partitionIndex - left;
        if (smallerItemsPartitioned <= neighboursToAdd) {
            while (left < partitionIndex) {
                resultIndices[numberNeighbours - neighboursToAdd--] = candidates[left++];
            }
        } else {
            right = partitionIndex - 1;
        }
    }
    bh.consume(resultIndices);
}

private int partition(int[] locations, int left, int right) {
    final int pivotIndex = ThreadLocalRandom.current().nextInt(left, right + 1);
    final double pivotValue = sampleData[locations[pivotIndex]];
    int storeIndex = left;
    for (int i = left; i <= right; i++) {
        if (sampleData[locations[i]] <= pivotValue) {
            final int temp = locations[storeIndex];
            locations[storeIndex] = locations[i];
            locations[i] = temp;

            storeIndex++;
        }
    }
    return storeIndex;
}

在这种情况下,基准测试结果令人不快:

Benchmark results are pretty upsetting in this case:

CounterBenchmark.testQuickSelect  thrpt    2   11586,761          ops/s

这篇关于在Java中获取k个最小(或最大)数组元素的最快方法是什么?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆