python - 如何在 TensorFlow v2 中分析网络
问题描述
开发 DNN 的一个常见且重要的问题是哪些操作需要多长时间以及它们如何在设备和线程之间分布。
这曾经在 TensorFlow v1 中通过tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
传递给是可能的session.run()
,请参阅Can I measure the execution time of individual operations with TensorFlow?
但是在 V2 中没有更多的会话。相反,您构建和训练这样的模型:
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
model.fit(train_dataset, epochs=2)
我能找到的唯一选择profiler
是tensorflow_core.python.eager.profiler
. 这样,您将获得一个Trace ProtoBuf 对象,其中包含具有持续时间的事件。然而,我得到的事件被命名'Model', 'BatchV2', 'TensorSlice', 'Prefetch', 'MemoryCacheImpl', 'MemoryCache', 'TFRecord', 'Shuffle', 'Map', 'FlatMap', '_Send', 'ParallelMap', 'NotEqual', 'ParallelInterleaveV2', 'LogicalAnd'
并且与图层没有明确的关系。
如何为显示所有 Ops 的运行时和设备和线程的任何模型获得正确的跟踪?
解决方案
TensorFlow Profiler 需要超过 2.2 版本的 Tensorflow 和 Tensorboard。
1.安装“tensorboard_plugin_profile”
pip install -U tensorboard_plugin_profile
2.确认TensorFlow可以访问GPU
device_name = tf.test.gpu_device_name()
if not device_name:
raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))
3.定义张量板回调
logs = "[save path]/logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs,
histogram_freq = 1, #option
profile_batch = 5) #option
在我的情况下,没有“profile_batch”选项,我得到了 100 个步数(即 epochs) ex ) tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs)
4.在fit()处设置回调属性
history = model.fit(train_input, train_output,
batch_size=BATCH_SIZE, epochs=EPOCHS,
callbacks=[tboard_callback])
5.训练结束后,在终端运行tensorboard
tensorboard --logdir [save path]/logs/
不需要双引号
例如)张量板--logdir c\python\tb\logs
6. 打开浏览器并输入 'localhost:6006' > Profile > kernal_stats
7.检查“跟踪查看器”和“kernel_stats”
我没有 GPU,所以我无法获得 kernel_stats。
关于它的更多信息,我推荐这个页面。
https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras
推荐阅读
- mysql - 为什么 mysql 中的 null=null 为假?
- c++ - 散点图不考虑具有散点样式 ssDot 的自适应采样
- r - R函数中ggplot2的使用
- xslt-2.0 - XSLT 2.0 测试没有价值的标记化结果
- c - 已分配内存块上的 calloc 是否会调用重复分配?
- sql - 这是什么连接类型:JOIN [Status] ON [Name] = 'Acknowledged'
- ruby-on-rails - Rails:将嵌套属性传递给 vue.js
- html - 当内容大于其框架时,JasperReports 的 Html 组件正在缩小
- python - django-admin.py 在 Windows Server 2012 上不起作用
- r - quantstrat 策略 - 在每个周期结束时关闭