python - TensorFlow:如何在 SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) 中定义输出形状
问题描述
输出张量定义如下:
// Grab the input tensor
const Tensor& I_tensor = context->input(0);
const Tensor& W_tensor = context->input(1);
auto Input = I_tensor.flat<float>();
auto Weight = W_tensor.flat<float>();
// OP_REQUIRES(context, iA_tensor.dims()==2 && iB_tensor.dims()==2);
int B = I_tensor.dim_size(1);
int nH= I_tensor.dim_size(2);
int nW= I_tensor.dim_size(3);
int C = I_tensor.dim_size(4);
int K = W_tensor.dim_size(2);
// Create an output tensor
Tensor* O_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{B, 2*nH, 2*nW, K}, &O_tensor));
我尝试在 SetShapeFn 中使用相同的代码,但它不起作用。
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
// // const Tensor& I_tensor = c->input(0); //This doesn't work
// // const Tensor& W_tensor = c->input(1); //This doesn't work
::tensorflow::shape_inference::ShapeHandle I = c->input(0);
::tensorflow::shape_inference::ShapeHandle W = c->input(1);
//I have no idea how to get the desired shape from I and W
return Status::OK();
})
我试图找到有关 ShapeHandle 和 InferenceContext 的更多信息,以获得所需的结果,但未能这样做。如果有人可以帮助我,我将不胜感激。
谢谢!
解决方案
这是我用来获取尺寸信息和创建新尺寸信息的:
auto N = c->Dim(c->input(0), 0);
auto H = c->Dim(c->input(0), 1);
auto W = c->Dim(c->input(0), 2);
auto C = c->Dim(c->input(0), 3);
// save stride so we can use it to calculate output shape
std::vector<float> stride;
c->GetAttr("stride", &stride);
H = c->MakeDim(round(c->Value(H)/stride[0]));
W = c->MakeDim(round(c->Value(W)/stride[1]));
c->set_output(0, c->MakeShape({N, H, W, C}));
推荐阅读
- python - Python:如何编写带有可选参数的“主”函数
- java - 如何禁用 iTextRenderer“createPDF”的警告输出
- java - 如何从外循环开始迭代?
- flutter - Flutter FirebaseMessaging - StreamSubscription 未正确处理
- vb.net - 如何使标签或图片框不那么不透明?
- qemu - 对模拟代码中字节序的重要性感到非常困惑
- azure - Azure 服务总线消息计数中 -- 和 0 之间的差异
- c - (GStreamer) rtmpsink 无法正常打开?
- csv - Bash - awk 比较未按预期工作
- python - Pandas - 根据其他列因子计数更新列值