python - 用于放大图像的 TensorRT 未达到预期结果
问题描述
几周以来,我一直在与 TensorRT(TensorRT 4 for python)抗争。为了让 TensorRT 运行,我通过了很多问题。来自 NVIDIA 的示例代码非常适合我: TensorRT MNIST 示例
现在,我在 tensorflow(一个非常简单的网络)中创建了自己的网络,用于放大图像,比如说(在 HWC 中)320x240x3 到 640x480x3。通常的方法是创建一个冻结图并运行基于 Tensorflow 的推理器,这给了我预期的结果但不是通过使用 TensorRT。
我有一种奇怪的感觉,我通过将图像输入 GPU 内存而做错了(这可能是关于 pycuda 和/或 TensorRT 的问题)。
最坏的情况是 TensorRT 通过优化过程破坏了我的网络。
我希望有人对挽救我的生命有一点想法。这是我的 Tensorflow 模型(我只是包装了函数):
net = conv2d(input,
64,
k_size=3,
activation=tf.nn.relu,
name='conv1')
net = deconv2d(net,
3,
k_size=5,
activation=tf.tanh,
stride=self.params.resize_factor,
scale=self.params.resize_factor,
name='deconv')
这是我的推理器的重要片段:
import tensorrt as trt
import uff
from tensorrt.parsers import uffparser
import pycuda.driver as cuda
import numpy as np
...
def _init_infer(self, uff_model):
g_logger = trt.infer.ConsoleLogger(trt.infer.LogSeverity.ERROR)
parser = uffparser.create_uff_parser()
parser.register_input(self.input_node, (self.channels, self.height, self.width), 0)
parser.register_output(self.output_node)
self.engine = trt.utils.uff_to_trt_engine(g_logger, uff_model, parser, self.max_batch_size,
self.max_workspace_size)
parser.destroy()
self.runtime = trt.infer.create_infer_runtime(g_logger)
self.context = self.engine.create_execution_context()
self.output = np.empty(self.output_size, dtype=self.dtype)
# create CUDA stream
self.stream = cuda.Stream()
# allocate device memory
self.d_input = cuda.mem_alloc(self.channels * self.max_batch_size * self.width *
self.height * self.output.dtype.itemsize)
self.d_output = cuda.mem_alloc(self.output_size * self.output.dtype.itemsize)
self.bindings = [int(self.d_input), int(self.d_output)]
def infer(self, input_batch, batch_size=1):
# transfer input data to device
cuda.memcpy_htod_async(self.d_input, input_batch, self.stream)
# execute model
self.context.enqueue(batch_size, self.bindings, self.stream.handle, None)
# transfer predictions back
cuda.memcpy_dtoh_async(self.output, self.d_output, self.stream)
# synchronize threads
self.stream.synchronize()
return self.output
和可执行代码片段:
...
# create trt inferencer
trt_inferencer = TensorRTInferencer(params=params)
img = [misc.imread('./test_images/lion.png')]
img[0] = normalize(img[0])
img = img[0]
# inferencing method
result = trt_inferencer.infer(img)
result = inormalize(result, dtype=np.uint8)
result = result.reshape(1, params.height * 2, params.width * 2, 3)
...
比较奇怪的结果 :( 放大的狮子 TensorRT、Tensorflow、Original
解决方案
我现在终于明白了。问题是输入图像和输出的维度和顺序错误。对于遇到同样问题的每个人,这是采用的可执行代码段,取决于我的初始化:
...
# create trt inferencer
trt_inferencer = TensorRTInferencer(params=params)
img = [misc.imread('./test_images/lion.png')]
img[0] = normalize(img[0])
img = img[0]
img = np.transpose(img, (2, 0, 1))
img = img.ravel()
# inferencing method
result = trt_inferencer.infer(img)
result = inormalize(result, dtype=np.uint8)
result = np.reshape(result, newshape=[3, params.height * 2, params.width * 2])
result = np.transpose(result, (1, 2, 0))
...
推荐阅读
- nginx - Nginx 在 100-Continue 后没有收到第二个请求,401 响应
- django - 使用 nginx gunicorn 在 digitalocean 上部署 django react
- javascript - Boilerplate Github repo 不会运行 npm install,ERR!在“...compile-2.0.0-beta.4.t”附近解析时 JSON 输入意外结束
- c# - “堆叠”画布元素
- javascript - Chart.js:在饼图的 yAxes 上汇总活动/显示数据集中的项目数
- php - IF(ISSET($_FILES['imgfile'])) 不检查文件是否正在发布
- python - 是否有适用于 VMware 兼容性指南的 API?
- mongodb - 我如何在 mongo 中执行此 SQL 查询
- javascript - Ajax 在表中显示数据库中的数据,但数据保存为用逗号分隔的字符串?
- liferay - 从 liferay 中删除 facebook cookie