首页 > 解决方案 > 张量流混淆矩阵:“所有输入的形状必须匹配”

问题描述

所以我有以下代码应该产生一个混淆矩阵:

predGenerator = generator(X_test, Y_test, batchSize=50)
predictions = model.predict_generator(predGenerator.mainGen(), steps=50)
print(tf.math.confusion_matrix(labels=tf.argmax(predGenerator.actualLabels, 1), predictions=tf.argmax(predictions, 1))) #Will predoce a confusion matrix

生成器类如下:

class generator:
    def __init__(self, X, Y, batchSize):
        self.X = X
        self.Y = Y
        self.batchSize = batchSize
        self.index = 0
        self.actualLabels = np.array([])

    def __genActualLabels__(self, Y_batch):
        try:
            #Creates an actualLabels attribute for when we want to create our confusion matrix.
            self.actualLabels = np.vstack([self.actualLabels, Y_batch])
        except:
            #If it's empty, then it'll create a new np array as else it'll cause an error.
            self.actualLabels = np.array(Y_batch) 

    def genBatch(self): #Will return the next batch each time. 
        for image in self.X[self.index * self.batchSize : (self.index + 1) * self.batchSize]:
            try:
                X_batch = np.vstack((X_batch, [imageToRGB(image, True)]))
            except:
                X_batch = np.array([imageToRGB(image, True)])

        Y_batch = np.array(self.Y[self.index * self.batchSize : (self.index + 1) * self.batchSize])
        self.__genActualLabels__(Y_batch) #Appends Y_batch to the actualLabels.
        self.index+=1 #Appends index so that next time, then next batch will be returned.
        return (X_batch.astype(np.float32), Y_batch.astype(np.float32))

    def mainGen(self): #The main loop which will continuously yield each batch.
        while True:
            yield self.genBatch()
            if self.index > (len(self.X) - 1)/self.batchSize:
                self.index = 0

但是,当我运行此代码时,出现错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [3038] != values[1].shape = [2500] [Op:Pack] name: stack

然而,完全相同的代码过去曾为我工作过,所以我不知道究竟是什么导致它有时工作而有时不工作。

标签: pythonnumpytensorflowkerasconfusion-matrix

解决方案


推荐阅读