首页 > 解决方案 > tf.nn.sigmoid_cross_entropy_with_logits 的尺寸问题:无法挤压 dim 1

问题描述

对于我的代码,基本上我使用的是这个MNIST 示例。我的图像不是 28x28 而是 120x50x3。我的标签不仅仅是 MNIST 中的数字,而且它们也是 120x50x3 的图像。某处我在我的代码中犯了一个错误。我认为这是 sigmoid_cross_entropy 函数的问题,它看起来需要一维的东西作为参数。在这一点上,我迷路了,希望能得到任何帮助,因为我已经花了几个小时来解决这个问题。

在发布代码之前,这是我得到的错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[1], expected a dimension of 1, got 18000 for 'remove_squeezable_dimensions/Squeeze' (op: 'Squeeze') with input shapes: [1024,18000].

这是我从中获取图像的功能:

def dataset(directory):
  """Download and parse MNIST dataset."""

  images_file = load_data("/Users/pics/images/")
  labels_file = load_data("/Users/pics/labels/")

  def decode_image(image):
    image = tf.cast(image, tf.float32)
    return image

  def decode_label(label):
    label = tf.cast(label, tf.float32)
    return label

  images = tf.data.Dataset.from_tensor_slices(images_file).map(decode_image)
  labels = tf.data.Dataset.from_tensor_slices(labels_file).map(decode_label)

  return tf.data.Dataset.zip((images, labels))


def train(directory):
  """tf.data.Dataset object for MNIST training data."""
  return dataset(directory)


def test(directory):
  """tf.data.Dataset object for MNIST test data."""
  return dataset(directory)

这是我的model_fn:

def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""

  model = create_model(features, labels, mode, params)
  image = features

  if isinstance(image, dict):
    image = features['image']

  if mode == tf.estimator.ModeKeys.PREDICT:
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }

    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.AdamOptimizer(learning_rate=4.25e-5)

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

    logits = model(image, training=True)

    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
    loss = tf.reduce_mean(cross_entropy)

    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(logits, axis=1))

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(LEARNING_RATE, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
    tf.identity(accuracy[1], name='train_accuracy')

    # Save accuracy scalar to Tensorboard output.
    tf.summary.scalar('train_accuracy', accuracy[1])

    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
  if mode == tf.estimator.ModeKeys.EVAL:

    logits = model(image, training=False)

    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
    loss = tf.reduce_mean(cross_entropy)

    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
        })

这是完整的错误消息:

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "mnist.py", line 287, in <module>
    absl_app.run(main)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/absl/app.py", line 278, in run
    _run_main(main, args)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/absl/app.py", line 239, in _run_main
    sys.exit(main(argv))
  File "mnist.py", line 281, in main
    run_mnist(flags.FLAGS)
  File "mnist.py", line 263, in run_mnist
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 366, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1119, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1132, in _train_model_default
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1107, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "mnist.py", line 158, in model_fn
    labels=labels, predictions=tf.argmax(logits, axis=1))
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py", line 403, in accuracy
    predictions=predictions, labels=labels, weights=weights)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py", line 80, in _remove_squeezable_dimensions
    labels, predictions)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/ops/confusion_matrix.py", line 72, in remove_squeezable_dimensions
    labels = array_ops.squeeze(labels, [-1])
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
    return func(*args, **kwargs)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 2556, in squeeze
    return gen_array_ops.squeeze(input, axis, name)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 7946, in squeeze
    "Squeeze", input=input, squeeze_dims=axis, name=name)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3414, in create_op
    op_def=op_def)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1756, in __init__
    control_input_ops)
  File "/Users/cezary/.pyenv/versions/3.6.4/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1592, in _create_c_op
    raise ValueError(str(e))
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 18000 for 'remove_squeezable_dimensions/Squeeze' (op: 'Squeeze') with input shapes: [1024,18000].

标签: pythontensorflow

解决方案


推荐阅读