python - 无法将所有权重保存到 .csv 文件
问题描述
我正在使用 tensorflow 2.4 上的 mnist 数据集运行一个简单的 CNN,并且我想将所有权重保存到 .csv 文件中,我尝试过的代码如下
for layer in model_train.layers:
weights = layer.get_weights()
print(weights)
np.savetxt('weight.csv' , weights , fmt='%s', delimiter=',')
和
weight = model_train.get_weights()
np.savetxt('weight.csv' , weight , fmt='%s', delimiter=',')
和我发现的这个方法
上面的所有结果最终都保存了类似这样的 .csv 文件,它不会保存层中的每一个权重,有没有办法可以保存每一个权重?还是我做错了什么?
我的模型
model_train = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
model_train.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()]
)
model_train.fit(
train_images,
train_labels,
epochs=5,
batch_size=480,
validation_split=0.2
)
model_train.save('./mnist_model.h5') # creates a HDF5 file 'my_model.h5'
weight = model_train.get_weights()
np.savetxt('weight.csv' , weight , fmt='%s', delimiter=',')
解决方案
推荐阅读
- mysql - 执行 Liquibase changeLog 时,触发器在 Amazon Aurora 中不起作用
- python - 无法将多个变量引导到正确的列表中
- angular - Angular Firebase - 从集合中提取数据
- java - 无休止的循环-while循环内的scanner.hasNextInt()
- mongodb - 按天对文档进行分组并在 mongodb 中计算一个字段
- php - 使用变量变量获取部分选项卡,并使用 php 将其设置为新变量
- node.js - 使用 Node.js 和 mangoose 对 mangoDB 集合进行排序时出错
- c - 努力理解为什么线程只运行一段时间然后停止
- android-fragments - Kotlin Retrofit2 Enqueue“未解决的参考入队”
- node.js - 有没有办法从 nestjs-redis 客户端获取 ioredis 实例?