首页 > 解决方案 > 如何重塑 Tensorflow 数据集中的数据?

问题描述

我正在编写一个数据管道,将成批的时间序列序列和相应的标签输入到需要 3D 输入形状的 LSTM 模型中。我目前有以下内容:

def split(window):
    return window[:-label_length], window[-label_length]

dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

得到的形状for x, y in dataset.take(1): x.shape是 (32, 20),其中 32 是批量大小,20 是序列长度,但我需要 (32, 20, 1) 的形状,其中附加维度表示特征。

我的问题是如何重塑,最好是在缓存数据之前split传递给函数的函数中?dataset.map

标签: pythonpandastensorflowdeep-learningreshape

解决方案


这很容易。在您的拆分功能中执行此操作

def split(window):
    return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]

推荐阅读