javascript - Tensorflow 与 Tensorflow JS 浮点算术计算的不同结果
问题描述
我已将 Tensorflow 模型转换为 Tensorflow JS 并尝试在浏览器中使用。在将输入输出输入模型进行推理之前,需要在输入输出上执行一些预处理步骤。我已经实现了与 Tensorflow 相同的这些步骤。问题是 TF JS 的推理结果与 Tensorflow 不同。所以我开始调试代码,发现TF JS预处理中浮点算术运算的结果与运行在带有GPU的Docker容器上的Tensorflow不同。TF JS 中使用的代码如下。
var tensor3d = tf.tensor3d(image,[height,width,1],'float32')
var pi= PI.toString();
if(bs == 14 && pi.indexOf('1') != -1 ) {
tensor3d = tensor3d.sub(-9798.6993999999995).div(7104.607118190255)
}
else if(bs == 12 && pi.indexOf('1') != -1) {
tensor3d = tensor3d.sub(-3384.9893000000002).div(1190.0708513300835)
}
else if(bs == 12 && pi.indexOf('2') != -1) {
tensor3d = tensor3d.sub(978.31200000000001).div(1092.2426342420442)
}
var resizedTensor = tensor3d.resizeNearestNeighbor([224,224]).toFloat()
var copiedTens = tf.tile(resizedTensor,[1,1,3])
return copiedTens.expandDims();
使用的 Python 代码块
ds = pydicom.dcmread(input_filename, stop_before_pixels=True)
if (ds.BitsStored == 12) and '1' in ds.PhotometricInterpretation:
normalize_mean = -3384.9893000000002
normalize_std = 1190.0708513300835
elif (ds.BitsStored == 12) and '2' in ds.PhotometricInterpretation:
normalize_mean = 978.31200000000001
normalize_std = 1092.2426342420442
elif (ds.BitsStored == 14) and '1' in ds.PhotometricInterpretation:
normalize_mean = -9798.6993999999995
normalize_std = 7104.607118190255
else:
error_response = "Unable to read required metadata, or metadata invalid.
BitsStored: {}. PhotometricInterpretation: {}".format(ds.BitsStored,
ds.PhotometricInterpretation)
error_json = {'code': 500, 'message': error_response}
self._set_headers(500)
self.wfile.write(json.dumps(error_json).encode())
return
normalization = Normalization(mean=normalize_mean, std=normalize_std)
resize = ResizeImage()
copy_channels = CopyChannels()
inference_data_collection.append_preprocessor([normalization, resize,
copy_channels])
规范化代码
def normalize(self, normalize_numpy, mask_numpy=None):
normalize_numpy = normalize_numpy.astype(float)
if mask_numpy is not None:
mask = mask_numpy > 0
elif self.mask_zeros:
mask = np.nonzero(normalize_numpy)
else:
mask = None
if mask is None:
normalize_numpy = (normalize_numpy - self.mean) / self.std
else:
raise NotImplementedError
return normalize_numpy
调整大小图像代码
from skimage.transform import resize
def Resize(self, data_group):
input_data = data_group.preprocessed_case
output_data = resize(input_data, self.output_dim)
data_group.preprocessed_case = output_data
self.output_data = output_data
CopyChannels 代码
def CopyChannels(self, data_group):
input_data = data_group.preprocessed_case
if self.new_channel_dim:
output_data = np.stack([input_data] * self.channel_multiplier, -1)
else:
output_data = np.tile(input_data, (1, 1, self.channel_multiplier))
data_group.preprocessed_case = output_data
self.output_data = output_data
示例输出左侧是带有 GPU 的 Docker 上的 Tensorflow,右侧是 TF JS:
每一步之后的结果实际上是不同的。
解决方案
可能有多种可能性会导致该问题。
1-python中使用的ops在js和python中的使用方式不同。如果是这种情况,使用完全相同的操作将解决这个问题。
2- python 库和浏览器画布可能会以不同方式读取张量图像。实际上,由于抗锯齿等某些操作,跨浏览器的画布像素并不总是具有相同的值......如this answer中所述。因此,操作的结果可能会有一些细微的差别。为了确保这是问题的根本原因,首先尝试打印 python 和 js 数组image
,看看它们是否相似。js和python中的3d张量很可能是不同的。
tensor3d = tf.tensor3d(image,[height,width,1],'float32')
在这种情况下,可以使用 python 库将图像转换为张量数组,而不是直接在浏览器中读取图像。并使用 tfjs 直接读取这个数组而不是图像。这样,输入张量在 js 和 python 中都是相同的。
3 - 这是一个 float32 精度问题。tensor3d 是使用 dtype 创建的float32
,根据使用的操作,可能存在精度问题。考虑这个操作:
tf.scalar(12045, 'int32').mul(tf.scalar(12045, 'int32')).print(); // 145082032 instead of 145082025
在 python 中将遇到相同的精度问题,如下所示:
a = tf.constant([12045], dtype='float32') * tf.constant([12045], dtype='float32')
tf.print(a) // 145082032
在 python 中,这可以通过使用int32
dtype 来解决。然而,由于 webgl 的float32
限制,不能在 tfjs 上使用 webgl 后端来完成同样的事情。在神经网络中,这个精度问题并不是什么大问题。为了摆脱它,可以使用setBackend('cpu')
例如慢得多的后端来更改后端。
推荐阅读
- r - 在 R 中有什么方法可以在一列中找到“文本”,将相同的值放在 R 数据框中的另一列中?
- excel - 如何将日期中的月份替换为特定月份?
- matlab - 为什么通过 lsqcurvefit 和 fminunc 函数获得的标准误差不同?
- python - Pandas DataFrame Groupby 两列并添加移动平均列
- ms-access - 如何在 MS Access 中添加验证或输入掩码?
- javascript - Fullcalendar v4:上一个按钮和浏览器后退按钮集成失败
- javascript - 无法将带有 "" 、 '' 和 /n 的字符串从 PHP 回显到 JavaScript
- excel - 是否可以将两个具有不同行和列的 Excel 电子表格组合在一起?
- macos - 在 OSX 中获取鼠标按钮状态
- java - BufferedReader 读取空行后等待 30 秒