python - 如何在 keras 中使用 cifar100 实现denseNet 架构?
问题描述
如何在 Keras 中使用 cifar100 实现denseNet 架构?我看到 Keras 中的密集网络仅使用 imageNet 实现!如何使用 cifar100 实现
解决方案
以下示例将帮助您了解如何cifar100
使用DenseNet121
. 请注意,我使用keras
with in tensorflow
。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras import backend as K
# import cifar 100 data
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# create the base pre-trained model
base_model = DenseNet121(weights='imagenet', include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 200 classes
predictions = Dense(100)(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional layers
for layer in base_model.layers:
layer.trainable = False
# compile the model (should be done *after* setting layers to non-trainable)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='rmsprop', loss=loss, metrics=['accuracy'])
# train the model on the new data for a few epochs
model.fit(x_train,y_train,epochs=5, validation_data=(x_test,y_test), verbose=1,batch_size=128)
您也可以进行微调,因为我训练了将原始base_model
权重保持在冻结状态的模型(未训练原始 base_model 的权重)。在微调期间,您可以解冻一些层并再次训练。我还建议您阅读有关ImageDataGenerator
增强图像并在测试期间获得更好的准确性的信息。
希望能帮助到你。
推荐阅读
- regex - 在 VB6 中使用正则表达式查找精确的通配符模式并仅返回该字符串及其位置
- react-native - 使用 mobx 操作来获取数据/执行异步请求
- python - 测量音频“响度”:RMS 与 LUFS
- embedded - IAR Embedded Workbench IDE - ARM 8.40.2 错误
- google-bigquery - 将 Hootsuite 数据连接到 BigQuery,而无需从 Hootsuite 手动导出数据集
- javascript - 使用 javascript 查找客户端计算机中可用的空闲内存
- ubuntu - 尝试从 Ubuntu 18.04 连接到 Raspberry PI 时出现 SSH“身份验证失败次数过多”错误
- typescript - 需要澄清干净的架构和存储库模式
- r - 如何填充数据直到最后一个非缺失值?
- python - 有没有办法使用 Pyhon Pandas 库只选择没有任何行的列标签?