python - 我将如何测量/记录 keras / tensorflow 的人工神经网络算法的总训练时间?
问题描述
我想记录训练 ANN 算法的总时间,但我不知道如何。我得到每个步骤的通常训练对话,例如:
7/7 [==============================] - 0s 2ms/step - loss: 0.2611 - accuracy: 0.9816
有没有可以为我做这件事的实用程序?这是因为我想研究更复杂的算法是否值得考虑可能更长的训练时间。
解决方案
最好的方法是创建一个自定义回调。下面的代码将在培训完成后打印出培训持续时间。
class TRA(keras.callbacks.Callback):
def __init__(self):
super(TRA, self).__init__()
def on_train_begin(self, logs=None):
self.start_time= time.time()
def on_train_end(self, logs=None):
stop_time=time.time()
tr_duration= stop_time- self.start_time
hours = tr_duration // 3600
minutes = (tr_duration - (hours * 3600)) // 60
seconds = tr_duration - ((hours * 3600) + (minutes * 60))
msg = f'training elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds)'
print(msg)
然后使用此代码
callbacks=[TRA()]
history=model.fit(... callbacks=callbacks ....)
在 model.fit 中包含如上所示的参数回调
推荐阅读
- google-sheets - 谷歌表格的最后一行
- python - 用于从 dict 中提取所有 url 的正则表达式,如字符串
- vba - 错误 5941 VBA Word 2016 - 保存文件,从表中获取名称
- ios - 从 Android 迁移到 iOS AWS IoT
- sql-server - Powershell、DataTable、ExecuteReader、查询返回的零行
- c# - ASP.NET Web 项目模板包含 2 个用于导入 Microsoft.WebApplication.targets $VSToolsPath) 与 $(MSBuildExtensionsPath32) 的条目
- spring-mvc - 如何在 Spring Data REST 中创建 RESTful 查询语言
- php - 为什么在调用包装短代码时,短代码内的 wp_video_shortcode 会回显?
- python - 在 spark sql SELECT 查询中执行 .show() 时出错
- magento2 - 有什么方法可以为 Magento 2 中的所有简单产品创建自定义选项?