python - Tensorflow 2.0:如何像使用 PyTorch 一样完全自定义 Tensorflow 训练循环?
问题描述
我以前用过Tensorflow
很多,但后来Pytorch
因为调试起来容易得多。我发现的好处PyTorch
是我必须编写自己的训练循环,这样我就可以单步执行代码并找出错误。我可以毫无困难地启动pdb
并检查张量形状和变换等。
因为Tensorflow
我一直在使用该model.fit()
函数,所以我收到的任何错误消息都像是 6 页 C 代码,其中错误消息没有给我任何指示,但它在 python 代码中。由于它是静态图,因此用户无法逐步执行该model.fit()
功能,这确实减慢了我的开发过程。但是,我正在考虑Tensorflow
再次使用,我想知道用户是否可以逐步通过自定义训练循环并查看张量形状等,或者甚至自定义训练循环是否被编译为静态图,因此用户无法通过它?
我确实在谷歌上搜索了这个问题,但所有自定义训练循环的教程都Tensorflow
侧重于针对高级用户的自定义循环,例如,如果您想在训练时应用一些异国情调的回调,或者如果您想应用一些条件逻辑。因此,是否容易通过自定义训练循环这个简单的问题没有得到回答。
任何帮助表示赞赏。谢谢。
解决方案
这几乎是我能做到的定制和简单的骨头。我还使用了子类层。
import tensorflow as tf
import tensorflow_datasets as tfds
ds = tfds.load('iris', split='train', as_supervised=True)
train = ds.take(125).shuffle(125).batch(1)
test = ds.skip(125).take(25).shuffle(25).batch(1)
class Dense(tf.Module):
def __init__(self, in_features, out_features, activation, name=None):
super().__init__(name=name)
self.activation = activation
self.w = tf.Variable(
tf.initializers.GlorotUniform()([in_features, out_features]), name='weights')
self.b = tf.Variable(tf.zeros([out_features]), name='biases')
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return self.activation(y)
class SequentialModel(tf.Module):
def __init__(self, name):
super().__init__(name=name)
self.dense1 = Dense(in_features=4, out_features=16, activation=tf.nn.relu)
self.dense2 = Dense(in_features=16, out_features=32, activation=tf.nn.relu)
self.dense3 = Dense(in_features=32, out_features=3, activation=tf.nn.softmax)
def __call__(self, x):
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
return x
model = SequentialModel(name='sequential_model')
loss_object = tf.losses.SparseCategoricalCrossentropy(from_logits=False)
def compute_loss(model, x, y):
out = model(x)
loss = loss_object(y_true=y, y_pred=out)
return loss, out
def get_grad(model, x, y):
with tf.GradientTape() as tape:
loss, out = compute_loss(model, x, y)
gradients = tape.gradient(loss, model.trainable_variables)
return loss, gradients, out
optimizer = tf.optimizers.Adam()
verbose = "Epoch {:2d} Loss: {:.3f} TLoss: {:.3f} Acc: {:=7.2%} TAcc: {:=7.2%}"
for epoch in range(1, 10 + 1):
train_loss = tf.constant(0.)
train_acc = tf.constant(0.)
test_loss = tf.constant(0.)
test_acc = tf.constant(0.)
for n_train, (x, y) in enumerate(train, 1):
loss_value, grads, out = get_grad(model, x, y)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_loss += loss_value
train_acc += tf.metrics.sparse_categorical_accuracy(y, out)[0]
for n_test, (x, y) in enumerate(test, 1):
loss_value, _, out = get_grad(model, x, y)
test_loss += loss_value
test_acc += tf.metrics.sparse_categorical_accuracy(y, out)[0]
print(verbose.format(epoch,
tf.divide(train_loss, n_train),
tf.divide(test_loss, n_test),
tf.divide(train_acc, n_train),
tf.divide(test_acc, n_test)))
推荐阅读
- sql - SELECT 语句中的多个条件
- python - 使用 pip 安装 keras 时出错?
- ssh - 我想知道我必须使用 ssh 登录终端的 ip
- git - `git remote add --mirror=fetch` 会与 `git clone --mirror` 生成相同的仓库吗?
- javascript - 以角度禁用另一个组件的表单字段
- javascript - Dropzone.js如何向上传的图像添加带有值的隐藏输入
- r - 使用 system() 打开带有 URL 参数的本地 index.html
- python - 如何将存储为数组的 Dataframe 值转换为列表
- python - AttributeError:模块“ocrmypdf”没有属性“ocr”
- javascript - div 内容的更改不会反映在浏览器中