首页 > 解决方案 > Estimator.train() 和 .predict() 对于小数据集来说太慢了

问题描述

我正在尝试实现一个 DQN,该 DQN 在同一模型上进行多次调用,Estimator.train()每个调用Estimator.predict()都有少量示例。但是每次调用至少需要几百毫秒到一秒以上,这与 1-20 等小数字的示例数量无关。

我认为这些延迟是由于重建图表并在每次调用时保存检查点造成的。有没有办法将相同的图形和参数保留在内存中以进行快速的训练预测迭代或以其他方式加快速度?

标签: pythontensorflow-estimator

解决方案


转换为 atf.keras.Model而不是Estimator,并使用tf.keras.Model.fit()而不是Estimator.train(). fit()没有 train() 那样的固定延迟。Keraspredict()也没有。


推荐阅读