首页 > 解决方案 > Keras/tensorflow:无法从损失函数调用 tf.Print()

问题描述

有以下损失函数:

def weightedLoss(originalLossFunc, weightsList):

    def lossFunc(true, pred):

        axis = -1 #if channels last 
        #axis=  1 #if channels first


        #argmax returns the index of the element with the greatest value
        #done in the class axis, it returns the class index    
        classSelectors = K.argmax(true, axis=axis) 

        #considering weights are ordered by class, for each class
        #true(1) if the class index is equal to the weight index   
        classSelectors = [K.equal(i, classSelectors) for i in range(len(weightsList))]

        #casting boolean to float for calculations  
        #each tensor in the list contains 1 where ground true class is equal to its index 
        #if you sum all these, you will get a tensor full of ones. 
        classSelectors = [K.cast(x, K.floatx()) for x in classSelectors]

        #for each of the selections above, multiply their respective weight
        weights = [sel * w for sel,w in zip(classSelectors, weightsList)] 

        #sums all the selections
        #result is a tensor with the respective weight for each element in predictions
        weightMultiplier = weights[0]
        for i in range(1, len(weights)):
            weightMultiplier = weightMultiplier + weights[i]


        #make sure your originalLossFunc only collapses the class axis
        #you need the other axes intact to multiply the weights tensor
        loss = originalLossFunc(true,pred) 
        weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage")
        loss = loss * weightMultiplier
        #weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage") ---location 2
        return loss
    return lossFunc

现在在该函数内部,我有一个打印语句来打印权重向量。在当前位置,网络不打印任何内容,尽管我认为这会导致网络将其包含在其计算图中。然后,我将它向下移动了一行并尝试了,但这也不起作用。我究竟做错了什么?我在任何时候都没有收到错误。

标签: pythontensorflowkeras

解决方案


推荐阅读