tensorflow - 如何使用 TensorFlow 的子分类方法预测二进制分类的类?
问题描述
我正在使用子分类进行二进制分类Tensorflow
。我的代码是:
class ChurnClassifier(Model):
def __init__(self):
super(ChurnClassifier, self).__init__()
self.layer1 = layers.Dense(20, input_dim = 20, activation = 'relu')
self.layer2 = layers.Dense(41, activation = 'relu')
self.layer3 = layers.Dense(83, activation = 'relu')
self.layer4 = layers.Dense(2, activation = 'sigmoid')
def call(self, inputs):
x = self.layer1(inputs)
x = self.layer2(x)
x = self.layer3(x)
return self.layer4(x)
ChurnClassifier = ChurnClassifier()
ChurnClassifier.compile(optimizer = 'adam',
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics = ['accuracy'])
现在我安装了模型:
history = ChurnClassifier.fit(X_train_nur, Y_train_nur,
epochs=20,
batch_size=512,
validation_data=(X_val_nur, Y_val_nur),
shuffle=True)
现在,我想预测类别 0 或 1,所以我使用了代码 -prediction = ChurnClassifier.predict(X_val_nur)
现在我想看看有多少是0和1来计算TN,FN,TP,FP。所以我创建了一个数据框进行预测。代码-
pred_y = pd.DataFrame(prediction , columns=['pred_y'])
但我得到以下DataFrame-
我的样本 X_train:
array([[2.02124594e+08, 3.63743942e+04, 2.12000000e+02, ...,
4.30000000e+01, 0.00000000e+00, 1.00000000e+00],
[4.93794595e+08, 6.66593354e+02, 4.22000000e+02, ...,
2.60000000e+01, 0.00000000e+00, 1.00000000e+00],
[7.28506124e+08, 1.17953696e+04, 1.14000000e+03, ...,
2.50000000e+01, 0.00000000e+00, 1.00000000e+00],
...,
[4.63797916e+08, 1.19273275e+03, 4.10000000e+02, ...,
9.00000000e+00, 0.00000000e+00, 1.00000000e+00],
[4.04285400e+08, 1.87350825e+04, 3.01000000e+02, ...,
1.60000000e+01, 0.00000000e+00, 1.00000000e+00],
[5.08433538e+08, 3.19289528e+03, 4.18000000e+02, ...,
9.00000000e+00, 0.00000000e+00, 1.00000000e+00]])
我的样本 y_train-array([0, 0, 0, ..., 0, 0, 0], dtype=int64)
y_train_nur 只包含 0 和 1
什么问题??
提前致谢!
解决方案
推荐阅读
- c# - CompileAssemblyFromFile c# 属性 7.2 失败
- excel - 基于索引匹配的 SUM 单元格
- swift - 使用 SwiftyJSON 将数据附加到现有的 JSON 数组
- spring-rest - 如何在rest xml中删除处理程序和hibernateinitializer
- vb.net - 将图像的字节写入文本文件
- swift - FireStore如何做ClientSideJoin
- python - 在风格迁移中使用 L2 归一化 - 不涉及权重?
- flutter - 根据时间导航到页面
- java - 用 Java 播放奇异的 WAV 文件
- polymer - Polymer 3 - 使用 stylesFromTemplate