machine-learning - 使用预训练 ResNet50 网络的 OneClass SVM 模型
问题描述
我正在尝试构建用于图像识别的 OneClass 分类器。我找到了这篇文章,但是因为我没有完整的源代码,所以我并不完全理解我在做什么。
X_train, X_test, y_train, y_test = train_test_split(x, y, random_state=42)
# X_train (2250, 200, 200, 3)
resnet_model = ResNet50(input_shape=(200, 200, 3), weights='imagenet', include_top=False)
features_array = resnet_model.predict(X_train)
# features_array (2250, 7, 7, 2048)
pca = PCA(svd_solver='randomized', n_components=450, whiten=True, random_state=42)
svc = SVC(kernel='rbf', class_weight='balanced')
model = make_pipeline(pca, svc)
param_grid = {'svc__C': [1, 5, 10, 50], 'svc__gamma': [0.0001, 0.0005, 0.001, 0.005]}
grid = GridSearchCV(model, param_grid)
grid.fit(X_train, y_train)
我有 2250 张图像(食物而不是食物)200x200px 大小,我发送这些数据来预测ResNet50 模型的方法。结果是(2250, 7, 7, 2048)张量,有人知道这个维度是什么意思吗?
当我尝试运行 grid.fit 方法时,出现错误:
ValueError: Found array with dim 4. Estimator expected <= 2.
解决方案
这些是我可以做出的发现。
你得到的输出张量高于全局平均池化层。(请参阅resnet_model.summary()
以了解输入维度如何更改为输出维度)
对于一个简单的修复,在 resnet_model 之上添加一个平均池化 2d 层。(因此输出形状变为(2250,1,1,2048))
resnet_model = ResNet50(input_shape=(200, 200, 3), weights='imagenet', include_top=False)
resnet_op = AveragePooling2D((7, 7), name='avg_pool_app')(resnet_model.output)
resnet_model = Model(resnet_model.input, resnet_op, name="ResNet")
这通常存在于 ResNet50 本身的源代码中。基本上我们在 resnet50 模型上附加了一个 AveragePooling2D 层。最后一行将图层(第 2 行)和基线模型组合成一个模型对象。
现在输出维度(feature_array)将是(2250, 1, 1, 2048)
(因为添加了平均池化层)。
为了避免ValueError
你应该重塑这个 feature_array 到(2250, 2048)
feature_array = np.reshape(feature_array, (-1, 2048))
在问题程序的最后一行,
grid.fit(X_train, y_train)
你适合 X_train (在这种情况下是图像)。这里正确的变量是features_array
(这被认为是图像的摘要)。输入此行将纠正错误,
grid.fit(features_array, y_train)
要通过提取特征向量以这种方式进行更多微调,请看这里(使用神经网络训练而不是使用 PCA 和 SVM)。
希望这可以帮助!!
推荐阅读
- python - 在 Python 3.9 的类中使用列表时出错
- php - 在子包中构建 laravel 就绪的 ORM 模型?(什么依赖?)
- python - pca-valueError:无法将字符串转换为浮点数:'finalized'
- android - 我想为片段编写一个通知按钮,但我的代码有一些问题
- python - 使用 twinx() 时 X 轴日期范围发生变化
- c# - 如何从 ASP.NET MVC 5 中的不同控制器(站点范围)访问 cookie?
- angular - 在有角度的前端应用程序上隐藏我的私钥 RSA
- elasticsearch - 多行日志的顺序不正确
- rust - 在 Rust 中使用泛型、特征别名和构造函数
- azure-devops - 如果我是发起该部署的人,则自动授予我对发布的批准