python - 在 numpy 而不是 tensorflow 中计算 keras 度量
问题描述
我正在尝试在 NumPy 而不是 TensorFlow 中计算 Keras 指标。
由于您通常不需要通过纯度量的梯度流,因此可以在 NumPy 中计算度量。
我扩展tf.keras.metrics.Metrics
并覆盖了该update_state()
方法。在那里,我得到y_true
并y_pred
作为tensorflow.python.framework.ops.Tensor
带有形状的类型(None, 64, 64, 64, 6)
不幸的是,我无法将张量转换为 NumPy 数组。我认为是因为度量标准已预编译,因此还没有可用的值(这就是为什么第一个形状为无)?
我尝试使用y_true.eval()
,y_true.numpy()
提供了一个会话,例如,y_true.eval(session=session)
使用session=tf.compat.v1.Session()
or tf.compat.v1.get_default_session()
- none 工作。
如何计算 NumPy 中的指标?不幸的是,我不能只使用 TensorFlow 函数重新实现所有 NumPy 函数,因为我想使用某个包。
我在用着tensorflow 2.5, keras 2.4.3
解决方案
将您的功能包装在tf.py_function
https://www.tensorflow.org/api_docs/python/tf/py_function
Keras 想预先构建一个图,但 NumPy 要求计算是急切的。py_function
允许在实际值可用时稍后进行计算。
推荐阅读
- ios - 天蓝色 devOps 管道 xcode
- javascript - 过滤嵌套数组对象以在 FlatList 中使用
- google-cloud-platform - Vertex AI 中的 GPU 访问
- python - 使用 tkinter 在“画廊”中显示图像?
- django - pipenv 更新所有依赖项带来重大变化
- docker - 在容器内时 pip 无法取轮子
- google-sheets - 谷歌表格动态查询导入多个标签?
- azure - 无法将更新反应应用程序部署到应用程序服务
- google-tag-manager - 为什么 Google 相关的 URL 有时会出现如此多的超时?
- python - 使用处理文件中的行的 multiprocessing.pool 时,全局/共享计数器的困难(Python)