首页 > 解决方案 > 如何在 tf.data.Dataset.map 中使用预训练的 keras 模型进行推理?

问题描述

我有一个预先训练的模型,我正在尝试构建另一个模型,将前一个模型的输出作为输入。我不想端到端地训练模型,只想使用第一个模型进行推理。第一个模型是使用tf.data.Dataset管道训练的,我的第一个倾向是将模型集成为dataset.map()管道尾部的另一个操作,但我遇到了问题。我在这个过程中遇到了 20 个不同的错误,每一个都与前一个无关。批量标准化层似乎尤其是一个痛点。

以下是说明该问题的最小入门示例。它是用 R 编写的,但也欢迎用 python 回答。

我正在使用 tensorflow-gpu 版本 1.13.1 和 kerastf.keras

library(reticulate)
library(tensorflow)
library(keras)
library(tfdatasets)
use_implementation("tensorflow")

model_weights_path <- 'model-weights.h5'

arr <- function(...) 
  np_array(array(seq_len(prod(unlist(c(...)))), unlist(c(...))), dtype = 'float32')

new_model <- function(load_weights = TRUE) {
  model <- keras_model_sequential() %>% 
    layer_conv_1d(5, 5, activation = 'relu', input_shape = shape(150, 10)) %>%
    layer_batch_normalization() %>%
    layer_flatten() %>%
    layer_dense(10, activation = 'softmax')
  if (load_weights)
    load_model_weights_hdf5(model, model_weights_path)
  freeze_weights(model)
  model
}

if(!file.exists(model_weights_path)) {
  model <- new_model(FALSE) 
  save_model_weights_hdf5(model, model_weights_path)
}

model <- new_model()

data <- arr(20, 150, 10)
ds <- tfdatasets::tensors_dataset(data) %>% 
  dataset_repeat()

ds2 <- ds %>% 
  dataset_map(function(x) {
    model(x)
  })

try(nb <- next_batch(ds2))

sess <- k_get_session()
it <- make_iterator_initializable(ds2)
sess$run(iterator_initializer(it))
nb <- it$get_next()

try(sess$run(nb))

sess$run(tf$initialize_all_variables())

try(sess$run(nb))

标签: pythonrtensorflowkerastensorflow-datasets

解决方案


也许这不会直接回答你的问题,因为我不熟悉 R。但我最近使用tf.data.

generate_images函数.map使用经过训练的生成器模型进行映射并生成新图像。

gen_model = tf.keras.models.load_model(artifact_dir+'/'+generators[-1], compile=False)

NOISE_DIM = 100

def generate_images(l):
    # generate images using the trained generator
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    images = gen_model(noise)

    # prepare the images for resize_and_preprocess function
    images = tf.squeeze(images, axis=-1)
    images = images*0.5+0.5
    images = tf.image.convert_image_dtype(images, dtype=tf.uint8)

    return images

genloader = tf.data.Dataset.from_tensors([1])

genloader = (
    genloader
    .map(generate_images, num_parallel_calls=AUTO)
    .map(resize_and_preprocess, num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

关于批量标准化,它在训练和推理阶段表现不同。training=False在基于 Python 的 TensorFlow 中,使用具有批量标准化层的预训练模型时需要通过。


推荐阅读