python - 如何训练多输出 Seq2Seq keras 模型以联合预测序列及其相关分数
问题描述
我正在尝试训练一个 Seq2Seq keras 模型来预测输出序列分数以及输出序列。但是,我对此感到困惑,因为预测一维分数似乎与序列预测不兼容。我尝试根据输出序列的步数重复输出分数,但没有成功。我最后一次尝试是使用 TimeDistributed 包装器,但它也失败了。请帮助我解决这个问题,并了解我做错了什么。接下来,我将向您展示一个重现问题的代码示例。提前致谢(Tensorflow2.6,Python 3.9)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
import numpy as np
import re, random, string
train_pairs = [('fence is a', ('[start] border [end]', 2.321928094887362)),
('add up column of number causes',
('[start] get sum [end]', 3.1699250014423126)),
('pluto defined as',
('[start] ninth planet from sun [end]', 1.5849625007211563)),
('sit down has prerequisite',
('[start] something to sit on [end]', 2.321928094887362)),
('eat has subevent', ('[start] make lot of noise [end]', 2.0)),
('hang glider is a', ('[start] minimal aircraft [end]', 2.0)),
('staircase used for', ('[start] go downstairs [end]', 2.321928094887362)),
('go to work has prerequisite', ('[start] open front door [end]', 2.0)),
('elastic band used for',
('[start] hold two or more object together [end]', 2.321928094887362)),
('condom is a', ('[start] call rubber [end]', 1.5849625007211563)),
('curiosity causes desire', ('[start] learn [end]', 2.321928094887362)),
('bird capable of',
('[start] build their nest on strong branch [end]', 1.5849625007211563)),
('join club motivated by goal', ('[start] find friend [end]', 2.0)),
('join club motivated by goal', ('[start] find friend [end]', 2.0)),
('start fire causes', ('[start] heat [end]', 2.0)),
('coffee has property', ('[start] have distinctive aroma [end]', 2.0)),
('read newspaper has a',
('[start] effect of learn about event [end]', 1.5849625007211563)),
('foot used for', ('[start] stand [end]', 2.321928094887362)),
('jello receives action',
('[start] make from hoof and connective tissue [end]', 1.5849625007211563)),
('ranch used for', ('[start] clean animal [end]', 2.0)),
('gain more land has subevent',
('[start] increase maintainance [end]', 1.5849625007211563))]
sequence_length = 10
n_epochs = 1
embedding_dim = 500
n_states = 5
max_features=15000
chunk_size = 5
strip_chars = string.punctuation
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")
def custom_standardization(input_string):
lowercase = tf.strings.lower(input_string)
return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
input_vectorizer = layers.experimental.preprocessing.TextVectorization(
output_mode="int", max_tokens=max_features,
output_sequence_length=sequence_length, standardize=custom_standardization)
output_vectorizer = layers.experimental.preprocessing.TextVectorization(
output_mode="int", max_tokens=max_features,
output_sequence_length=sequence_length + 1, standardize=custom_standardization)
train_in_texts = [pair[0] for pair in train_pairs]
train_out_texts = [pair[1][0] for pair in train_pairs]
input_vectorizer.adapt(train_in_texts)
output_vectorizer.adapt(train_out_texts)
def format_dataset(in_phr, out_phr, label=None):
in_phr = input_vectorizer(in_phr)
out_phr = output_vectorizer(out_phr)
return ({"encoder_inputs": in_phr, "decoder_inputs": out_phr[:, :-1],}, (out_phr[:, 1:], label))
def make_dataset(pairs):
in_phr_texts, targets = zip(*pairs)
in_phr_texts = list(in_phr_texts)
out_phr_texts = [t[0] for t in targets]
labels = [[t[1]] for t in targets]
dataset = tf.data.Dataset.from_tensor_slices((in_phr_texts, out_phr_texts, labels))
dataset = dataset.batch(chunk_size)
dataset = dataset.map(format_dataset)
return dataset.shuffle(2048).prefetch(16).cache()
def build_rnn_encdec_verifier_model(
vocab_size, sequence_length, embedding_dims=100, n_states=10):
sub_pred_seq = layers.Input(shape=(sequence_length, ), name='encoder_inputs')
enc_embedding = layers.Embedding(
vocab_size + 1,
embedding_dims,
input_length = sequence_length,
mask_zero=True,
name="Enc_emb")(sub_pred_seq)
_, state_h, state_c = layers.LSTM(
n_states, return_state=True,
#unroll=True,
name="Enc_LSTM")(
enc_embedding)
encoder_states = [state_h, state_c]
decoder_inputs = keras.Input(shape=(sequence_length, ), name='decoder_inputs')
dec_embedding = layers.Embedding(
vocab_size + 1,
embedding_dims,
input_length = sequence_length,
mask_zero=True,
name="Dec_emb")(decoder_inputs)
decoder_out, last_state, last_cell = layers.LSTM(
n_states,
return_sequences=True,
return_state=True,
name="Dec_LSTM")(dec_embedding,
initial_state=encoder_states)
obj_out = layers.Dense(vocab_size,
name="obj_MLP",
)(decoder_out)
out_label = layers.TimeDistributed(layers.Dense(1,
name="labels_MLP",
))(decoder_out)
model = keras.Model([sub_pred_seq, decoder_inputs], [obj_out, out_label])
return model
train_ds = make_dataset(train_pairs)
model = build_rnn_encdec_verifier_model(
vocab_size=max_features,
embedding_dims=embedding_dim,
sequence_length=sequence_length,
n_states=n_states)
model.summary()
model.compile(optimizer='adam',
loss=[keras.losses.SparseCategoricalCrossentropy(), 'mae'],
metrics=['acc', 'mse'])
model.fit(train_ds,
epochs=n_epochs,
verbose=1)
这是日志输出:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_inputs (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
decoder_inputs (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
Enc_emb (Embedding) (None, 10, 500) 7500500 encoder_inputs[0][0]
__________________________________________________________________________________________________
Dec_emb (Embedding) (None, 10, 500) 7500500 decoder_inputs[0][0]
__________________________________________________________________________________________________
Enc_LSTM (LSTM) [(None, 5), (None, 5 10120 Enc_emb[0][0]
__________________________________________________________________________________________________
Dec_LSTM (LSTM) [(None, 10, 5), (Non 10120 Dec_emb[0][0]
Enc_LSTM[0][1]
Enc_LSTM[0][2]
__________________________________________________________________________________________________
obj_MLP (Dense) (None, 10, 15000) 90000 Dec_LSTM[0][0]
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, 10, 1) 6 Dec_LSTM[0][0]
==================================================================================================
Total params: 15,111,246
Trainable params: 15,111,246
Non-trainable params: 0
__________________________________________________________________________________________________
Traceback (most recent call last):
File "/home/iarroyof/test.py", line 136, in <module>
model.fit(train_ds,
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py", line 1184, in fit
tmp_logs = self.train_function(iterator)
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 885, in __call__
result = self._call(*args, **kwds)
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 759, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3066, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3298, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 668, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 994, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:853 train_function *
return step_function(self, iterator)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:842 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:1286 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2849 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3632 _call_for_each_replica
return fn(*args, **kwargs)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:835 run_step **
outputs = model.train_step(data)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:788 train_step
loss = self.compiled_loss(
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/compile_utils.py:201 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/losses.py:142 __call__
return losses_utils.compute_weighted_loss(
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/utils/losses_utils.py:319 compute_weighted_loss
losses, _, sample_weight = squeeze_or_expand_dimensions( # pylint: disable=unbalanced-tuple-unpacking
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/utils/losses_utils.py:210 squeeze_or_expand_dimensions
sample_weight = tf.squeeze(sample_weight, [-1])
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py:4537 squeeze_v2
return squeeze(input, axis, name)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py:549 new_func
return func(*args, **kwargs)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py:4485 squeeze
return gen_array_ops.squeeze(input, axis, name)
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/ops/gen_array_ops.py:10198 squeeze
_, _, _op, _outputs = _op_def_library._apply_op_helper(
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py:748 _apply_op_helper
op = g._create_op_internal(op_type_name, inputs, dtypes=None,
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:599 _create_op_internal
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:3561 _create_op_internal
ret = Operation(
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:2041 __init__
self._c_op = _create_c_op(self._graph, node_def, inputs,
/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:1883 _create_c_op
raise ValueError(str(e))
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 10 for '{{node mean_absolute_error/weighted_loss/Squeeze}} = Squeeze[T=DT_FLOAT, squeeze_dims=[-1]](Cast_2)' with input shapes: [?,10].
解决方案
推荐阅读
- chef-infra - 文件中的厨师刀环境无法识别较旧的食谱版本
- web-scraping - 如何从 Squawka 中获取这些玩家评分
- python - 没有名为 googleapiclient.discovery 的模块
- android - 如果 StreetView 不适用于指定的 latlong,如何获取最近的可用 StreetView?
- swift - 如何获取系统范围的鼠标和键盘事件以查找空闲时间?
- vb.net - Web 自动化 - 如何将文本输入到由类定义的网站 (textarea) 上的富文本框中?
- ios - 如何以编程方式将笔尖添加到 UIStackView
- laravel - Laravel 本地化位置取决于变量
- android - 如何将 distributionSha256Sum 添加到 gradle Android Studio?
- java - optional.filter 的 Lambda 表达式编码变体