首页 > 解决方案 > 如何训练多输出 Seq2Seq keras 模型以联合预测序列及其相关分数

问题描述

我正在尝试训练一个 Seq2Seq keras 模型来预测输出序列分数以及输出序列。但是,我对此感到困惑,因为预测一维分数似乎与序列预测不兼容。我尝试根据输出序列的步数重复输出分数,但没有成功。我最后一次尝试是使用 TimeDistributed 包装器,但它也失败了。请帮助我解决这个问题,并了解我做错了什么。接下来,我将向您展示一个重现问题的代码示例。提前致谢(Tensorflow2.6,Python 3.9)

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import  plot_model
import numpy as np
import re, random, string


train_pairs = [('fence is a', ('[start] border [end]', 2.321928094887362)),
 ('add up column of number causes',
  ('[start] get sum [end]', 3.1699250014423126)),
 ('pluto defined as',
  ('[start] ninth planet from sun [end]', 1.5849625007211563)),
 ('sit down has prerequisite',
  ('[start] something to sit on [end]', 2.321928094887362)),
 ('eat has subevent', ('[start] make lot of noise [end]', 2.0)),
 ('hang glider is a', ('[start] minimal aircraft [end]', 2.0)),
 ('staircase used for', ('[start] go downstairs [end]', 2.321928094887362)),
 ('go to work has prerequisite', ('[start] open front door [end]', 2.0)),
 ('elastic band used for',
  ('[start] hold two or more object together [end]', 2.321928094887362)),
 ('condom is a', ('[start] call rubber [end]', 1.5849625007211563)),
 ('curiosity causes desire', ('[start] learn [end]', 2.321928094887362)),
 ('bird capable of',
  ('[start] build their nest on strong branch [end]', 1.5849625007211563)),
 ('join club motivated by goal', ('[start] find friend [end]', 2.0)),
('join club motivated by goal', ('[start] find friend [end]', 2.0)),
 ('start fire causes', ('[start] heat [end]', 2.0)),
 ('coffee has property', ('[start] have distinctive aroma [end]', 2.0)),
 ('read newspaper has a',
  ('[start] effect of learn about event [end]', 1.5849625007211563)),
 ('foot used for', ('[start] stand [end]', 2.321928094887362)),
 ('jello receives action',
  ('[start] make from hoof and connective tissue [end]', 1.5849625007211563)),
 ('ranch used for', ('[start] clean animal [end]', 2.0)),
 ('gain more land has subevent',
  ('[start] increase maintainance [end]', 1.5849625007211563))]

sequence_length = 10
n_epochs = 1
embedding_dim = 500
n_states = 5
max_features=15000
chunk_size = 5

strip_chars = string.punctuation
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
input_vectorizer = layers.experimental.preprocessing.TextVectorization(
    output_mode="int", max_tokens=max_features,
    output_sequence_length=sequence_length, standardize=custom_standardization)

output_vectorizer = layers.experimental.preprocessing.TextVectorization(
    output_mode="int", max_tokens=max_features,
    output_sequence_length=sequence_length + 1, standardize=custom_standardization)

train_in_texts = [pair[0] for pair in train_pairs]
train_out_texts = [pair[1][0] for pair in train_pairs]

input_vectorizer.adapt(train_in_texts)
output_vectorizer.adapt(train_out_texts)

def format_dataset(in_phr, out_phr, label=None):
    in_phr = input_vectorizer(in_phr)
    out_phr = output_vectorizer(out_phr)
    return ({"encoder_inputs": in_phr, "decoder_inputs": out_phr[:, :-1],}, (out_phr[:, 1:], label))
def make_dataset(pairs):
    in_phr_texts, targets = zip(*pairs)
    in_phr_texts = list(in_phr_texts)
    out_phr_texts = [t[0] for t in targets]
    labels = [[t[1]] for t in targets]
    dataset = tf.data.Dataset.from_tensor_slices((in_phr_texts, out_phr_texts, labels))

    dataset = dataset.batch(chunk_size)
    dataset = dataset.map(format_dataset)
    return dataset.shuffle(2048).prefetch(16).cache()


