python - 如何在 tf.data.Dataset.map 中使用预训练的 keras 模型进行推理?
问题描述
我有一个预先训练的模型,我正在尝试构建另一个模型,将前一个模型的输出作为输入。我不想端到端地训练模型,只想使用第一个模型进行推理。第一个模型是使用tf.data.Dataset
管道训练的,我的第一个倾向是将模型集成为dataset.map()
管道尾部的另一个操作,但我遇到了问题。我在这个过程中遇到了 20 个不同的错误,每一个都与前一个无关。批量标准化层似乎尤其是一个痛点。
以下是说明该问题的最小入门示例。它是用 R 编写的,但也欢迎用 python 回答。
我正在使用 tensorflow-gpu 版本 1.13.1 和 kerastf.keras
library(reticulate)
library(tensorflow)
library(keras)
library(tfdatasets)
use_implementation("tensorflow")
model_weights_path <- 'model-weights.h5'
arr <- function(...)
np_array(array(seq_len(prod(unlist(c(...)))), unlist(c(...))), dtype = 'float32')
new_model <- function(load_weights = TRUE) {
model <- keras_model_sequential() %>%
layer_conv_1d(5, 5, activation = 'relu', input_shape = shape(150, 10)) %>%
layer_batch_normalization() %>%
layer_flatten() %>%
layer_dense(10, activation = 'softmax')
if (load_weights)
load_model_weights_hdf5(model, model_weights_path)
freeze_weights(model)
model
}
if(!file.exists(model_weights_path)) {
model <- new_model(FALSE)
save_model_weights_hdf5(model, model_weights_path)
}
model <- new_model()
data <- arr(20, 150, 10)
ds <- tfdatasets::tensors_dataset(data) %>%
dataset_repeat()
ds2 <- ds %>%
dataset_map(function(x) {
model(x)
})
try(nb <- next_batch(ds2))
sess <- k_get_session()
it <- make_iterator_initializable(ds2)
sess$run(iterator_initializer(it))
nb <- it$get_next()
try(sess$run(nb))
sess$run(tf$initialize_all_variables())
try(sess$run(nb))
解决方案
也许这不会直接回答你的问题,因为我不熟悉 R。但我最近使用tf.data
.
该generate_images
函数.map
使用经过训练的生成器模型进行映射并生成新图像。
gen_model = tf.keras.models.load_model(artifact_dir+'/'+generators[-1], compile=False)
NOISE_DIM = 100
def generate_images(l):
# generate images using the trained generator
noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
images = gen_model(noise)
# prepare the images for resize_and_preprocess function
images = tf.squeeze(images, axis=-1)
images = images*0.5+0.5
images = tf.image.convert_image_dtype(images, dtype=tf.uint8)
return images
genloader = tf.data.Dataset.from_tensors([1])
genloader = (
genloader
.map(generate_images, num_parallel_calls=AUTO)
.map(resize_and_preprocess, num_parallel_calls=AUTO)
.prefetch(AUTO)
)
关于批量标准化,它在训练和推理阶段表现不同。training=False
在基于 Python 的 TensorFlow 中,使用具有批量标准化层的预训练模型时需要通过。
推荐阅读
- python - Django:为什么我必须在我的 AppConfig 中导入信号(它似乎没有工作)
- javascript - 如何在 React-Quill 中设置字符长度
- postgresql - Postgres字符串与前导空格的比较
- html - 用于发送与我的表格格式相同的电子邮件的 Google 应用脚本
- php - 无法在 Symfony 2.3 中更新新的数据库字段
- php - 如何修复 PHP 数组爆炸结果在表单选择中仅显示最后一个值?
- unit-testing - 为什么从 npm 运行 Jest 时,我的覆盖率为 0?
- swift - Swift firebase 查询有序搜索
- r - 如何将 app.R 的输入传递给其他脚本并返回对象?
- node.js - Babel 编译错误 SyntaxError: Unexpected token when using spread operator