python-3.x - 在 tensorflow 2.1.0 中强制执行急切
问题描述
我是 tensorflow 和深度学习的新手。
我创建了一个自定义损失函数,但似乎在自定义损失函数中,未启用急切执行。下面是我的自定义损失函数(它不起作用):
def custom_error_finder(y_actual,y_pred):
print(tf.executing_eagerly())
count = 0
qw = tf.py_function((y_actual).numpy())
ya = ((y_actual[0].numpy()).decode())
yp = ((y_pred[0].numpy()).decode())
for i,j in ya,yp:
if i!=j:
count = count+1
mse = pow(count,2)/len(ya)
return mse
让我难过的是,在这个函数之外,每当我运行时print(tf.executing_eagerly())
,它都会返回 true,但在函数内部,它会返回 false。
我已经尝试了所有我能找到的修复:
- 传入 model.compile run_eagerly = True
() 函数
-model.run_eagerly() = True
在编译函数之后添加
-tf.compat.v1.enable_eager_execution()
在损失函数中运行以至少强制急切执行一次。
以上修复均无效。
解决方案
我能够重现如下问题。您可以从这里下载我在程序中使用的数据集。我print("tf.executing_eagerly() Results")
在程序中添加了语句来跟踪更改。
代码 -
%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)
import numpy as np
from numpy import loadtxt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import backend as K
print("tf.executing_eagerly() Results")
print("Before loading dataset :",tf.executing_eagerly())
# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
print("After building model :",tf.executing_eagerly())
def weighted_binary_crossentropy(y_true, y_pred):
print("In loss function :",tf.executing_eagerly())
return K.mean(K.binary_crossentropy(y_pred, y_true))
# compile model
model.compile(loss=weighted_binary_crossentropy, optimizer='adam', metrics=['accuracy'])
print("After compiling model :",tf.executing_eagerly())
# Fit the model
model.fit(X, Y, epochs=1, batch_size=150, verbose=0)
# evaluate the model
scores = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
输出 -
2.2.0
tf.executing_eagerly() Results
Before loading dataset : True
After building model : True
After compiling model : True
In loss function : False
In loss function : False
In loss function : False
accuracy: 34.90%
解决方案 -根据文档。它提到,
run_eagerly - 指示模型是否应该急切运行的可设置属性。急切地运行意味着您的模型将像 Python 代码一样逐步运行。您的模型可能会运行得更慢,但通过单步调用各个层调用,您应该可以更轻松地对其进行调试。默认情况下,我们将尝试将您的模型编译为静态图以提供最佳执行性能。
如果我们修改model.compile
withrun_eagerly = True
参数,我们可以解决这个问题。下图是修改后的model.compile
代码,
model.compile(loss=weighted_binary_crossentropy, run_eagerly = True, optimizer='adam', metrics=['accuracy'])
固定代码 -
%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)
import numpy as np
from numpy import loadtxt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import backend as K
print("tf.executing_eagerly() Results")
print("Before loading dataset :",tf.executing_eagerly())
# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
print("After building model :",tf.executing_eagerly())
def weighted_binary_crossentropy(y_true, y_pred):
print("In loss function :",tf.executing_eagerly())
return K.mean(K.binary_crossentropy(y_pred, y_true))
# compile model
model.compile(loss=weighted_binary_crossentropy, run_eagerly = True, optimizer='adam', metrics=['accuracy'])
print("After compiling model :",tf.executing_eagerly())
# Fit the model
model.fit(X, Y, epochs=1, batch_size=150, verbose=0)
# evaluate the model
scores = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
输出 -
2.2.0
tf.executing_eagerly() Results
Before loading dataset : True
After building model : True
After compiling model : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
In loss function : True
accuracy: 34.90%
希望这能回答你的问题。快乐学习。
推荐阅读
- http - 为什么 YouTube API 查询增长如此之快?
- hibernate - Hibernate 映射关联到 Map
- java - java - 如何在java中将字符串(从json对象解组)再次编组为json?
- ruby-on-rails - 无法访问端口 3000 上的 puma http 站点
- python - ctypes 库如何实现基本数据类型乘法来生成数组?
- javascript - 从状态映射数组内的数组
- algorithm - 迭代贪心算法解释
- java - 代理和装饰器,以防止对缓存对象的 setter 调用
- css - 网格下的显示流(基础框架)
- asp.net - 我可以在我的计算机上验证图像的 URL