Как использовать функцию merge и switch из tensorflow?


merge и switch не могут быть открыты для использования обычными пользователями. И я искал исходный код:

Есть описание в merge:

Возвращает значение доступного элемента inputs.

Что это значит доступно ? Возвращается ли он switch? Это демо:

from tensorflow.python.ops import control_flow_ops

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
y = control_flow_ops.merge([x_0, x_1, x_2, x_3])
with tf.Session() as sess:
    print(sess.run(y))
1 2

1 ответ:

switch

Начнем с рассмотрения функции control_flow_ops.switch:
x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
with tf.Session() as sess:
  print(sess.run(x_0))    # prints 2
  print(sess.run(x_3))    # prints 7

control_flow_ops.switch возвращает кортеж тензоров, но только один из них будет иметь значение (в зависимости от аргумента условия). В приведенном выше примере это x_0 = 2 из первого switch и x_3 = 7 из второго. Попытка оценить x_1 или x_2 приведет к Retval не имеет значения ошибка:

  sess.run(x_1)  # FAILS!
  sess.run(x_2)  # FAILS!
Другими словами, x_0 и x_3 являются доступными , в то время как x_1 или x_2 нет.

merge

control_flow_ops.merge выполняет обратную операцию: учитывая кортеж тензоров, он выбирает доступный. Именно, он возвращает именованный кортеж ["output", "value_index"] тензора, который имеет значение. Согласно текущему doc, входные данные должны содержать ровно один доступный тензор, это означает, что ваша демонстрация строго говоря не поддерживается и приводит к неопределенному поведению. Вот пример:

with tf.Session() as sess:
  print(sess.run(merge([x_0, x_1])))       # Merge(output=2, value_index=0)
  print(sess.run(merge([x_1, x_0])))       # Merge(output=2, value_index=1)
  print(sess.run(merge([x_2, x_3])))       # Merge(output=7, value_index=1)
  print(sess.run(merge([x_3, x_2])))       # Merge(output=7, value_index=0)
  print(sess.run(merge([x_0, x_1, x_2])))  # Merge(output=2, value_index=0)
  print(sess.run(merge([x_1, x_2, x_3])))  # Merge(output=7, value_index=2)

Обе эти функции могут быть удобны для управления. поток вычислений, например control_flow_ops.switch градиент реализуется через switch Сам (исходный код tensorflow).