Skip to content

Commit b2e8187

Browse files
liyafan82kou
authored andcommitted
ARROW-7400: [Java] Avoid the worst case for quick sort
This issue is in response of a discussion in: apache/arrow#5540 (comment). The quick sort algorithm can degenerate to an O(n^2) algorithm, if the pivot is selected poorly. This is an important problem, as the worst case can happen, if the input vector is alrady sorted, which is frequently encountered in practice. After some investigation, we solve the problem with a simple but effective approach: take 3 samples and choose the median (with at most 3 comparisons) as the pivot. This sorts the vector which is already sorted in O(nlogn) time. Closes #6039 from liyafan82/fly_1213_sort and squashes the following commits: 0943b0692 <liyafan82> Make tests more readable 7cdf0a694 <liyafan82> Fix the bug of choosing pivot and add more tests e6ab2bf1f <liyafan82> Apply insertion sort when the range is small 1167176b4 <liyafan82> Avoids the worst case for quick sort Authored-by: liyafan82 <fan_li_ya@foxmail.com> Signed-off-by: Micah Kornfield <emkornfield@gmail.com>
1 parent 96a8adc commit b2e8187

File tree

7 files changed

+513
-15
lines changed

7 files changed

+513
-15
lines changed

algorithm/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
<version>${project.version}</version>
2828
<classifier>${arrow.vector.classifier}</classifier>
2929
</dependency>
30+
<dependency>
31+
<groupId>org.apache.arrow</groupId>
32+
<artifactId>arrow-vector</artifactId>
33+
<version>${project.version}</version>
34+
<type>test-jar</type>
35+
</dependency>
3036
<dependency>
3137
<groupId>org.apache.arrow</groupId>
3238
<artifactId>arrow-memory</artifactId>

