首页 > 解决方案 > RuntimeError:在 sagemaker 上的 tensorflow 2.x 数据集上跟踪函数时不支持 as_numpy_iterator()

问题描述

我正在准备一个sagemaker PIPE 模式数据集SageMaker,以在with模式下训练时间序列模型PIPEPipeModeDataset用于TensorFlow Dataset读取SageMaker管道模式通道。我正在使用一个增强的清单文件,其中包含图像位置S3和每行的标签。我的模型接受每批带有单个标签的图像批次 (512 x 512 x 1) 作为输入。我想过使用窗口函数来捆绑从管道读取的图像。数据集生成请参考以下部分代码。

def _input_fn(channel):
    """Returns a Dataset for reading from a SageMaker PipeMode channel."""
    features = {
        'image-ref': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([3], tf.int64),
    }
    
    def parse(record):
        parsed = tf.io.parse_single_example(record, features)
        image = tf.io.decode_png(parsed['image-ref'], channels=1, dtype=tf.uint8)
        image = tf.reshape(image, [512, 512, 1])
        label = parsed['label']
        return (image, label)

    ds = PipeModeDataset(channel, record_format='TFRecord', benchmark=True, benchmark_records_interval=100)
    ds = ds.map(parse)
    
    print ("PipeModeDataset print0 = " + str(ds))
    ds = ds.window(16, shift=1, drop_remainder=True)
    print ("PipeModeDataset print1 = " + str(ds))
    
    def window_func(window, label):
        window = window.batch(16, drop_remainder=True)
        label = label.batch(16, drop_remainder=True)
        
        print ("window batch is = " + str(window))
        print ("label batch is = " + str(label))
        
        window_np = np.stack(list(window.as_numpy_iterator()))
        label_np = np.stack(list(label.as_numpy_iterator())) # TODO: only get the last label
        
        return tf.data.Dataset.from_tensor_slices((window_np, label_np))
    
    ds = ds.flat_map(lambda window, label: window_func(window, label))
    ....
    ....

目前出现以下错误。如何解决这个问题?如果有更好的方法,请推荐。

PipeModeDataset print0 = <MapDataset shapes: ((512, 512, 1), (3,)), types: (tf.uint8, tf.int64)>
PipeModeDataset print1 = <WindowDataset shapes: (DatasetSpec(TensorSpec(shape=(512, 512, 1), dtype=tf.uint8, name=None), TensorShape([])), DatasetSpec(TensorSpec(shape=(3,), dtype=tf.int64, name=None), TensorShape([]))), types: (DatasetSpec(TensorSpec(shape=(512, 512, 1), dtype=tf.uint8, name=None), TensorShape([])), DatasetSpec(TensorSpec(shape=(3,), dtype=tf.int64, name=None), TensorShape([])))>
window batch is = <BatchDataset shapes: (16, 512, 512, 1), types: tf.uint8>
label batch is = <BatchDataset shapes: (16, 3), types: tf.int64>

RuntimeError: in user code:

    /opt/ml/code/train_on_pipemode.py:104 None  *
        ds = ds.flat_map(lambda window, label: window_func(window, label))
    /opt/ml/code/train_on_pipemode.py:96 window_func  *
        window_np = np.stack(list(window.as_numpy_iterator()))
    /usr/local/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:518 as_numpy_iterator  **
        raise RuntimeError("as_numpy_iterator() is not supported while tracing "

    RuntimeError: as_numpy_iterator() is not supported while tracing functions

这个答案说启用急切执行,但在我打印时启用了它tf.executing_eagerly()。我正在训练tensorflow 2.x

Tensorflow version: 2.3.1
Eager execution: True

标签: tensorflowruntime-errortensorflow2.0tensorflow-datasetsamazon-sagemaker

解决方案


推荐阅读