tensorflow - 根据图形时间未知的大小拆分张量
问题描述
在张量流中,我想做以下事情:
- 接收 N 个一维张量
- 将它们连接为形状 [m] 的大一维张量
- 调用处理该张量并生成形状为 [m] 的张量的函数
- 将结果张量拆分为 N 个一维张量
但是在创建图形时,我不知道每个一维张量的大小,这会产生问题。这是我正在做的一个片段:
def stack(tensors):
sizes = tf.convert_to_tensor([t.shape[0].value for t in tensors])
tensor_stacked = tf.concat(tensors, axis=0)
res = my_function(tensor_stacked)
return tf.split(res, sizes, 0)
tensor_A = tf.placeholder(
tf.int32,
shape=[None],
name=None
)
tensor_B = tf.placeholder(
tf.int32,
shape=[None],
name=None
)
res = stack([tensor_A, tensor_B])
这将在带有消息的“concat”行上失败
TypeError:无法将类型对象转换为张量。内容:[无,无]。考虑将元素转换为支持的类型。
有什么办法可以在 tensorflow 中做到这一点?在图形时间,“大小”变量将始终包含未知大小,因为一维张量的长度永远未知
解决方案
好的,同时我找到了答案
显然它足以取代对 to 的tensor.shape[0]
调用tf.shape(tensor)[0]
所以现在我有:
def stack(tensors):
sizes = tf.convert_to_tensor([tf.shape(t)[0] for t in tensors])
print(sizes)
tensor_stacked = tf.concat(tensors, axis=0)
res = my_function(tensor_stacked)
return tf.split(res, sizes, 0)
推荐阅读
- vba - 在 MS Word 2016 中选择和重新定位和图像
- electron - 电子打印问题
- javascript - 如何从本地存储中存储和检索 html div 元素的位置?
- php - Laravel 响应是 String 而不是 int
- python - 在 Locust 类中使用 setattr
- searchkick - 如何将 searchkick(在 Rails 应用程序和/ Sidekiq 作业中)连接到多个弹性搜索集群而不踩到全局搜索配置?
- vim - 如何使gruvbox背景变黑?
- go - 结构接收器到外部函数
- python - 使用 Sympy 求解具有总和和索引的方程
- reactjs - 使用 react router v6 实现受保护的路由