python - 如何删除 TensorFlow 自定义操作实例?
问题描述
我正在使用 TensorFlow 中的自定义操作教程。
我在 C++ 中实现了自定义操作并创建了一个共享库。我正在使用 python tf.load_op_library 函数调用加载它。然后我使用 session.run() 调用自定义操作。
自定义操作工作正常。但我无法弄清楚何时调用自定义的析构函数。
即使我在自定义操作的析构函数中有一个打印语句,它也永远不会被打印出来。似乎自定义操作实例永远不会被破坏。
这是预期的行为吗?如果是这种情况,有没有办法通知 tensorflow 我已使用自定义操作完成?
请注意,仅当我们使用 tf.placeholders 时才会注意到此行为。如果我将输入矩阵设为常数,即
x=zero_out_module.zero_out([[1, 2], [3, 4]])
然后调用自定义操作的析构函数。
C++ 中的自定义操作实现
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
printf("Constructor: ZeroOutOp\n"); fflush(stdout);
}
~ZeroOutOp() override {
printf("Destructor: ZeroOutOp\n"); fflush(stdout);
}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
用于测试自定义操作的 Python 代码
import tensorflow as tf
import numpy as np
zero_out_module = tf.load_op_library('./lib/zero_out/zero_out.so')
inmat=np.array([[1,2],[3,4]])
with tf.device("/cpu:0"):
input_mat = tf.placeholder(tf.int32)
x=zero_out_module.zero_out(to_zero=input_mat)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True, allow_soft_placement=False))
print(sess.run(x, feed_dict={input_mat:inmat}))
解决方案
推荐阅读
- python - 使用带有 conda 环境的 docker 部署到 Heroku 后无法访问烧瓶应用程序
- python - 使用 for 循环引用变量
- angular - AmCharts4:导出菜单不提供 CSV、XLSX 和 JSON 选项
- javascript - 为什么 yAxisBar 没有显示正确的数据,D3
- python - django form.is_valid 返回 false
- gremlin - Gremlin - 如何展平分组输出
- excel - 无法从 Mac 刷新 Excel 中的外部连接
- php - 更改后续行的类
- javascript - 反应原生 redux 传奇 TypeError
- asp.net - 在 Kendo ASP.NET MVC 网格中格式化日期列