python - 使用model.predict()后如何理解结果数组的维度
问题描述
我正在重复一个代码来检索项目,但是当我在model.predict函数中调试时,我发现这个函数的输入是维度(1、224、224、3),但输出是(1, 7, 7, 2048)。model.predict() 的结果不应该是一个 1D 数组,它给出对象属于每个类别而不是 4D 的概率吗?如何理解这个结果数组的维度?
model_features = model.predict(x, batch_size=1)
具体代码如下:(这只是整个代码的一部分,可能无法直接运行)
import keras.applications.resnet50
import numpy as np
import os
import pickle
import time
import vse
from keras.preprocessing import image
from keras.models import Model, load_model
model = keras.applications.resnet50.ResNet50(include_top=False)
model_extension == "resnet"
def extract_features_cnn(img_path):
"""Returns a normalized features vector for image path and model specified in parameters file """
print('Using model', model_extension)
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
if model_extension == "vgg19":
x = keras.applications.vgg19.preprocess_input(x)
elif model_extension == "vgg16":
x = keras.applications.vgg16.preprocess_input(x)
elif model_extension == "resnet":
x = keras.applications.resnet50.preprocess_input(x)
else:
print('Wrong model name')
model_features = model.predict(x, batch_size=1)
x = model_features[0]
total_sum = sum(model_features[0])
features_norm = np.array([val / total_sum for val in model_features[0]], dtype=np.float32)
if model_extension == "resnet":
print("reshaping resnet")
features_norm = features_norm.reshape(2048, -1)
return features_norm
解决方案
你的问题不够清楚,但我会尽可能多地解释你的问题。您的模型只有 ResNet,它只有卷积层,并且没有线性层,它可能导致表示类概率的结果。你的结果不是你想象的 4D。在您的输出形状中,即(1, 7, 7, 2048)
. 1 代表批量大小。这意味着您仅向网络提供了 1 张图像并获得 1 个结果。7s 代表您的输出大小,即 7x7。2048 代表你的输出通道。如果你想得到类的概率,你需要在 ResNet 网络的末端添加一个线性层。您可以使用参数添加它,include_top=True
并且可以使用参数指定类号classes=1000
。
这是文档。
推荐阅读
- ngrx - StoreModule.forRoot() - 如何在没有附加键的情况下返回对象
- javascript - 使用 lodash 按嵌套属性对对象数组进行排序
- java - 搜索将对象之间的双向链接转换为 JSON 格式的正确方法
- batch-file - 将变量保存在文本文件中
- excel - Power Query 仅在最近未刷新时才刷新源表
- javascript - Promise catch,如何返回一个新的替代 promise 继续?
- python-3.x - AWS ALB 背后的 Python Flask 端到端加密
- php - woocommerce 中的 get_attachments()
- docker - Kafka Client Timeout 60000ms 在分区位置确定之前过期
- python - python 中的 smtplib.server.sendmail 函数引发 UnicodeEncodeError: 'ascii' codec can't encode character