Чередование tf.данные.Набор данных


Я пытаюсь использовать tf.данные.Dataset для чередования двух наборов данных, но возникли проблемы с этим. Приведем такой простой пример:

ds0 = tf.data.Dataset()
ds0 = ds0.range(0, 10, 2)
ds1 = tf.data.Dataset()
ds1 = ds1.range(1, 10, 2)
dataset = ...
iter = dataset.make_one_shot_iterator()
val = iter.get_next()

Что такое ..., чтобы получить результат, подобный 0, 1, 2, 3...9?

Похоже на dataset.interleave() было бы уместно, но я не смог сформулировать утверждение таким образом, чтобы оно не приводило к ошибке.

1 3

1 ответ:

Маттскарпино находится на правильном пути всвоем комментарии . Вы можете использовать Dataset.zip() вместе с Dataset.flat_map() чтобы сгладить многоэлементный набор данных:

ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)

# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
    lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
        tf.data.Dataset.from_tensors(x1)))

iter = dataset.make_one_shot_iterator()
val = iter.get_next()

Сказав это, ваша интуиция об использовании Dataset.interleave() это довольно разумно. Мы исследуем способы, которыми вы можете сделать это более легко.


ПС. В качестве альтернативы, вы можете использовать Dataset.interleave() для решения задачи, если измените способ определения ds0 и ds1:

dataset = tf.data.Dataset.range(2).interleave(
    lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)