首页 > 解决方案 > Tensorflow 数据集,如何在每个批次上使用自定义窗口来提供训练数据?

问题描述

我有一个数据集,它是tf.data.Dataset. 我想要做的是提供一个自定义范围数据,它是每个批次的一组标记。例如,如果我的一个训练数据集是 [0,1,2,3,4,5],那么我想为第一批提供 [1,2,3],然后为 [3,4,5]第二批。有什么方法可以控制如何将训练数据提供给 tensorflow 模型?

标签: pythontensorflowkerastensorflow-datasets

解决方案


假设您tf.data.Dataset的定义如下:

train_dataset = tf.data.Dataset.from_tensor_slices(YOUR_DATA).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

并且你循环通过你的train_dataset结果批次说32。根据模型期望的输入形式,您可以拆分批次:

for batch in dataset:
  train_step(batch) 


@tf.function
def train_step(batch):
  batch1, batch2 = tf.split(batch, 2, 0)

请注意,您的批次在第一个轴上分为两个切片(通常是您的批次的大小)。在此之后,您可以简单地将这些切片提供给您的模型。

另一个想法是尝试用切片符号切片你的张量(你的批次) :

rank_3_tensor = tf.constant([
                   [[0, 1, 2, 3, 4],
                    [5, 6, 7, 8, 9]],
                   [[10, 11, 12, 13, 14],
                    [15, 16, 17, 18, 19]],
                   [[20, 21, 22, 23, 24],
                    [25, 26, 27, 28, 29]],])
print(rank_3_tensor[0:3,:,:])
# Tensor("strided_slice:0", shape=(3, 2, 5), dtype=int32)

或者

import numpy as np

sample_size = 201
D = 5
tensor = tf.constant(np.array(range(sample_size * D * D)).reshape([sample_size, D, D]))
batches_of_n = 3
for i in range(0, tensor.shape[0], batches_of_n):
    print(tensor[i:i+batches_of_n,: :])

我想你应该已经明白了。


推荐阅读