首页 > 解决方案 > 在不使用 tf.session() 的情况下获取 Tensorflow model.predict 的值

问题描述

所以我有一个用 tensorflow 的新估计器 API 编写的预测网络

我这样称呼预测:

prediction = classifier_mycolumn.predict(cust_train_fn)

其中classifier_mycolumn定义为:

classifier_mycolumn = tf.estimator.Estimator(
        model_fn = model_function,
        params = {'feature_columns': my_feature_columns,
        'n_classes': 1,},
        model_dir=current_model)

只是一个标准的 tf.estimator.Estimator 东西。

cust_train_fn实际上不是一个函数,它是网络应该接收并运行推理的单个数组。

网络应该返回 0 或 1(最后使用 softmax 的简单分类) - 但每次我尝试打印时prediction- 我只是得到<generator object Estimator.predict at 0x7fe0e8630c50>

有什么办法可以得到预测的实际结果?例如,网络返回的 0 或 1?!非常感谢您的阅读。

标签: pythontensorflowpredictiontensorflow-estimator

解决方案


推荐阅读