python - Tensorflow 2.0:如何为 numpy 矩阵输入创建 feature_columns
问题描述
我了解如何在 Tensorflow 1.x 中执行此操作(链接在这里)
但是对于 Tensorflow 2.0,如何为 numpy 矩阵创建 feature_columns?
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import tensorflow as tf
X = iris['data']
y = iris['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
ds_train = tf.data.Dataset.from_tensor_slices((X_train, y_train))
ds_test = tf.data.Dataset.from_tensor_slices((X_test, y_test))
model = CustomModel(feature_columns, num_classes=y_train.shape[1])
model.compile()
model.compile('adam', loss='categorical_crossentropy', metrics='accuracy')
根据 CustomModel 的文档字符串,它要求feature_columns: The Tensorflow feature columns for the dataset.
我以 sklearn 的 iris 数据集为例。我知道 tensorflow2.0 有一个 iris 数据集。如果我使用那个数据集,我就不会有这个问题。但这不是重点。鉴于我有 numpy 矩阵,我想知道如何创建特征列以输入 tensorflow 模型。
解决方案
TensorFlow文档包含我们通常需要的所有示例。如果你想创建提升树模型,还有另一篇文章
我使用这个数据集只是为了展示如何使用 Pandas 数据框创建 tf.data
import tensorflow as tf
import pandas as pd
from sklearn.model_selection import train_test_split
seaflow_train = pd.read_csv(
"~/PycharmProjects/TensorFlow2/seaflow_21min.csv")
print(seaflow_train.head())
print(seaflow_train.columns)
seaflow_train['target'] = seaflow_train['pop']
seaflow_train = seaflow_train.drop(columns=['file_id', 'cell_id', 'time', 'd1', 'd2'])
train, test = train_test_split(seaflow_train, test_size=0.2)
train, val = train_test_split(train, test_size=0.2)
# A utility method to create a tf.data dataset from a Pandas Dataframe
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
dataframe = dataframe.copy()
labels = dataframe.pop('target')
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
if shuffle:
ds = ds.shuffle(buffer_size=len(dataframe))
ds = ds.batch(batch_size)
return ds
batch_size = 5 # A small batch sized is used for demonstration purposes
train_ds = df_to_dataset(train, batch_size=batch_size)
val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)
test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)
推荐阅读
- r - 如何在 R 中生成未来残差?
- bash - curl PUT 使用 auth token header 到 mesosphere 失败,没有 eval
- javascript - 创建反应应用程序未将 abab 模块编译为符合 ES5 的代码,导致 IE11 失败
- python - 了解 Python Arcade 中的类
- postgresql - Postgres在地图中连接不同的键
- android - 为什么 setPivotX() 只是替换视图?
- unity3d - 将 .NET 4.5 C# DLL 加载到 Unity 2018.2.5 后的 TypeLoadException
- c++ - 我想在 C++ 中读取一些多个字符,但它从不读取第二个字符
- javascript - 如果行中带有值触发器的任何单元格被更新,sendNotification 将继续发送电子邮件
- service - 如何检查由 NSSM(非吸吮服务管理器)创建的所有服务的列表?