tensorflow - 如何合并 TensorFlow Dataset 列?
问题描述
我有一个 Keras 模型,它采用形状为 (n, 288, 1) 的输入层,其中 288 是特征数。我正在使用 TensorFlow 数据集tf.data.experimental.make_batched_features_dataset
,我的输入层将是 (n, 1, 1),这意味着它一次为模型提供一个特征。如何制作形状为 (n, 288, 1) 的输入张量?我的意思是如何在一个张量中使用我的所有功能?
解决方案
您可以在 Keras 输入层中指定输入的形状。这里有一个示例代码演示与演示相同的虚拟数据。
import tensorflow as tf
## Creating dummy data for demo
def make_sample():
return tf.random.normal([288, 1])
n_samples = 100
samples = [make_sample() for _ in range(n_samples)]
labels = [tf.random.uniform([1]) for _ in range(n_samples)]
# Use tf.data to create dataset
batch_size = 4
dataset = tf.data.Dataset.from_tensor_slices((samples, labels))
dataset = dataset.batch(batch_size)
# Build keras function model
inputs = tf.keras.Input(shape=[288, 1], name='input')
x = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs=[inputs], outputs=[x])
# Compile loss and optimizer
model.compile(loss='mse', optimizer='sgd', metrics=['mae'])
model.fit(dataset, epochs=1)
推荐阅读
- mysql - 是否可以为每个 KEY 制作 UNIQUE VALUE?
- python - OCR tesseract 改善详细背景的结果
- c++ - 为什么通过打印星形矩形来移动第一行?
- r - 无法使用 tidy 从 aov 中获得额外的 p 值
- angular - 如何保护阵列以供 Firebase Firestore 中的用户访问?
- vue.js - 使用 AXIOS GET 和 VUE JS 检索 CSV 内容 - 返回超时错误
- javascript - 在反应中,我提出了 CORS 策略错误,所以我该如何处理它
- php - wopress 分页不适用于下一页或预览页面
- java - ForkJoinTask - 加入()与调用()
- r - 在 R 中一次创建多个具有特殊长度的向量