For-петля в Тензоре в Тензорном потоке


Я новичок в tensorflow, я хочу сделать тензор, используя ряд условий if-else. Я просто не знаю, как это сделать.

В python, если тензор похож на [3,3,3], я могу использовать цикл for, как показано ниже:

for i in range(3):
   for j in range(3):
      for k in range(3):
         if tensor[i,j,k]>10:
            tensor[i,j,k]=tensor[i,j,k]-10
         elif tensor[i,j,k]<4:
            tensor[i,j,k]=tensor[i,j,k]+60

После этого я все еще хочу вычислить функции loos с помощью тензора, а затем перейти к следующему циклу для обучения. Кто-нибудь знает, как это сделать? Я знаю, как сделать это одним способом в течение сеанса. Но я не знаю, как это сделать в тренировочном цикле.

1 2

1 ответ:

Путь тензорного потока

Ваш конкретный пример легко векторизуется, поэтому нет реальной необходимости делать это через for-loop. Вот чистое тензорное решение:

x = tf.placeholder(shape=[3, 3], dtype=tf.float32)
cond1 = tf.where(x > 10, x - 10, tf.zeros_like(x))
cond2 = tf.where(x < 4, x + 60, tf.zeros_like(x))
cond3 = tf.where(tf.logical_and(x >= 4, x <= 10), x, tf.zeros_like(x))
y = cond1 + cond2 + cond3

py_func Путь

Если вам случайно придется выполнить мелкозернистую обработку, вы всегда можете вернуться к tf.py_func:

def process(tensor):
  mask1 = tensor > 10
  mask2 = tensor < 4
  tensor[mask1] -= 10
  tensor[mask2] += 60
  return tensor
z = tf.py_func(process, [x], tf.float32)

Объединение всего этого вместе

Полный запускаемый пример:

import tensorflow as tf

x = tf.placeholder(shape=[3, 3], dtype=tf.float32)

cond1 = tf.where(x > 10, x - 10, tf.zeros_like(x))
cond2 = tf.where(x < 4, x + 60, tf.zeros_like(x))
cond3 = tf.where(tf.logical_and(x >= 4, x <= 10), x, tf.zeros_like(x))
y = cond1 + cond2 + cond3

def process(tensor):
  mask1 = tensor > 10
  mask2 = tensor < 4
  tensor[mask1] -= 10
  tensor[mask2] += 60
  return tensor
z = tf.py_func(process, [x], tf.float32)

sample = [[10, 15, 25], [1, 2, 3], [4, 4, 10]]
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(y, feed_dict={x: sample}))
  print(sess.run(z, feed_dict={x: sample}))

Вывод:

[[10.  5. 15.]
 [61. 62. 63.]
 [ 4.  4. 10.]]
[[10.  5. 15.]
 [61. 62. 63.]
 [ 4.  4. 10.]]