Как работает метод view() для тензора в Факеле
у меня путаница с методом view()
в следующем фрагменте кода.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
мое замешательство касается следующей строки.
x = x.view(-1, 16*5*5)
что значит
2 ответа:
функция просмотра предназначена для изменения тензора.
скажем, у вас есть тензор
import torch a = torch.range(1, 16)
a
тензор, который имеет 16 элементов от 1 до 16(в комплекте). Если вы хотите изменить этот тензор, чтобы сделать это4 x 4
тензор, то вы можете использоватьa = a.view(4, 4)
теперь
a
будет4 x 4
тензор.(Обратите внимание, что после изменения формы общее количество элементов должно оставаться неизменным. Изменение тензораa
до3 x 5
тензор не будет уместно)что означает параметр -1?
если есть какая-либо ситуация, что вы не знаете, сколько строк вы хотите, но уверены в количестве столбцов, то вы можете упомянуть его как -1(вы можете расширить это до тензоров с большим количеством измерений. Только одно из значений оси может быть -1). Это способ сообщить библиотеке; дайте мне тензор, который имеет эти многие столбцы, и вы вычислите соответствующее количество строк, которое необходимо для этого случаться.
это можно увидеть в коде нейронной сети, который вы дали выше. После строки
x = self.pool(F.relu(self.conv2(x)))
в функции "вперед" у вас будет карта объектов с 16 глубинами. Вы должны сгладить это, чтобы дать ему полностью подключенный слой. Поэтому вы говорите pytorch изменить тензор, который вы получили, чтобы иметь определенное количество столбцов, и сказать ему, чтобы он сам определил количество строк.рисуя сходство между numpy и pytorch,
view
аналогично и NumPy это изменить.
давайте сделаем несколько примеров, от простого к более сложному.
The
view
метод возвращает тензор с теми же данными, что иself
тензор (что означает, что возвращаемый тензор имеет такое же количество элементов), но с другой формой. Например:a = torch.arange(1, 17) # a's shape is (16,) a.view(4, 4) # output below 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 [torch.FloatTensor of size 4x4] a.view(2, 2, 4) # output below (0 ,.,.) = 1 2 3 4 5 6 7 8 (1 ,.,.) = 9 10 11 12 13 14 15 16 [torch.FloatTensor of size 2x2x4]
предполагая, что
-1
не является одним из параметров, когда вы умножить их вместе, результат должен быть равен количеству элементов в Тензоре. Если вы это сделаете:a.view(3, 3)
, он подниметRuntimeError
потому что форма (3 x 3) недопустима для ввода с 16 элементами. Другими словами: 3 х 3 не равно 16, а 9.можно использовать
-1
как один из параметров, которые вы передаете в функцию, но только один раз. Все, что происходит, это то, что метод будет делать математику для вас о том, как заполнить это измерение. Напримерa.view(2, -1, 4)
эквивалентноa.view(2, 2, 4)
. [16 / (2 x 4) = 2]обратите внимание, что вернулся тензор разделяет те же данные. Если вы вносите изменения в "вид" вы меняете исходные данные тензора:
b = a.view(4, 4) b[0, 2] = 2 a[2] == 3.0 False
теперь, для более сложного варианта использования. В документации говорится, что каждое новое измерение представления должно быть либо подпространством исходного измерения, либо только span d, d + 1, ..., d + k которые удовлетворяют следующему условию смежности, что для всех i = 0, ..., к - 1, шаг[i] = шаг[i + 1] x размер[i + 1]. В противном случае,
contiguous()
должен быть вызван прежде, чем тензор можно рассматривать. Например:a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2) a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4) # The commented line below will raise a RuntimeError, because one dimension # spans across two contiguous subspaces # a_t.view(-1, 4) # instead do: a_t.contiguous().view(-1, 4) # To see why the first one does not work and the second does, # compare a.stride() and a_t.stride() a.stride() # (24, 6, 2, 1) a_t.stride() # (24, 2, 1, 6)
обратите внимание, что для
a_t
,stride[0]!= Шаг[1] размер х[1] С 24 != 2 x 3