首页 > 解决方案 > 如何使用 TensorFlow 的子分类方法预测二进制分类的类?

问题描述

我正在使用子分类进行二进制分类Tensorflow。我的代码是:

class ChurnClassifier(Model):
    def __init__(self):
        super(ChurnClassifier, self).__init__()
        self.layer1 = layers.Dense(20, input_dim = 20, activation = 'relu')
        self.layer2 = layers.Dense(41, activation = 'relu')
        self.layer3 = layers.Dense(83, activation = 'relu')
        self.layer4 = layers.Dense(2,  activation = 'sigmoid')
        
    def call(self, inputs):
        x = self.layer1(inputs)
        x = self.layer2(x)
        x = self.layer3(x)
        return self.layer4(x)
        
ChurnClassifier = ChurnClassifier()
ChurnClassifier.compile(optimizer = 'adam',
                        loss=tf.keras.losses.CategoricalCrossentropy(),
                       metrics = ['accuracy'])

现在我安装了模型:

history = ChurnClassifier.fit(X_train_nur, Y_train_nur, 
          epochs=20, 
          batch_size=512,
          validation_data=(X_val_nur, Y_val_nur),
          shuffle=True)

现在,我想预测类别 0 或 1,所以我使用了代码 -prediction = ChurnClassifier.predict(X_val_nur)

现在我想看看有多少是0和1来计算TN,FN,TP,FP。所以我创建了一个数据框进行预测。代码-

pred_y = pd.DataFrame(prediction , columns=['pred_y']) 

但我得到以下DataFrame-

在此处输入图像描述

我的样本 X_train:

array([[2.02124594e+08, 3.63743942e+04, 2.12000000e+02, ...,
        4.30000000e+01, 0.00000000e+00, 1.00000000e+00],
       [4.93794595e+08, 6.66593354e+02, 4.22000000e+02, ...,
        2.60000000e+01, 0.00000000e+00, 1.00000000e+00],
       [7.28506124e+08, 1.17953696e+04, 1.14000000e+03, ...,
        2.50000000e+01, 0.00000000e+00, 1.00000000e+00],
       ...,
       [4.63797916e+08, 1.19273275e+03, 4.10000000e+02, ...,
        9.00000000e+00, 0.00000000e+00, 1.00000000e+00],
       [4.04285400e+08, 1.87350825e+04, 3.01000000e+02, ...,
        1.60000000e+01, 0.00000000e+00, 1.00000000e+00],
       [5.08433538e+08, 3.19289528e+03, 4.18000000e+02, ...,
        9.00000000e+00, 0.00000000e+00, 1.00000000e+00]])

我的样本 y_train-array([0, 0, 0, ..., 0, 0, 0], dtype=int64)

y_train_nur 只包含 0 和 1

什么问题??

提前致谢!

标签: tensorflowmachine-learningdeep-learningclassification

解决方案


推荐阅读