python - 如何将自定义 python 函数应用于 Tensorflow 中的配对数据?
问题描述
我希望计算 TensorFlow 中两个张量之间的自定义函数(在我的例子中是 Hausdorff 距离)。但是,我仍然想念tf.map
函数的使用。
特别是,我希望在批处理的元素之间计算一个自定义函数,这个函数是一个 python 函数。当我想使用该功能时,我会:
# get a batch of images and ground truth segmentation masks from the dataset
input_data, y_true = from_dataset(...)
# compute predicted segmentation mask:
y_pred = myCNN(input_data)
# compute hausdorff distance for each element of the batch:
list_of_hds = sess.run(hausdorff_distance(y_pred, y_true))
我目前对 Hausdorff 距离的实现:
from medpy.metric.binary import hd
import tensorflow as tf
def hausdorff_distance(mask1, mask2):
def _py_hd(m1, m2):
where1 = np.argwhere(m1)
where2 = np.argwhere(m2)
return hd(where1, where2)
tf_hd = tf.map_fn(lambda el:
tf.py_function(func=_py_hd, inp=[el[0], el[1]],
Tout=[tf.float32, tf.float32], name='hausdorff_distance'),
elems=[mask1, mask2])
return tf_hd
但是,如果我做对了,这个实现就是错误的。事实上,它会将 HD 完全应用于 mask1 和 mask2。相反,我想获得批次的每个元素之间的 HD 列表。在实践中,我想要列表:l = [HD(mask1[0], mask2[0]), HD(mask1[1], mask2[1]), ... HD(mask1[N], mask2[N])]
。
我想念什么?我误解了 的功能tf.map
吗?
谢谢你,G。
PS这个实现使用TensorFlow 1.14,虽然我猜TensorFlow > 2应该是类似的。
编辑: 我找到了一个可能的解决方案,我将其留在下面的评论中。任何帮助仍然非常受欢迎:)
解决方案
我找到了一个可能的解决方案,我希望它是正确的。我把它留在这里,让谁感兴趣。
from medpy.metric.binary import hd
import tensorflow as tf
def hausdorff_distance(mask1, mask2):
"""Compute the average Hausdorff distance for the patient (in pixels), between mask1 and mask2."""
def _py_hd(m1, m2):
"""Python function to compute HD between the two n-dimensional masks"""
m1, m2 = np.array(m1), np.array(m2)
num_elems = len(m1)
assert len(m2) == num_elems
# remove last channel, if it is == 1:
if len(m1.shape) == 4 and m1.shape[-1] == 1:
m1, m2 = np.squeeze(m1, axis=-1), np.squeeze(m2, axis=-1)
return hd(m1, m2)
# map _py_hd(.) to every element on the batch axis:
tf_hd = tf.py_function(func=_py_hd, inp=[mask1, mask2],
Tout=[tf.float32], name='hausdorff_distance'),
# return the average HD in the batch:
return tf.reduce_mean(tf_hd)
推荐阅读
- javascript - React Table:如何对每个holeOne 到holeNine 求和(求和)值,并将总和显示为out:41 中的值?
- android - Flutter:“使用静态访问无法访问实例成员'playing'。”
- typescript - 无法理解这些泛型(?)定义中发生了什么,包括。类型参数列表
- python - 是否可以使 matplotlib 图形轴等比例缩放?
- clang - 获取 Clang AST 节点类型的可靠方法
- reactjs - ReactJS RSuite 3 未正确渲染组件
- javascript - 使用 Axios 和 Vue 将更改保存到后端 API 的正确方法
- python - TensorFlow 训练 - “批量大小”和 tf.unpack - 解包非“批量大小”的动态值?
- gcc - 错误:注册后出现垃圾 `bswapl eax movl %eax'
- docker - 有没有办法恢复 docker 卷?