首页 > 解决方案 > TensorFlow:如何在 SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) 中定义输出形状

问题描述

输出张量定义如下:

// Grab the input tensor
const Tensor& I_tensor = context->input(0);
const Tensor& W_tensor = context->input(1);
auto Input = I_tensor.flat<float>();
auto Weight = W_tensor.flat<float>();
// OP_REQUIRES(context, iA_tensor.dims()==2 && iB_tensor.dims()==2);

int B = I_tensor.dim_size(1);
int nH= I_tensor.dim_size(2);
int nW= I_tensor.dim_size(3);
int C = I_tensor.dim_size(4);
int K = W_tensor.dim_size(2);

// Create an output tensor
Tensor* O_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{B, 2*nH, 2*nW, K}, &O_tensor));

我尝试在 SetShapeFn 中使用相同的代码,但它不起作用。

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
  // // const Tensor& I_tensor = c->input(0); //This doesn't work
  // // const Tensor& W_tensor = c->input(1); //This doesn't work

  ::tensorflow::shape_inference::ShapeHandle I = c->input(0);
  ::tensorflow::shape_inference::ShapeHandle W = c->input(1);
  //I have no idea how to get the desired shape from I and W

  return Status::OK();
})

我试图找到有关 ShapeHandle 和 InferenceContext 的更多信息,以获得所需的结果,但未能这样做。如果有人可以帮助我,我将不胜感激。

谢谢!

标签: pythontensorflow

解决方案


这是我用来获取尺寸信息和创建新尺寸信息的:

auto N = c->Dim(c->input(0), 0);
auto H = c->Dim(c->input(0), 1);
auto W = c->Dim(c->input(0), 2);
auto C = c->Dim(c->input(0), 3);

// save stride so we can use it to calculate output shape
std::vector<float> stride;
c->GetAttr("stride", &stride);
H = c->MakeDim(round(c->Value(H)/stride[0]));
W = c->MakeDim(round(c->Value(W)/stride[1]));
c->set_output(0, c->MakeShape({N, H, W, C}));

推荐阅读