首页 > 解决方案 > TPU 训练错误:“没有为与节点 {{node RaggedConcat/Cumsum 兼容的 XLA_TPU_JIT 设备注册的 'Cumsum' OpKernel”

问题描述

我正在尝试在 TPU 上使用自定义训练步骤训练模型。训练在 GPU 上运行良好,但在 TPU 上不行。根据此 [ https://cloud.google.com/tpu/docs/tensorflow-ops ],我相信我没有使用不受支持的 tensorflow 操作,但受支持和不受支持的功能的列表并不详尽,我正在使用两者中未列出的功能。

错误消息提到了一个参差不齐的张量,只有两行代码我有参差不齐的张量,它们在train_step

class CustomModel(tf.keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x = data

        positveLabel = tf.constant( [1, 0], dtype=tf.int32 )
        negativeLabel = tf.constant( [0, 1], dtype=tf.int32 )

        pLabelBatch = tf.reshape(tf.tile( positveLabel, [tf.shape(x[0])[0]], name=None), [tf.shape(x[0])[0],2] )
        nLabelBatch = tf.reshape(tf.tile( negativeLabel, [tf.shape(x[0])[0]], name=None), [tf.shape(x[0])[0],2] )

        y = (pLabelBatch, pLabelBatch, nLabelBatch, nLabelBatch) 

        batch_label = tf.reshape(y, (tf.size(y)/2, 2), name=None)

        rs = tf.ragged.stack(x, axis=0)
        reg = rs.to_tensor()
        batch_input = tf.reshape(reg, (tf.shape(reg)[0]*tf.shape(reg)[1], tf.shape(reg)[2]))

        with tf.GradientTape() as tape:
            y_pred = self(batch_input, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(batch_label, y_pred, regularization_losses=self.losses)

        # Compute gradients
        _minimize(self.distribute_strategy, tape, self.optimizer, loss,
                self.trainable_variables)
        # Update weights
        # self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

具体来说

        rs = tf.ragged.stack(x, axis=0)
        reg = rs.to_tensor()

我在网上找不到任何关于支持或不支持参差不齐的张量的信息。

我试图弄清楚如何完全解释错误消息。

标签: tensorflowkerastpu

解决方案


推荐阅读