def build_rnn_encdec_verifier_model(
        vocab_size, sequence_length, embedding_dims=100, n_states=10):
    sub_pred_seq = layers.Input(shape=(sequence_length, ), name='encoder_inputs')
    enc_embedding = layers.Embedding(
                                    vocab_size + 1,
                                    embedding_dims,
                                     input_length = sequence_length,
                                     mask_zero=True,
                                     name="Enc_emb")(sub_pred_seq)
    _, state_h, state_c = layers.LSTM(
                                    n_states, return_state=True,
                                    #unroll=True,
                                    name="Enc_LSTM")(
                                        enc_embedding)
    encoder_states = [state_h, state_c]
    decoder_inputs = keras.Input(shape=(sequence_length, ), name='decoder_inputs')
    dec_embedding = layers.Embedding(
                                    vocab_size + 1,
                                    embedding_dims,
                                     input_length = sequence_length,
                                     mask_zero=True,
                                     name="Dec_emb")(decoder_inputs)
    decoder_out, last_state, last_cell = layers.LSTM(
                            n_states,
                            return_sequences=True,
                            return_state=True,
                            name="Dec_LSTM")(dec_embedding,
                                             initial_state=encoder_states)
    obj_out = layers.Dense(vocab_size,
                            name="obj_MLP",
                            )(decoder_out)
    out_label = layers.TimeDistributed(layers.Dense(1,
                            name="labels_MLP",
                            ))(decoder_out)
    model = keras.Model([sub_pred_seq, decoder_inputs], [obj_out, out_label])
    return model


train_ds = make_dataset(train_pairs)

model = build_rnn_encdec_verifier_model(
    vocab_size=max_features,
    embedding_dims=embedding_dim,
    sequence_length=sequence_length,
    n_states=n_states)
model.summary()
model.compile(optimizer='adam',
              loss=[keras.losses.SparseCategoricalCrossentropy(), 'mae'],
              metrics=['acc', 'mse'])

model.fit(train_ds,
        epochs=n_epochs,
        verbose=1)

这是日志输出:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
encoder_inputs (InputLayer)     [(None, 10)]         0
__________________________________________________________________________________________________
decoder_inputs (InputLayer)     [(None, 10)]         0
__________________________________________________________________________________________________
Enc_emb (Embedding)             (None, 10, 500)      7500500     encoder_inputs[0][0]
__________________________________________________________________________________________________
Dec_emb (Embedding)             (None, 10, 500)      7500500     decoder_inputs[0][0]
__________________________________________________________________________________________________
Enc_LSTM (LSTM)                 [(None, 5), (None, 5 10120       Enc_emb[0][0]
__________________________________________________________________________________________________
Dec_LSTM (LSTM)                 [(None, 10, 5), (Non 10120       Dec_emb[0][0]
                                                                 Enc_LSTM[0][1]
                                                                 Enc_LSTM[0][2]
__________________________________________________________________________________________________
obj_MLP (Dense)                 (None, 10, 15000)    90000       Dec_LSTM[0][0]
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, 10, 1)        6           Dec_LSTM[0][0]
==================================================================================================
Total params: 15,111,246
Trainable params: 15,111,246
Non-trainable params: 0
__________________________________________________________________________________________________
Traceback (most recent call last):
  File "/home/iarroyof/test.py", line 136, in <module>
    model.fit(train_ds,
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py", line 1184, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 885, in __call__
    result = self._call(*args, **kwds)
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 759, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3066, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3298, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 668, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 994, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:853 train_function  *
        return step_function(self, iterator)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:842 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:1286 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2849 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3632 _call_for_each_replica
        return fn(*args, **kwargs)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:835 run_step  **
        outputs = model.train_step(data)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/training.py:788 train_step
        loss = self.compiled_loss(
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/engine/compile_utils.py:201 __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/losses.py:142 __call__
        return losses_utils.compute_weighted_loss(
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/utils/losses_utils.py:319 compute_weighted_loss
        losses, _, sample_weight = squeeze_or_expand_dimensions(  # pylint: disable=unbalanced-tuple-unpacking
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/keras/utils/losses_utils.py:210 squeeze_or_expand_dimensions
        sample_weight = tf.squeeze(sample_weight, [-1])
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py:4537 squeeze_v2
        return squeeze(input, axis, name)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py:549 new_func
        return func(*args, **kwargs)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py:4485 squeeze
        return gen_array_ops.squeeze(input, axis, name)
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/ops/gen_array_ops.py:10198 squeeze
        _, _, _op, _outputs = _op_def_library._apply_op_helper(
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py:748 _apply_op_helper
        op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:599 _create_op_internal
        return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:3561 _create_op_internal
        ret = Operation(
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:2041 __init__
        self._c_op = _create_c_op(self._graph, node_def, inputs,
    /home/iarroyof/anaconda3/envs/tf25/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:1883 _create_c_op
        raise ValueError(str(e))

    ValueError: Can not squeeze dim[1], expected a dimension of 1, got 10 for '{{node mean_absolute_error/weighted_loss/Squeeze}} = Squeeze[T=DT_FLOAT, squeeze_dims=[-1]](Cast_2)' with input shapes: [?,10].

标签: pythonpython-3.xtensorflowdeep-learning

解决方案


推荐阅读