TensorFlow, почему есть 3 файла после сохранения модели?


читать docs, Я сохранил модель в TensorFlow, вот мой демо-код:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

но после этого, я обнаружил, что есть 3 файлы

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

и я не могу восстановить модель восстановления model.ckpt файл, так как такого файла нет. Вот мой код

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Итак, почему есть 3 файлов?

4 69

4 ответа:

попробуйте это:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

метод TensorFlow save сохраняет три вида файлов, потому что он хранит графической структуры отдельно переменной значения. Элемент .meta файл описывает сохраненную структуру графика, поэтому вам нужно импортировать ее перед восстановлением контрольной точки (в противном случае он не знает, каким переменным соответствуют сохраненные значения контрольных точек).

в качестве альтернативы, вы можете сделать это:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

хотя там нет файла с именем model.ckpt, вы все еще ссылаетесь на сохраненную контрольную точку этим именем при ее восстановлении. Из saver.py исходный код:

пользователи должны взаимодействовать только с указанным пользователем префиксом... вместо любого физического пути.

  • мета файл: описывает сохраненную структуру графа, включает GraphDef, SaverDef и т. д.; затем применить tf.train.import_meta_graph('/tmp/model.ckpt.meta'), восстановим Saver и Graph.

  • индекс: это строка-строка неизменяемая таблица (tensorflow::table::Table). Каждый ключ-это имя тензора, а его значение-сериализованный BundleEntryProto. Каждый BundleEntryProto описывает метаданные тензора: какой из файлов "данных" содержит содержимое тензора, смещение в этот файл, контрольная сумма, некоторые вспомогательные данные и т. д.

  • файл данных: это коллекция TensorBundle, сохраните значения всех переменных.

я восстанавливаю обученные вложения слов из Word2Vec tensorflow учебник.

Если вы создали несколько контрольных точек:

например, созданные файлы выглядят так

модель.ckpt-55695.data-00000-of-00001

модель.ckpt-55695.индекс

модель.ckpt-55695.мета

попробуй такое

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

при вызове restore_session():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

Если вы обучали CNN с отсевом, например, вы могли бы сделать это:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})