python-3.x - How to convert Tensorflow 2.0 SavedModel to TensorRT?
问题描述
I've trained a model in Tensorflow 2.0 and am trying to improve predict time when moving to production (on a server with GPU support). In Tensorflow 1.x I was able to get a predict speedup by using freeze graph, but this has been deprecated as of Tensorflow 2. From reading Nvidia's description of TensorRT, they suggest that using TensorRT can speedup inference by 7x compared to Tensorflow alone. Source:
TensorFlow 2.0 with Tighter TensorRT Integration Now Available
I have trained my model and saved it to a .h5 file using Tensorflow's SavedModel format. Now I follow nvidia's documentation to optimize the model for inference with tensorrt: TF-TRT 2.0 Workflow With A SavedModel.
When I run:
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
I get the error: ModuleNotFoundError: No module named 'tensorflow.python.compiler.tensorrt'
They give another example with Tensorflow 2.0 here: Examples. However, they try to import the same module as above and I get the same error.
Can anyone suggest how to optimize my model with TensorRT?
解决方案
I've solved this issue. The problem is that I was testing the code on my local Windows machine, rather than on my AWS EC2 Instance with gpu support.
It seems that tensorflow.python.compiler.tensorrt is included in tensorflow-gpu, but not in standard tensorflow. In order to convert the SavedModel instance with TensorRT, you need to use a machine with tensorflow-gpu. (I knew that this would be required to run the model, but hadn't realized it was needed to convert the model.)
推荐阅读
- javascript - 将 onchange 脚本内联到外部 JS
- machine-learning - 您如何使用 sklearn 模型对新观察结果进行预测?
- three.js - 如何理解 ThreeJS 中 BufferGeometry 中的 setIndex 和 index?
- django - django 中的related_name 和related_query_name 是什么?
- javascript - 如何通过 ajax 更改 xhtml2pdf 的查询集
- r - 如何将标签与 ggplot 2 中的条对齐
- jenkins - Jenkins Pipeline - 在shell中插入变量创建一个新行
- python - 使用权重文件离线加载keras resnet50模型失败
- node.js - 如何在请求上下文之外的 nodejs Express 中捕获异步错误?
- swiftui - SwiftUI - 通过嵌套引用类型传播更改通知