python - 如何正确拆分我的训练和推理解码器?
问题描述
我使用 contrib seq2seq API 创建了一个 Dataset API 管道和一个 seq2seq 编码器-解码器模型。
我正在创建两个不同的解码器(共享相同的权重):
- 一个训练我的模型,使用教师强制(TrainHelper)
- 一个测试我的模型,使用解码器输出作为输入(GreedyEmbeddingHelper)
但是,我不能使用我的模型,因为我调用了解码器函数两次:
- 提供训练数据集作为构建模型的参数
- 提供测试数据集作为构建模型的参数
两次调用该函数,我正在复制一些变量。
这是我的解码器函数,它创建了训练和推理解码器:
def decoder(target, hidden_state, encoder_outputs):
with tf.name_scope("decoder"):
# ... embedding the targets
decoder_inputs = embeddings(target)
decoder_gru_cell = tf.nn.rnn_cell.GRUCell(dec_units, name="gru_cell")
# Here I create the training decoder part
with tf.variable_scope("decoder"):
training_helper = tf.contrib.seq2seq.TrainingHelper(decoder_inputs, max_length)
training_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_gru_cell, training_helper, hidden_state)
training_decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder, max_length)
# And here I create the inference decoder part
with tf.variable_scope("decoder", reuse=True):
inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(...)
inference_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_gru_cell, inference_helper, hidden_state)
inference_decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(inference_decoder, max_length)
return training_decoder_outputs, inference_decoder_outputs
在这里我创建了我的模型:
def seq2_seq2_model(values, labels):
encoder_outputs, hidden_state = encoder(values)
training_decoder, inference_decoder = decoder(labels, hidden_state, encoder_outputs)
return training_decoder, inference_decoder
这是我的数据集,我分为训练部分和测试部分(大小 n_test):
values_dataset = tf.data.Dataset.from_tensor_slices(values)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
X_Y_dataset = tf.data.Dataset.zip((features_dataset, caption_dataset))
X_Y_test = X_Y_dataset.take(n_test).batch(n_test)
X_Y_train = X_Y_dataset.skip(n_test).batch(batch_size)
test_iterator = X_Y_test.make_initializable_iterator()
x_y_test_next = test_iterator.get_next()
train_iterator = X_Y_train.make_initializable_iterator()
x_y_train_next = train_iterator.get_next()
最后,我通过调用 seq2_seq2_model 来构建我的模型:
training_decoder_outputs, _ = seq2_seq2_model(*x_y_train_next)
_, inference_decoder_outputs = seq2_seq2_model(*x_y_test_next)
错误出现了,因为我创建了两次变量decoder_gru_cell。
ValueError: Variable decoder/decoder/attention_wrapper/gru_cell/gates/kernel already exists, disallowed.
我可以为重复的变量创建一个全局变量,但这似乎是纠正问题的一种肮脏方式。此外,我展示的代码是我的简化版本:我必须创建几个全局变量......
解决方案
我终于找到了。关键是使用可重新初始化的迭代器,以便将数据集切换为输入源。
data_iterator = tf.data.Iterator.from_structure(X_Y_test.output_types,
X_Y_train.output_shapes)
train_init_op = data_iterator.make_initializer(X_Y_train)
test_init_op = data_iterator.make_initializer(X_Y_test)
values, labels = data_iterator.get_next()
然后我们可以一步创建解码器:
training_decoder_outputs, inference_decoder_outputs = seq2_seq2_model(values, labels)
最后,我们使用 train_init_op 和 test_init_op 来指定我们是要使用训练数据集还是 test_dataset:
with tf.Session() as sess:
sess.run(init)
sess.run(train_init_op)
# Perform training...
sess.run(test_init_op)
# Perform inference...
推荐阅读
- tensorrt - TensorRT 警告:由于驱动程序或 nvrtc 不兼容,卷积 + 通用激活融合被禁用
- java - Java Image IO“无法读取输入文件”
- ios-simulator - 如何对在模拟器上运行 iphone 应用程序进行故障排除?
- javascript - 已解决:如何在 mongoose 中创建子文档?MongoDB、NodeJS
- python - 如何使用 python 请求抓取非 restful API?
- html - 为什么我的 h1 的内容会出现在一个固定的导航栏组件之上?
- python - 如何从方程 x+yx*y 中找到 x 和 y
- c - 收到 sigchld 后 Getline 停止工作
- android - Android导航组件在选项卡之间切换时不维护堆栈
- java - 有没有办法计算完成每个 completableFuture 所需的时间?