首页 > 解决方案 > Tensorflow 2 模型以单个输出运行,但以多个输出失败。无法挤压 dim[2],预期维度为 1,得到 3

问题描述

当我尝试使用任一输出/损失组合时,模型运行良好,但当我尝试同时使用这两种组合时失败。因此,如果我只是不包括模型定义中的输出之一并且还消除了额外的损失,它就可以正常工作。

def prepare_dataset(ds, shuffle = False):
  # ds = ds.cache()
  if shuffle:
    ds = ds.shuffle(buffer_size=500)
  ds = ds.batch(BATCH_SIZE)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds

...model definition dataset fetching
model = tf.keras.Model(inputs, [output_sp, output_norms])

train_ds = tf.data.Dataset.from_tensor_slices((X_train_file, y_train_1, y_train_2))
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)

#Here I check the first element and all parts of it have the right shape

model.compile(
  optimizer='adam',
  # loss=tf.keras.losses.MeanSquare(),
  loss= (tf.keras.losses.MeanAbsoluteError(), tf.keras.losses.MeanSquaredError()), loss_weights=[.8, 1])


train_ds = prepare_dataset(train_ds, shuffle=True)
model.fit(
  augmented,
  # validation_data = val_ds,
  epochs=1,
  batch_size=BATCH_SIZE,
  callbacks=[cp_callback]
)

在此处输入图像描述

Traceback (most recent call last):   File "/home/scandy/Developer/RouxNN/models/normals_and_sp_transfer.py", line 245, in <module>
    model.fit(   File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)   File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)   File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
    return self._stateless_fn(*args, **kwds)   File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
    return graph_function._call_flat(   File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(   File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
    outputs = execute.execute(   File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.   (0) Invalid argument:  Can not squeeze dim[2], expected a dimension of 1, got 3
         [[{{node mean_absolute_error/weighted_loss/cond/else/_66/mean_absolute_error/weighted_loss/cond/cond/then/_387/mean_absolute_error/weighted_loss/cond/cond/Squeeze}}]]
         [[mean_squared_error/cond/then/_74/mean_squared_error/cond/cond/pivot_t/_400/_117]] (1) Invalid argument:  Can not squeeze dim[2], expected a dimension of 1, got 3
         [[{{node mean_absolute_error/weighted_loss/cond/else/_66/mean_absolute_error/weighted_loss/cond/cond/then/_387/mean_absolute_error/weighted_loss/cond/cond/Squeeze}}]]

标签: tensorflowtensorflow2.0tensorflow-datasetstf.kerasmultipleoutputs

解决方案


我能够通过拆分然后重新加入数据来使其运行。

所以我换了这个

train_ds = tf.data.Dataset.from_tensor_slices((X_train_file, y_train_1, y_train_2))
train_ds = prepare_dataset(train_ds, shuffle=True)

有了这个

train_ds_in = tf.data.Dataset.from_tensor_slices((X_train_file))
train_ds_out = tf.data.Dataset.from_tensor_slices((y_train_1, y_train_2))

data_set = tf.data.Dataset.zip( (train_ds_in , train_ds_out) )
data_set = prepare_dataset(data_set)

推荐阅读