首页 > 解决方案 > 将参数传递给 tf.py_function

问题描述

我正在创建一个 tf.data.Dataset,我有几个需要传递参数的预处理函数。是否可以通过 tf.py_function() 将参数传递给函数?

我能看到的唯一方法是将我的预处理函数放在一个类中,这样我就可以通过 self.

例如:


class My_Dataset():
    def __init__(self, shape):
        self.shape = shape

    def resize(self, image):
        # Note I am just using resize as a dummy example
        # my actual preprocessing functions are more general 
        # and take several params
        return cv2.resize(image.numpy(), self.shape)

    def get_map_func(self, image, label):
        [image,] = tf.py_function(self.resize, [image], [tf.float32])
        image.set_shape(shape)
        return image, label

    def create_dataset(self, images_paths, labels):
        ds = ...
        ds = ds.map(my_dataset.get_map_func)
        return ds

my_dataset = My_Dataset( (512, 512, 3) )

ds = my_dataset.create_dataset(...)


但是有更好的方法吗?对于将类传递给多进程函数,我总是非常谨慎。据我了解,他们对这个过程感到厌烦,所以如果班级变得太大,那么它似乎总是会给我带来问题。

编辑:添加第二个问题..

在上面的示例中,my_dataset 对象的任何实例实际上是否存在于最终的 ds 中?例如,images_paths 是数百万个数十 MB 大的图像路径的列表。如果我在init时将 images_paths 和标签传递给类并将它们分配给 self,那么 ds 中是否会有一些大型对象需要在进程之间传递?

标签: tensorflowtensorflow-datasets

解决方案


您可以使用 lambda 函数传递参数

例如(说明性):-

all_encoded_data = corpus_ds.map(lambda x: generate_context_target_pairs_map_fxn(x,window_size, vocab_size))

因此,您可以在此处映射主列表并将 lambda 函数传递给映射器并调用您的 tf.py_function 包装器。如果您有多个值,只需在映射之前压缩它们。


推荐阅读