python - STFT 和 DWT 输入数据的深度学习参数
问题描述
我在 STFT 数据和离散小波变换数据上创建 CNN 模型。我想在 python 的 2 个输入数据上获得我的深度学习模型的权重和偏差的数量。怎么做 ??
任何帮助,将不胜感激。
代码:
def createModel():
with tf.device("cpu"):
input_shape=(1, 22, 5, 3844)
model = Sequential()
model.add(Conv3D(16, (22, 5, 5), strides=(1, 2, 2), padding='same',activation='relu',data_format= "channels_first", input_shape=input_shape))
model.add(keras.layers.MaxPooling3D(pool_size=(1, 2, 2),data_format= "channels_first", padding='same'))
model.add(BatchNormalization())
model.add(Conv3D(32, (1, 3, 3), strides=(1, 1,1), padding='same',data_format= "channels_first", activation='relu'))#incertezza se togliere padding
model.add(keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first", ))
model.add(BatchNormalization())
model.add(Conv3D(64, (1,3, 3), strides=(1, 1,1), padding='same',data_format= "channels_first", activation='relu'))#incertezza se togliere padding
model.add(keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first",padding='same' ))
model.add(BatchNormalization())
model.add(Dense(64, input_dim=64,kernel_regularizer=regularizers.l2(0.0001), activity_regularizer=regularizers.l1(0.0001)))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(256, activation='sigmoid'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))
opt_adam = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='categorical_crossentropy', optimizer=opt_adam, metrics=['accuracy'])
return model
解决方案
您应该做的第一件事是安装 h5py
pip install h5py
然后你可以在这个文件中探索 keras 模型
import h5py
f = h5py.File('mytestfile.hdf5', 'r')
# layer names of your model
list(f.keys())
# you can use this layers as index
d = f['dense']['dense_1']['kernel:0']
推荐阅读
- python - 访问像 `arr[arr>5]` 这样的 NumPy 数组是如何工作的?
- javafx - Java FX TextField 模糊
- asp.net-mvc - 为什么即使在注销后 User.Identity.IsAuthenticated 也始终为真
- php - Laravel 脚本中遇到的非数字值
- json - 语法错误,意外的keyword_end,需要')'
- php - 记录到数据库 BLOB 列长时间运行的 php 脚本的内存有效方式
- google-analytics - 如何在 Google Analytics 中分析多个查询参数
- angular - 角度 github 构建并非所有图像都获取应用程序 URL(和其他内容)
- javascript - 在网页上显示 Discord 服务器成员数?
- python - matplotlib.pyplot.savefig 需要很长时间