python - 当我使用 model.predict() 解决多类问题时,内核死了
问题描述
我使用 MacPro M1 芯片,python:3.8.10,TensorFlow:2.4.0-rc0。
在标签为 y 的多类问题(softmax)中是一个稀疏数字,即使有一些警告也能很好地工作,但model.fit()
会导致内核立即死亡。model.evaluate()
model.predict()
但是,如果标签 y 是一个单热向量,则所有 、 和 都适用于多类(softmax model.fit()
)model.evaluate()
。 model.predict()
顺便说一句,所有model.fit()
, model.evaluate()
, 和 model.predict()
在二元类和回归问题中也能很好地工作。
这是我的代码(导致内核死机)及其结果:
## Define a MLP model
model = keras.models.Sequential([
keras.layers.Flatten(input_shape = [28,28]),
keras.layers.Dense(50, activation = "relu"),
keras.layers.Dense(20, activation = "relu"),
keras.layers.Dense(10, activation = "softmax")
])
model.compile(loss = "sparse_categorical_crossentropy",
optimizer = "sgd",
metrics = ["accuracy"]
)
history = model.fit(X_train, y_train, epochs = 30,
validation_data = (X_valid, y_valid))
print(X_test.shape)
model.evaluate(X_test[:1], y_test[:1])
输出:
(10000, 28, 28)
1/1 [==============================] - 0s 8ms/step - loss: 0.0219 - accuracy: 1.0000
Out[21]:
[0.02185755781829357, 1.0]
X_new = X_test[:1]
print(X_new.shape)
print(X_new.shape[0])
y_prob = model.predict(X_new)
输出:
(1, 28, 28)
1
WARNING:tensorflow:AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x165553dc0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: unsupported operand type(s) for -: 'NoneType' and 'int'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x165553dc0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: unsupported operand type(s) for -: 'NoneType' and 'int'\
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
解决方案
推荐阅读
- python - 在python中模拟线程定时器事件超时间隔
- windows-7 - 程序无法启动,因为缺少 api-ms-win-downlevel-kernel32-l2-1-0.dll - UE4 更新后
- vlookup - 用于 VLOOKUP 和 SUMIF 公式的虚拟数组
- java - 通过 JavaWeb 在整个 MySQL 数据库中搜索关键字的推荐方法
- python - 服务器发送速度比客户端接收速度快
- jitpack - JitPack 构建失败并出现错误:获取容器状态超时
- solidity - Solidity 导入多个声明
- vue.js - Vue中的计算函数未定义
- javascript - 无法使用 AJAX 将数组从 JS 传递到 PHP
- javascript - 如何为 paypal api 显示 javascript 对象?