Каков самый быстрый способ получить k наименьших (или наибольших) элементов массива в Java?


У меня есть массив элементов (в примере это просто целые числа), которые сравниваются с помощью некоторого пользовательского компаратора. В этом примере я моделирую этот компаратор, определяя i SMALLER j тогда и только тогда, когда scores[i] <= scores[j].

У меня есть два подхода:

  • используя кучу текущих K кандидатов
  • используя массив текущих K кандидатов

Я обновляю две верхние структуры следующим образом:

  • куча: методы PriorityQueue.poll и PriorityQueue.offer,
  • массив: хранится индекс top худшего из лучших K кандидатов в массиве кандидатов. Если вновь увиденный пример лучше элемента с индексом top, то последний заменяется первым и top обновляется путем перебора всех k элементов массива.
Однако, когда я проверил, какой из подходов быстрее, я обнаружил, что это второй. Вопросы следующие:
    Является ли мое использование PriorityQueue неоптимальным?
  • каков самый быстрый способ вычисления k наименьших элементов?
Меня интересует случай, когда число примеров может быть большим, но число соседей относительно невелико (от 10 до 20).

Вот код:

public static void main(String[] args) {
    long kopica, navadno, sortiranje;

    int numTries = 10000;
    int numExamples = 1000;
    int numNeighbours = 10;

    navadno = testSimple(numExamples, numNeighbours, numTries);
    kopica = testHeap(numExamples, numNeighbours, numTries);

    sortiranje = testSort(numExamples, numNeighbours, numTries, false);
    System.out.println(String.format("tries: %d examples: %d neighbours: %dn time heap[ms]: %dn time simple[ms]: %d", numTries, numExamples, numNeighbours, kopica, navadno));
}

