python - Keras,Tensorflow:评估时如何在自定义层中设置断点(调试)?
问题描述
我只想在自定义层内做一些数值验证。
假设我们有一个非常简单的自定义层:
class test_layer(keras.layers.Layer):
def __init__(self, **kwargs):
super(test_layer, self).__init__(**kwargs)
def build(self, input_shape):
self.w = K.variable(1.)
self._trainable_weights.append(self.w)
super(test_layer, self).build(input_shape)
def call(self, x, **kwargs):
m = x * x # Set break point here
n = self.w * K.sqrt(x)
return m + n
和主程序:
import tensorflow as tf
import keras
import keras.backend as K
input = keras.layers.Input((100,1))
y = test_layer()(input)
model = keras.Model(input,y)
model.predict(np.ones((100,1)))
如果我在该行设置断点调试m = x * x
,程序在执行的时候会在这里暂停y = test_layer()(input)
,这是因为图建好了,call()
方法被调用了。
但是当我使用model.predict()
它来赋予它真正的价值,并且想看看它是否正常工作时,它不会停在这条线上m = x * x
我的问题是:
call()
仅在构建计算图时才调用方法吗?(提供真实价值时不会调用它?)如何在层内调试(或在哪里插入断点)以在给它实际值输入时查看变量的值?
解决方案
在 TensorFlow 2 中,您现在可以向 TensorFlow Keras 模型/层添加断点,包括在使用拟合、评估和预测方法时。但是,您必须在调用断点处的调试器中可用的张量值model.run_eagerly = True
之后添加。model.compile()
例如,
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
class SimpleModel(Model):
def __init__(self):
super().__init__()
self.dense0 = Dense(2)
self.dense1 = Dense(1)
def call(self, inputs):
z = self.dense0(inputs)
z = self.dense1(z) # Breakpoint in IDE here. =====
return z
x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
model0 = SimpleModel()
y0 = model0.call(x) # Values of z shown at breakpoint. =====
model1 = SimpleModel()
model1.run_eagerly = True
model1.compile(optimizer=Adam(), loss=BinaryCrossentropy())
y1 = model1.predict(x) # Values of z *not* shown at breakpoint. =====
model2 = SimpleModel()
model2.compile(optimizer=Adam(), loss=BinaryCrossentropy())
model2.run_eagerly = True
y2 = model2.predict(x) # Values of z shown at breakpoint. =====
注意:这是在 TensorFlow 中测试的2.0.0-rc0
。
推荐阅读
- javascript - jquery on change 和 input 函数运行多次
- python - UnicodeDecodeError:“utf-8”编解码器无法解码位置 3 中的字节 0x97:无效的起始字节
- vue.js - 带有图像标签的 vue.js 动态路由器链接包装
- c - 关于打包结构的大小
- javascript - 我的 ng-repeat 只显示最新的循环
- mysql - 按mysql中的特定单词排序
- c - 重命名和删除文件仅在第一次迭代中不起作用 (C)
- javascript - 在 ReactJS 中将新的子 DOM 附加到父级时面临的问题
- python - 类型错误:float() 参数必须是字符串或数字:float(ab[1:])
- android - 如何在不从 Android Studio 重新上传的情况下在设备上重新运行应用程序的调试版本?