Быстрый способ инициализации тензора в torch7
Мне нужно инициализировать трехмерный тензор с зависимой от индекса функцией в torch7, т. е.
func = function(i,j,k) --i, j is the index of an element in the tensor
return i*j*k --do operations within func which're dependent of i, j
end
Затем я инициализирую трехмерный тензор A следующим образом:
for i=1,A:size(1) do
for j=1,A:size(2) do
for k=1,A:size(3) do
A[{i,j,k}] = func(i,j,k)
end
end
end
Но этот код работает очень медленно, и я обнаружил, что он занимает 92% от общего времени выполнения. Существуют ли более эффективные способы инициализации 3D тензора в torch7?1 ответ:
См. документацию по
Tensor:apply
Эти функции применяют функцию к каждому элементу тензора на который метод называется (self). Эти методы гораздо быстрее, чем использование цикла for В Lua.
Пример в docs инициализирует двумерный массив на основе его индекса i (в памяти). Ниже приведен расширенный пример для 3 измерений и ниже-для N-D тензоров. Используя метод применить гораздо, Много быстрее машина:
require 'torch' A = torch.Tensor(100, 100, 1000) B = torch.Tensor(100, 100, 1000) function func(i,j,k) return i*j*k end t = os.clock() for i=1,A:size(1) do for j=1,A:size(2) do for k=1,A:size(3) do A[{i, j, k}] = i * j * k end end end print("Original time:", os.difftime(os.clock(), t)) t = os.clock() function forindices(A, func) local i = 1 local j = 1 local k = 0 local d3 = A:size(3) local d2 = A:size(2) return function() k = k + 1 if k > d3 then k = 1 j = j + 1 if j > d2 then j = 1 i = i + 1 end end return func(i, j, k) end end B:apply(forindices(A, func)) print("Apply method:", os.difftime(os.clock(), t))
EDIT
Это будет работать для любого тензорного объекта:
function tabulate(A, f) local idx = {} local ndims = A:dim() local dim = A:size() idx[ndims] = 0 for i=1, (ndims - 1) do idx[i] = 1 end return A:apply(function() for i=ndims, 0, -1 do idx[i] = idx[i] + 1 if idx[i] <= dim[i] then break end idx[i] = 1 end return f(unpack(idx)) end) end -- usage for 3D case. tabulate(A, function(i, j, k) return i * j * k end)