public static long testHeap(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        PriorityQueue<Integer> myHeap = new PriorityQueue(numberNeighbours, new Comparator<Integer>(){
            @Override
            public int compare(Integer o1, Integer o2) {
                return -Double.compare(scores[o1], scores[o2]);
            }
        });

        int top;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                myHeap.offer(i);
            } else{
                top = myHeap.peek();
                if(scores[top] > scores[i]){
                    myHeap.poll();
                    myHeap.offer(i);
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

public static long testSimple(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        int[] candidates = new int[numberNeighbours];
        int top = 0;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                candidates[i] = i;
                if(scores[candidates[top]] < scores[candidates[i]]) top = i;
            } else{
                if(scores[candidates[top]] > scores[i]){
                    candidates[top] = i;
                    top = 0;
                    for(int j = 1; j < numberNeighbours; j++){
                        if(scores[candidates[top]] < scores[candidates[j]]) top = j;                            
                    }
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

Это приводит к следующему результату:

tries: 10000 examples: 1000 neighbours: 10
   time heap[ms]: 393
   time simple[ms]: 388
3 3

3 ответа:

Во-первых, ваш метод бенчмаркинга неверен. Вы измеряете создание входных данных вместе с производительностью алгоритма, и вы не разогреваете JVM перед измерением. Результаты для вашего кода, при тестировании через JMH:

Benchmark                     Mode  Cnt      Score   Error  Units
CounterBenchmark.testHeap    thrpt    2  18103,296          ops/s
CounterBenchmark.testSimple  thrpt    2  59490,384          ops/s

Модифицированный бенчмаркпастебин .

Относительно 3-кратной разницы между двумя предоставленными решениями. В терминах Big-O нотации ваш первый алгоритм может показаться лучше, но на самом деле Big-O нотация только говорит вам, как алгоритм хорош с точки зрения масштабирования, он никогда не говорит вам, как быстро он работает (см. Также этот вопрос ). И в вашем случае масштабирование не является проблемой, так как ваш numNeighbours ограничен 20. Другими словами, обозначение big-O описывает, сколько тиков алгоритма необходимо для его завершения, но оно не ограничивает длительность тика, а просто говорит, что длительность тика не меняется при изменении входных данных. И с точки зрения сложности тика ваш второй алгоритм, безусловно, выигрывает.

Каков самый быстрый способ вычисления k наименьших элементов?

Я придумал следующее решение, которое, как я считаю, позволяет предсказанию ветвей делать свою работу:

@Benchmark
public void testModified(Blackhole bh) {
    final double[] scores = sampleData;
    int[] candidates = new int[numberNeighbours];
    for (int i = 0; i < numberNeighbours; i++) {
        candidates[i] = i;
    }
    // sorting candidates so scores[candidates[0]] is the largest
    for (int i = 0; i < numberNeighbours; i++) {
        for (int j = i+1; j < numberNeighbours; j++) {
            if (scores[candidates[i]] < scores[candidates[j]]) {
                int temp = candidates[i];
                candidates[i] = candidates[j];
                candidates[j] = temp;
            }
        }
    }
    // processing other scores, while keeping candidates array sorted in the descending order
    for (int i = numberNeighbours; i < numberExamples; i++) {
        if (scores[i] > scores[candidates[0]]) {
            continue;
        }
        // moving all larger candidates to the left, to keep the array sorted
        int j; // here the branch prediction should kick-in
        for (j = 1; j < numberNeighbours && scores[i] < scores[candidates[j]]; j++) {
            candidates[j - 1] = candidates[j];
        }
        // inserting the new item
        candidates[j - 1] = i;
    }
    bh.consume(candidates);
}

Результаты бенчмарка (в 2 раза быстрее, чем ваше текущее решение):

(10 neighbours) CounterBenchmark.testModified    thrpt    2  136492,151          ops/s
(20 neighbours) CounterBenchmark.testModified    thrpt    2  118395,598          ops/s

Другие упоминали быстрый выбор , но, как и следовало ожидать, сложность этого алгоритма пренебрегает его сильными сторонами в вашем случае:

@Benchmark
public void testQuickSelect(Blackhole bh) {
    final int[] candidates = new int[sampleData.length];
    for (int i = 0; i < candidates.length; i++) {
        candidates[i] = i;
    }
    final int[] resultIndices = new int[numberNeighbours];
    int neighboursToAdd = numberNeighbours;

    int left = 0;
    int right = candidates.length - 1;
    while (neighboursToAdd > 0) {
        int partitionIndex = partition(candidates, left, right);
        int smallerItemsPartitioned = partitionIndex - left;
        if (smallerItemsPartitioned <= neighboursToAdd) {
            while (left < partitionIndex) {
                resultIndices[numberNeighbours - neighboursToAdd--] = candidates[left++];
            }
        } else {
            right = partitionIndex - 1;
        }
    }
    bh.consume(resultIndices);
}

private int partition(int[] locations, int left, int right) {
    final int pivotIndex = ThreadLocalRandom.current().nextInt(left, right + 1);
    final double pivotValue = sampleData[locations[pivotIndex]];
    int storeIndex = left;
    for (int i = left; i <= right; i++) {
        if (sampleData[locations[i]] <= pivotValue) {
            final int temp = locations[storeIndex];
            locations[storeIndex] = locations[i];
            locations[i] = temp;

            storeIndex++;
        }
    }
    return storeIndex;
}

Результаты бенчмарка довольно расстраивают в этом корпус:

CounterBenchmark.testQuickSelect  thrpt    2   11586,761          ops/s

Создание самого быстрого алгоритма никогда не бывает простым, вам нужно учитывать много вещей. Например, k элементов должны быть возвращены отсортированными или нет, ваше исследование должно быть стабильным (если два элемента равны, вам нужно извлечь перед первым или не нужно) или нет?

В этом конкурсе теоретически лучшим решением является сохранение k наименьших элементов в упорядоченной структуре данных. Поскольку вставка может часто происходить в середине этой структуры данных сбалансированный сортированное дерево кажется оптимальным решением. Но реальность сильно отличается от этого.

Вероятно, наилучшим решением является смешение различных структур данных в зависимости от размера исходного массива и значения k:

  • Если k мало, используйте массив для сохранения k наименьших значений
  • Если k велико, используйте сбалансированное дерево
  • Если k очень большой и близок к размерности массива, просто отсортируйте массив (и если вы не можете создать новую отсортированную копию из него), затем извлеките первые k элементов.
Этот вид алгоритма называетсяhibryd algorithm . Известным гибридным алгоритмом являетсяTim Sort , который используется в классах java для сортировки коллекций.

Примечание: Если вы можете использовать силу многопоточности, то различные алгоритмы и поэтому различные структуры данных можно использовать.


Дополнительная заметка о микро-бенчмарке . На ваши показатели эффективности может сильно влиять внешний фактор. факторы, не связанные с эффективностью вашего алгоритма. Для создания объектов, как это делается в обеих функциях, может потребоваться память, которая недоступна, требующая дополнительной работы, выполняемой GC. Такого рода факторы очень сильно влияют на ваши результаты. По крайней мере, попытайтесь минимизировать код, который не сильно связан с частью кода, подлежащей исследованию. Повторите тесты в разных порядках, подождите перед вызовом тестов, чтобы убедиться, что ни один GC не находится в действии.

Первое решение имеет временную сложность O(numberExamples * log numberNeighbours), а второе - O(numberExamples * numberNeighbours), поэтому оно должно быть медленнее для достаточно больших входных данных. Второе решение быстрее, потому что вы тестируете для малого numberNeighbours, и PriorityQueue имеет больше накладных расходов, чем простой массив. Вы используете PriorityQueue optimal.

Быстрее, но не оптимально, было бы просто отсортировать массив, и тогда наименьшие элементы находятся в k месте.

В любом случае вы можете захотеть реализовать алгоритм QuickSelect, если вы выберете элемент pivot smartly you должен иметь лучшую производительность. Возможно, вы захотите увидеть это https://discuss.leetcode.com/topic/55501/2ms-java-quick-select-only-2-points-to-mention