tensorflow: ошибка общих переменных в простой сети LSTM
Я пытаюсь построить простейшую из возможных LSTM-сетей. Просто хочу, чтобы он предсказал следующее значение в последовательности np_input_data
.
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
import numpy as np
num_steps = 3
num_units = 1
np_input_data = [np.array([[1.],[2.]]), np.array([[2.],[3.]]), np.array([[3.],[4.]])]
batch_size = 2
graph = tf.Graph()
with graph.as_default():
tf_inputs = [tf.placeholder(tf.float32, [batch_size, 1]) for _ in range(num_steps)]
lstm = rnn_cell.BasicLSTMCell(num_units)
initial_state = state = tf.zeros([batch_size, lstm.state_size])
loss = 0
for i in range(num_steps-1):
output, state = lstm(tf_inputs[i], state)
loss += tf.reduce_mean(tf.square(output - tf_inputs[i+1]))
with tf.Session(graph=graph) as session:
tf.initialize_all_variables().run()
feed_dict={tf_inputs[i]: np_input_data[i] for i in range(len(np_input_data))}
loss = session.run(loss, feed_dict=feed_dict)
print(loss)
Интерпретатор возвращает:
ValueError: Variable BasicLSTMCell/Linear/Matrix already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:
output, state = lstm(tf_inputs[i], state)
Что я делаю не так?
3 ответа:
Вызов к
lstm
здесь:for i in range(num_steps-1): output, state = lstm(tf_inputs[i], state)
Будет пытаться создавать переменные с одинаковым именем на каждой итерации, если вы не скажете иначе. Вы можете сделать это с помощью
tf.variable_scope
Первая итерация создает переменные, представляющие ваши параметры LSTM, и каждая последующая итерация (после вызоваwith tf.variable_scope("myrnn") as scope: for i in range(num_steps-1): if i > 0: scope.reuse_variables() output, state = lstm(tf_inputs[i], state)
reuse_variables
) будет просто искать их в области по имени.
Я столкнулся с аналогичной проблемой в TensorFlow В1.0.1 используя
tf.nn.dynamic_rnn
. Оказалось, что ошибка возникала только в том случае, если мне приходилось переучиваться или отменять тренировку в середине тренировки и перезапускать тренировочный процесс. В основном график не сбрасывался.Короче говоря, бросьте
tf.reset_default_graph()
в начале вашего кода и это должно помочь. По крайней мере, при использованииtf.nn.dynamic_rnn
и переподготовке.