tensorflow - multi_gpu_model : 'NoneType' 类型的对象没有 len()
问题描述
使用 keras multi_gpu_model 时出现此错误。如果我消除这一行,代码运行良好。此外,使用 CNN 模型它可以正常工作,只是在密集网络时它会产生错误。你能帮我解决这个问题吗?谢谢。
import numpy as np
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.layers import LSTM, BatchNormalization,Flatten
from keras.utils.vis_utils import model_to_dot
from keras.optimizers import adam
from keras.models import load_model
import pylab
from sklearn.model_selection import train_test_split
from keras.utils import multi_gpu_model
from scipy.io import wavfile
X=np.ones(10000)
y=np.zeros(100)
x_train=X
y_train=y
x_train=np.array(x_train)
y_train=np.array(y_train)
x_train.shape=(1,10000)
y_train.shape=(1,100)
model = Sequential()
model.add(Dense(500,activation = 'tanh'))
model.add(Dense(450, activation = 'tanh'))
model.add(Dense(412, activation = 'tanh'))
model.add(Dense(100, activation = 'tanh'))
opt = adam(lr=0.002, decay=1e-6)
model = multi_gpu_model(model, gpus=4)
model.compile(loss='mae', optimizer=opt, metrics=['accuracy'])
model.fit(x_train,y_train,epochs=50, batch_size = 40000)
Error: Traceback (most recent call last):
File "p.py", line 37, in <module>
model = multi_gpu_model(model, gpus=4)
File "/home/ENG/benipas1/anaconda3/envs/new/lib/python3.7/site-packages/keras/utils/multi_gpu_utils.py", line 203, in multi_gpu_model
for i in range(len(model.outputs)):
TypeError: object of type 'NoneType' has no len()
解决方案
问题在这里:
model = Sequential()
model.add(Dense(500,activation = 'tanh'))
您没有为第一层提供输入形状,因此模型的输出完全未定义并且model.outputs
为无。如果您向第一层提供输入形状,则定义输出并且它应该可以正常工作。您可能正在为您的 CNN 模型提供输入形状,这就是它起作用的原因:
model.add(Dense(500,activation = 'tanh', input_shape=(something,)))
推荐阅读
- java - Java中的匹配纸牌游戏
- java - 是否可以在 Vaadin 14 和 RouterLink 中添加 uri 的#fragment 部分?
- c++ - 需要帮助了解 Stockfish 中使用的 uci.h 文件
- ios - 邮递员 GET 请求的 Alamofire 等效项
- php - 为什么我的前端 Wordpress 登录表单卡在重定向循环中?
- c++ - 分配 C++ 后指针的行为不符合预期
- android - Android谷歌地图,点击创建标记并立即开始拖动它
- regex - 如何使用带有正则表达式/特殊字符的 lstlisting 包?
- python - PyTorch - 调整图像大小的原因是什么以及如何确定最佳尺寸?
- amazon-web-services - 如何启用 (https) SSL 证书 AWS EC2 托管站点