python - ValueError:使用数据集作为输入时不支持“sample_weight”参数
问题描述
我想训练一个 keras 模型并使用样本权重。我的数据源是 tf.data.dataset 类型。sample_weight
使用函数的参数时出现以下错误model.fit
。
ValueError: `sample_weight` argument is not supported when using dataset as input.
代码如下所示:
model.fit(tf_train_dataset,
epochs=epochs,
verbose=self.verbose,
batch_size=batch_size,
callbacks=callbacks,
sample_weight=sample_weights
steps_per_epoch=self.steps_per_epoch,
use_multiprocessing=True,
tf_train_dataset
由创建tf.data.Dataset.from_generator
。我如何传递每个样本的权重并将其应用于损失并最终进行训练?
解决方案
使用tf.data.Dataset
API 时,样本权重应该是数据集中的另一个元组,顺序如下:(input_batch, label_batch, sample_weight_batch)
.
虚拟示例:
import numpy as np
import tensorflow as tf
from sklearn.utils.class_weight import compute_sample_weight
x_train = np.random.randn(100,2)
y_train = np.random.randint(low = 0, high = 5, size = 100, dtype = np.int32)
weights = compute_sample_weight(class_weight='balanced', y=y_train)
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train, weights))
有关更多信息,您可以参考文档。
推荐阅读
- r - 如何在 PowerBI 中为 ML 运行 R 脚本?
- angular - 从角度服务中的字符串函数返回值
- python - django 中的 CSS 不更新
- amazon-web-services - 用于查询 AWS macie 的 PII 类型
- reactjs - 我如何使用一个主题来设置所有材质 UI 输入的样式,就好像它们有一个变体 ='outlined' 一样?
- mysql - 使 10k 查询 SQL 更快
- python - 我如何在opencv python中控制轮廓区域
- neo4j - 如何在知识图中对条件三元组进行建模
- android - 将有状态存储库注入 WorkManager Worker
- regex - 忽略组匹配中的字符