Tensorflow: обновляются веса нетренируемых слоев модели


У меня есть обученная модель, которая создана с использованием Keras. На этой модели я хочу применить обучение переносу, заморозив все, кроме последнего сверточного слоя. Однако, когда я подгоняю модель после замораживания слоев, я замечаю, что некоторые из замороженных слоев имеют разный вес. Как мне этого избежать?

Я попытался заморозить всю модель с помощью model.trainable = False, но это также не сработало.

Я использую python 3.5.0, tensorflow 1.0.1 и Keras 2.0.3


Пример скрипт

import os
import timeit
import datetime
import numpy as np
from keras.layers.core import Activation, Reshape, Permute
from keras.layers.convolutional import Convolution2D, MaxPooling2D, UpSampling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.optimizers import Adam
from keras import models
from keras import backend as K
K.set_image_dim_ordering('th')

def conv_model(input_shape, data_shape, kern_size, filt_size, pad_size,
                               maxpool_size, n_classes, compile_model=True):
    """
    Create a small conv neural network
    input_shape: input shape of the images
    data_shape: 1d shape of the data
    kern_size: Kernel size used in all convolutional2d layers
    filt_size: Filter size of the first and last convolutional2d layer
    pad_size: size of padding
    maxpool_size: Pool size of all maxpooling2d and upsampling2d layers
    n_classes: number of output classes
    compile_model: True if the model should be compiled

    output: Keras deep learning model
    """
    #keep track of compilation time
    start_time = timeit.default_timer()
    model = models.Sequential()
    # Add a noise layer to get a denoising autoencoder. This helps avoid overfitting
    model.add(ZeroPadding2D(padding=(pad_size, pad_size), input_shape=input_shape))

    #Encoding layers
    model.add(Convolution2D(filt_size, kern_size, kern_size, border_mode='valid'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(maxpool_size, maxpool_size)))
    model.add(UpSampling2D(size=(maxpool_size, maxpool_size)))
    model.add(ZeroPadding2D(padding=(pad_size, pad_size)))
    model.add(Convolution2D(filt_size, kern_size, kern_size, border_mode='valid'))
    model.add(BatchNormalization())
    model.add(Convolution2D(n_classes, 1, 1, border_mode='valid'))
    model.add(Reshape((n_classes, data_shape), input_shape=(n_classes,)+input_shape[1:]))
    model.add(Permute((2, 1)))
    model.add(Activation('softmax'))

    if compile_model:
        model.compile(loss="categorical_crossentropy", optimizer='adam', metrics=["accuracy"])
    print('Model compiled in {0} seconds'.format(datetime.timedelta(seconds=round(
          timeit.default_timer() - start_time))))
    return model

if __name__ == '__main__':
    #Create some random training data
    train_data = np.random.randint(0, 10, 3*512*512*20, dtype='uint8').reshape(-1, 3, 512, 512)
    train_labels = np.random.randint(0, 1, 7*512*512*20, dtype='uint8').reshape(-1, 512*512, 7)
    #Get dims of the data
    data_dims = train_data.shape[2:]
    data_shape = np.prod(data_dims)
    #Create initial model
    initial_model = conv_model((train_data.shape[1], train_data.shape[2], train_data.shape[3]),
                               data_shape, 3, 4, 1, 2, train_labels.shape[-1])
    #Train initial model on first part of the training data
    initial_model.fit(train_data[0:10], train_labels[0:10], verbose=2)
    #Store initial weights
    initial_weights = initial_model.get_weights()

    #Create transfer learning model
    transf_model = conv_model((train_data.shape[1], train_data.shape[2], train_data.shape[3]),
                              data_shape, 3, 4, 1, 2, train_labels.shape[-1], False)
    #Set transfer model weights
    transf_model.set_weights(initial_weights)
    #Set all layers trainable to False (except final conv layer)
    for layer in transf_model.layers:
        layer.trainable = False
    transf_model.layers[9].trainable = True
    print(transf_model.layers[9])
    #Compile model
    transf_model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=1e-4),
                         metrics=["accuracy"])
    #Train model on second part of the data
    transf_model.fit(train_data[10:20], train_labels[10:20], verbose=2)
    #Store transfer model weights
    transf_weights = transf_model.get_weights()

    #Check where the weights have changed
    for i in range(len(initial_weights)):
        update_w = np.sum(initial_weights[i] != transf_weights[i])
        if update_w != 0:
            print(str(update_w)+' updated weights for layer '+str(transf_model.layers[i]))
2 2

2 ответа:

Как только вы скомпилировали свою модель-вы потеряли свои предыдущие веса, так как они были пересчитаны. Вам нужно сначала перенести их, установить веса, чтобы они не поддавались обучению, а затем скомпилировать его:

#Compile model
transf_model.set_weights(initial_weights)

#Set all layers trainable to False (except final conv layer)
for layer in transf_model.layers:
    layer.trainable = False

transf_model.layers[9].trainable = True

transf_model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=1e-4),\
                     metrics=["accuracy"])

В противном случае-веса изменялись бы по мере их пересчета.

EDIT:

Модель должна быть скомпилирована после изменений - потому что во время компиляции keras устанавливает все обучаемые / не обучаемые веса в списке, который в дальнейшем не изменяется.

Следует обновить Керрас к водоснабжении В2.1.3

Эта проблема только что решена, и эта самая последняя функция замораживания слоев BatchNormalization теперь доступна в недавнем выпуске:

Обучаемый атрибут в BatchNormalization теперь отключает обновления пакетной статистики (т. е. если trainable == False, то слой теперь будет работать на 100% в режиме вывода).

Причина ошибки:

В предыдущих версиях дисперсия и среднее значение параметры слоев BatchNormalization не могли установить untrainable , и это не сработало, хотя вы сидели layer.trainable = False.

Теперь это работает!