首页 > 解决方案 > 从外部附加函数调用 `model.predict()`

问题描述

以此为参考,我想出了以下代码

import tensorflow as tf
from tensorflow.keras.applications.densenet import DenseNet121
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess_input
import inspect, cv2
import numpy as np

@tf.function(input_signature=[tf.TensorSpec([None, None, 3],dtype=tf.uint8)])
def _preprocess(image_array):
    im_arr = tf.image.resize(image_array, (resize_height, resize_width))
    im_arr = densenet_preprocess_input(im_arr)
    input_batch = tf.expand_dims(im_arr, axis=0)
    return input_batch

training_model = DenseNet121(include_top=True, weights='imagenet')

#Assign resize dimensions
resize_height = tf.constant(480, dtype=tf.int64)
resize_width = tf.constant(640, dtype=tf.int64)

#Attach function to Model
training_model.preprocess = _preprocess

#Attach resize dimensions to Model
training_model.resize_height = resize_height
training_model.resize_width = resize_width

training_model.save("saved_model", overwrite=True)

它基本上附加了一个名为preprocess的方法,为DenseNet121tf.keras.Model定义。

以便以后我可以使用它来进行预测:

pred_model = tf.keras.models.load_model('saved_model')

#download image
image_path = tf.keras.utils.get_file("cat.jpg", "https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg")
#load and convert the image to tf.uint8 numpy array
image_array = np.array(tf.keras.preprocessing.image.load_img(path=image_path))

#call the custom function bound to the model
preprocessed_image = pred_model.preprocess(image_array)

result = pred_model.predict(preprocessed_image)
print(np.argmax(result, axis=-1), np.amax(result, axis=-1))

我的问题:

如何从预处理函数调用模型的预测方法。以便

preprocessed_image = pred_model.preprocess(image_array)
result = pred_model.predict(preprocessed_image)

可以变成

result = pred_model.preprocess_predict(image_array)

标签: pythonpython-3.xtensorflowkerastf.keras

解决方案


推荐阅读