Самый быстрый способ сортировки в Python (без cython)
У меня есть проблема, когда я должен отсортировать очень большой массив(shape - 7900000X4X4) с помощью пользовательской функции. Я использовал sorted
, но на сортировку ушло больше часа. Мой код был примерно таким.
def compare(x,y):
print('DD '+str(x[0]))
if(np.array_equal(x[1],y[1])==True):
return -1
a = x[1].flatten()
b = y[1].flatten()
idx = np.where( (a>b) != (a<b) )[0][0]
if a[idx]<0 and b[idx]>=0:
return 0
elif b[idx]<0 and a[idx]>=0:
return 1
elif a[idx]<0 and b[idx]<0:
if a[idx]>b[idx]:
return 0
elif a[idx]<b[idx]:
return 1
elif a[idx]<b[idx]:
return 1
else:
return 0
def cmp_to_key(mycmp):
class K:
def __init__(self, obj, *args):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj)
return K
tblocks = sorted(tblocks.items(),key=cmp_to_key(compare))
Это сработало, но я хочу, чтобы это завершилось в считанные секунды. Я не думаю, что какая-либо прямая реализация в python может дать мне необходимую производительность, поэтому я попробовал cython. Мой код Цитона таков, и он довольно прост.
cdef int[:,:] arrr
cdef int size
cdef bool compare(int a,int b):
global arrr,size
cdef int[:] x = arrr[a]
cdef int[:] y = arrr[b]
cdef int i,j
i = 0
j = 0
while(i<size):
if((j==size-1)or(y[j]<x[i])):
return 0
elif(x[i]<y[j]):
return 1
i+=1
j+=1
return (j!=size-1)
def sorted(np.ndarray boxes,int total_blocks,int s):
global arrr,size
cdef int i
cdef vector[int] index = xrange(total_blocks)
arrr = boxes
size = s
sort(index.begin(),index.end(),compare)
return index
Этот код в cython занял 33 секунды! Цитон-это решение проблемы, но я ... я ищу некоторые альтернативные решения, которые могут работать непосредственно на python. Например, numba. Я попробовал Numba, но не получил удовлетворительных результатов. Пожалуйста, помогите!2 ответа:
Трудно дать ответ без рабочего примера. Я предполагаю, что arrr в вашем коде Cython был 2D-массивом, и я предполагаю, что размер был
size=arrr.shape[0]
Реализация Numba
import numpy as np import numba as nb from numba.targets import quicksort def custom_sorting(compare_fkt): index_arange=np.arange(size) quicksort_func=quicksort.make_jit_quicksort(lt=compare_fkt,is_argsort=False) jit_sort_func=nb.njit(quicksort_func.run_quicksort) index=jit_sort_func(index_arange) return index def compare(a,b): x = arrr[a] y = arrr[b] i = 0 j = 0 while(i<size): if((j==size-1)or(y[j]<x[i])): return False elif(x[i]<y[j]): return True i+=1 j+=1 return (j!=size-1) arrr=np.random.randint(-9,10,(7900000,8)) size=arrr.shape[0] index=custom_sorting(compare)
Это дает 3,85 s для генерируемых тестовых данных. Но скорость алгоритма сортировки сильно зависит от данных....
Простой Пример
import numpy as np import numba as nb from numba.targets import quicksort #simple reverse sort def compare(a,b): return a > b #create some test data arrr=np.array(np.random.rand(7900000)*10000,dtype=np.int32) #we can pass the comparison function quicksort_func=quicksort.make_jit_quicksort(lt=compare,is_argsort=True) #compile the sorting function jit_sort_func=nb.njit(quicksort_func.run_quicksort) #get the result ind_sorted=jit_sort_func(arrr)
Эта реализация примерно на 35% медленнее, чем np.argsort, но это также распространено в использовании np.argsort в скомпилированный код.
Если я правильно понимаю ваш код, то порядок, который вы имеете в виду, является стандартным порядком, только он начинается с
0
, оборачивается в+/-infinity
и завершается в-0
. Кроме того, мы имеем простой лексикографический порядок слева направо.Теперь, если Ваш массив dtype является целочисленным, обратите внимание на следующее: Из-за дополнительного представления негативов приведение вида к unsigned int делает ваш порядок стандартным порядком. Кроме того, если мы используем кодировку big endian, эффективная лексикографическая упорядочение может быть достигнуто путем приведения вида к
Приведенный ниже код показывает, что с помощью примераvoid
dtype.10000x4x4
Этот метод дает тот же результат, что и ваш код Python.Он также проверяет его на примере
7,900,000x4x4
(используя массив, а не дикт). На моем скромном ноутбуке этот метод занимает8
секунд.import numpy as np def compare(x, y): # print('DD '+str(x[0])) if(np.array_equal(x[1],y[1])==True): return -1 a = x[1].flatten() b = y[1].flatten() idx = np.where( (a>b) != (a<b) )[0][0] if a[idx]<0 and b[idx]>=0: return 0 elif b[idx]<0 and a[idx]>=0: return 1 elif a[idx]<0 and b[idx]<0: if a[idx]>b[idx]: return 0 elif a[idx]<b[idx]: return 1 elif a[idx]<b[idx]: return 1 else: return 0 def cmp_to_key(mycmp): class K: def __init__(self, obj, *args): self.obj = obj def __lt__(self, other): return mycmp(self.obj, other.obj) return K def custom_sort(a): assert a.dtype==np.int64 b = a.astype('>i8', copy=False) return b.view(f'V{a.dtype.itemsize * a.shape[1]}').ravel().argsort() tblocks = np.random.randint(-9,10, (10000, 4, 4)) tblocks = dict(enumerate(tblocks)) tblocks_s = sorted(tblocks.items(),key=cmp_to_key(compare)) tblocksa = np.array(list(tblocks.values())) tblocksa = tblocksa.reshape(tblocksa.shape[0], -1) order = custom_sort(tblocksa) tblocks_s2 = list(tblocks.items()) tblocks_s2 = [tblocks_s2[o] for o in order] print(tblocks_s == tblocks_s2) from timeit import timeit data = np.random.randint(-9_999, 10_000, (7_900_000, 4, 4)) print(timeit(lambda: data[custom_sort(data.reshape(data.shape[0], -1))], number=5) / 5)
Пример вывода:
True 7.8328493310138585