algorithm/src/main/java/org/apache/arrow/algorithm/sort/FixedWidthInPlaceVectorSorter.java

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,25 @@
2626
*/
2727
public class FixedWidthInPlaceVectorSorter<V extends BaseFixedWidthVector> implements InPlaceVectorSorter<V> {
2828

29-
private VectorValueComparator<V> comparator;
29+
/**
30+
* If the number of items is smaller than this threshold, we will use another algorithm to sort the data.
31+
*/
32+
public static final int CHANGE_ALGORITHM_THRESHOLD = 15;
33+
34+
static final int STOP_CHOOSING_PIVOT_THRESHOLD = 3;
35+
36+
VectorValueComparator<V> comparator;
3037

3138
/**
3239
* The vector to sort.
3340
*/
34-
private V vec;
41+
V vec;
3542

3643
/**
3744
* The buffer to hold the pivot.
3845
* It always has length 1.
3946
*/
40-
private V pivotBuffer;
47+
V pivotBuffer;
4148

4249
@Override
4350
public void sortInPlace(V vec, VectorValueComparator<V> comparator) {
@@ -64,6 +71,12 @@ private void quickSort() {
6471
int high = rangeStack.pop();
6572
int low = rangeStack.pop();
6673
if (low < high) {
74+
if (high - low < CHANGE_ALGORITHM_THRESHOLD) {
75+
// switch to insertion sort
76+
InsertionSorter.insertionSort(vec, low, high, comparator, pivotBuffer);
77+
continue;
78+
}
79+
6780
int mid = partition(low, high);
6881

6982
// push the larger part to stack first,
@@ -86,8 +99,55 @@ private void quickSort() {
8699
}
87100
}
88101

102+
/**
103+
* Select the pivot as the median of 3 samples.
104+
*/
105+
void choosePivot(int low, int high) {
106+
// we need at least 3 items
107+
if (high - low + 1 < STOP_CHOOSING_PIVOT_THRESHOLD) {
108+
pivotBuffer.copyFrom(low, 0, vec);
109+
return;
110+
}
111+
112+
comparator.attachVector(vec);
113+
int mid = low + (high - low) / 2;
114+
115+
// find the median by at most 3 comparisons
116+
int medianIdx;
117+
if (comparator.compare(low, mid) < 0) {
118+
if (comparator.compare(mid, high) < 0) {
119+
medianIdx = mid;
120+
} else {
121+
if (comparator.compare(low, high) < 0) {
122+
medianIdx = high;
123+
} else {
124+
medianIdx = low;
125+
}
126+
}
127+
} else {
128+
if (comparator.compare(mid, high) > 0) {
129+
medianIdx = mid;
130+
} else {
131+
if (comparator.compare(low, high) < 0) {
132+
medianIdx = low;
133+
} else {
134+
medianIdx = high;
135+
}
136+
}
137+
}
138+
139+
// move the pivot to the low position, if necessary
140+
if (medianIdx != low) {
141+
pivotBuffer.copyFrom(medianIdx, 0, vec);
142+
vec.copyFrom(low, medianIdx, vec);
143+
vec.copyFrom(0, low, pivotBuffer);
144+
}
145+
146+
comparator.attachVectors(vec, pivotBuffer);
147+
}
148+
89149
private int partition(int low, int high) {
90-
pivotBuffer.copyFrom(low, 0, vec);
150+
choosePivot(low, high);
91151

92152
while (low < high) {
93153
while (low < high && comparator.compare(high, 0) >= 0) {

algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
*/
2929
public class IndexSorter<V extends ValueVector> {
3030

31+
/**
32+
* If the number of items is smaller than this threshold, we will use another algorithm to sort the data.
33+
*/
34+
public static final int CHANGE_ALGORITHM_THRESHOLD = 15;
35+
3136
/**
3237
* Comparator for vector indices.
3338
*/
@@ -68,6 +73,11 @@ private void quickSort() {
6873
int low = rangeStack.pop();
6974

7075
if (low < high) {
76+
if (high - low < CHANGE_ALGORITHM_THRESHOLD) {
77+
InsertionSorter.insertionSort(indices, low, high, comparator);
78+
continue;
79+
}
80+
7181
int mid = partition(low, high, indices, comparator);
7282

7383
// push the larger part to stack first,
@@ -90,6 +100,53 @@ private void quickSort() {
90100
}
91101
}
92102

103+
/**
104+
* Select the pivot as the median of 3 samples.
105+
*/
106+
static <T extends ValueVector> int choosePivot(
107+
int low, int high, IntVector indices, VectorValueComparator<T> comparator) {
108+
// we need at least 3 items
109+
if (high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD) {
110+
return indices.get(low);
111+
}
112+
113+
int mid = low + (high - low) / 2;
114+
115+
// find the median by at most 3 comparisons
116+
int medianIdx;
117+
if (comparator.compare(indices.get(low), indices.get(mid)) < 0) {
118+
if (comparator.compare(indices.get(mid), indices.get(high)) < 0) {
119+
medianIdx = mid;
120+
} else {
121+
if (comparator.compare(indices.get(low), indices.get(high)) < 0) {
122+
medianIdx = high;
123+
} else {
124+
medianIdx = low;
125+
}
126+
}
127+
} else {
128+
if (comparator.compare(indices.get(mid), indices.get(high)) > 0) {
129+
medianIdx = mid;
130+
} else {
131+
if (comparator.compare(indices.get(low), indices.get(high)) < 0) {
132+
medianIdx = low;
133+
} else {
134+
medianIdx = high;
135+
}
136+
}
137+
}
138+
139+
// move the pivot to the low position, if necessary
140+
if (medianIdx != low) {
141+
int tmp = indices.get(medianIdx);
142+
indices.set(medianIdx, indices.get(low));
143+
indices.set(low, tmp);
144+
return tmp;
145+
} else {
146+
return indices.get(low);
147+
}
148+
}
149+
93150
/**
94151
* Partition a range of values in a vector into two parts, with elements in one part smaller than
95152
* elements from the other part. The partition is based on the element indices, so it does
@@ -103,7 +160,7 @@ private void quickSort() {
103160
*/
104161
public static <T extends ValueVector> int partition(
105162
int low, int high, IntVector indices, VectorValueComparator<T> comparator) {
106-
int pivotIndex = indices.get(low);
163+
int pivotIndex = choosePivot(low, high, indices, comparator);
107164

108165
while (low < high) {
109166
while (low < high && comparator.compare(indices.get(high), pivotIndex) >= 0) {
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.arrow.algorithm.sort;
19+
20+
import org.apache.arrow.vector.BaseFixedWidthVector;
21+
import org.apache.arrow.vector.IntVector;
22+
import org.apache.arrow.vector.ValueVector;
23+
24+
/**
25+
* Insertion sorter.
26+
*/
27+
class InsertionSorter {
28+
29+
/**
30+
* Sorts the range of a vector by insertion sort.
31+
*
32+
* @param vector the vector to be sorted.
33+
* @param startIdx the start index of the range (inclusive).
34+
* @param endIdx the end index of the range (inclusive).
35+
* @param buffer an extra buffer with capacity 1 to hold the current key.
36+
* @param comparator the criteria for vector element comparison.
37+
* @param <V> the vector type.
38+
*/
39+
static <V extends BaseFixedWidthVector> void insertionSort(
40+
V vector, int startIdx, int endIdx, VectorValueComparator<V> comparator, V buffer) {
41+
comparator.attachVectors(vector, buffer);
42+
for (int i = startIdx; i <= endIdx; i++) {
43+
buffer.copyFrom(i, 0, vector);
44+
int j = i - 1;
45+
while (j >= startIdx && comparator.compare(j, 0) > 0) {
46+
vector.copyFrom(j, j + 1, vector);
47+
j = j - 1;
48+
}
49+
vector.copyFrom(0, j + 1, buffer);
50+
}
51+
}
52+
53+
/**
54+
* Sorts the range of vector indices by insertion sort.
55+
*
56+
* @param indices the vector indices.
57+
* @param startIdx the start index of the range (inclusive).
58+
* @param endIdx the end index of the range (inclusive).
59+
* @param comparator the criteria for vector element comparison.
60+
* @param <V> the vector type.
61+
*/
62+
static <V extends ValueVector> void insertionSort(
63+
IntVector indices, int startIdx, int endIdx, VectorValueComparator<V> comparator) {
64+
for (int i = startIdx; i <= endIdx; i++) {
65+
int key = indices.get(i);
66+
int j = i - 1;
67+
while (j >= startIdx && comparator.compare(indices.get(j), key) > 0) {
68+
indices.set(j + 1, indices.get(j));
69+
j = j - 1;
70+
}
71+
indices.set(j + 1, key);
72+
}
73+
}
74+
}

algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthInPlaceVectorSorter.java

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
package org.apache.arrow.algorithm.sort;
1919

20+
import static org.junit.Assert.assertEquals;
2021
import static org.junit.Assert.assertTrue;
2122

2223
import org.apache.arrow.memory.BufferAllocator;
2324
import org.apache.arrow.memory.RootAllocator;
2425
import org.apache.arrow.vector.IntVector;
26+
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
2527
import org.junit.After;
2628
import org.junit.Assert;
2729
import org.junit.Before;
@@ -114,4 +116,97 @@ public void testSortLargeIncreasingInt() {
114116
}
115117
}
116118
}
119+
120+
@Test
121+
public void testChoosePivot() {
122+
final int vectorLength = 100;
123+
try (IntVector vec = new IntVector("", allocator)) {
124+
vec.allocateNew(vectorLength);
125+
126+
// the vector is sorted, so the pivot should be in the middle
127+
for (int i = 0; i < vectorLength; i++) {
128+
vec.set(i, i * 100);
129+
}
130+
vec.setValueCount(vectorLength);
131+
132+
FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter();
133+
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(vec);
134+
135+
try (IntVector pivotBuffer = (IntVector) vec.getField().createVector(allocator)) {
136+
// setup internal data structures
137+
pivotBuffer.allocateNew(1);
138+
sorter.pivotBuffer = pivotBuffer;
139+
sorter.comparator = comparator;
140+
sorter.vec = vec;
141+
comparator.attachVectors(vec, pivotBuffer);
142+
143+
int low = 5;
144+
int high = 6;
145+
int pivotValue = vec.get(low);
146+
assertTrue(high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD);
147+
148+
// the range is small enough, so the pivot is simply selected as the low value
149+
sorter.choosePivot(low, high);
150+
assertEquals(pivotValue, vec.get(low));
151+
152+
low = 30;
153+
high = 80;
154+
pivotValue = vec.get((low + high) / 2);
155+
assertTrue(high - low + 1 >= FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD);
156+
157+
// the range is large enough, so the median is selected as the pivot
158+
sorter.choosePivot(low, high);
159+
assertEquals(pivotValue, vec.get(low));
160+
}
161+
}
162+
}
163+
164+
/**
165+
* Evaluates choosing pivot for all possible permutations of 3 numbers.
166+
*/
167+
@Test
168+
public void testChoosePivotAllPermutes() {
169+
try (IntVector vec = new IntVector("", allocator)) {
170+
vec.allocateNew(3);
171+
172+
FixedWidthInPlaceVectorSorter sorter = new FixedWidthInPlaceVectorSorter();
173+
VectorValueComparator<IntVector> comparator = DefaultVectorComparators.createDefaultComparator(vec);
174+
175+
try (IntVector pivotBuffer = (IntVector) vec.getField().createVector(allocator)) {
176+
// setup internal data structures
177+
pivotBuffer.allocateNew(1);
178+
sorter.pivotBuffer = pivotBuffer;
179+
sorter.comparator = comparator;
180+
sorter.vec = vec;
181+
comparator.attachVectors(vec, pivotBuffer);
182+
183+
int low = 0;
184+
int high = 2;
185+
186+
ValueVectorDataPopulator.setVector(vec, 11, 22, 33);
187+
sorter.choosePivot(low, high);
188+
assertEquals(22, vec.get(0));
189+
190+
ValueVectorDataPopulator.setVector(vec, 11, 33, 22);
191+
sorter.choosePivot(low, high);
192+
assertEquals(22, vec.get(0));
193+
194+
ValueVectorDataPopulator.setVector(vec, 22, 11, 33);
195+
sorter.choosePivot(low, high);
196+
assertEquals(22, vec.get(0));
197+
198+
ValueVectorDataPopulator.setVector(vec, 22, 33, 11);
199+
sorter.choosePivot(low, high);
200+
assertEquals(22, vec.get(0));
201+
202+
ValueVectorDataPopulator.setVector(vec, 33, 11, 22);
203+
sorter.choosePivot(low, high);
204+
assertEquals(22, vec.get(0));
205+
206+
ValueVectorDataPopulator.setVector(vec, 33, 22, 11);
207+
sorter.choosePivot(low, high);
208+
assertEquals(22, vec.get(0));
209+
}
210+
}
211+
}
117212
}

0 commit comments

Comments
 (0)