首页 > 解决方案 > ValueError:使用数据集作为输入时不支持“sample_weight”参数

问题描述

我想训练一个 keras 模型并使用样本权重。我的数据源是 tf.data.dataset 类型。sample_weight使用函数的参数时出现以下错误model.fit

ValueError: `sample_weight` argument is not supported when using dataset as input.

代码如下所示:

model.fit(tf_train_dataset,
          epochs=epochs,
          verbose=self.verbose,
          batch_size=batch_size,
          callbacks=callbacks,
          sample_weight=sample_weights
          steps_per_epoch=self.steps_per_epoch,
          use_multiprocessing=True,

tf_train_dataset由创建tf.data.Dataset.from_generator。我如何传递每个样本的权重并将其应用于损失并最终进行训练?

标签: pythontensorflowmachine-learningkeras

解决方案


使用tf.data.DatasetAPI 时,样本权重应该是数据集中的另一个元组,顺序如下:(input_batch, label_batch, sample_weight_batch).

虚拟示例:

import numpy as np
import tensorflow as tf
from sklearn.utils.class_weight import compute_sample_weight

x_train = np.random.randn(100,2)
y_train = np.random.randint(low = 0, high = 5, size = 100, dtype = np.int32)
weights = compute_sample_weight(class_weight='balanced', y=y_train)

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train, weights))

有关更多信息,您可以参考文档


推荐阅读