python - 如何通过沿某个轴交错来合并两个 3D 张量?
问题描述
假设我想通过沿一个特定轴交错来合并两个 3D 张量流张量 a 和 b。例如,张量 a 的形状为 (3,3,2),张量 b 的形状为 (3,2,2)。我想创建一个沿轴 1 交错的张量 c,从而产生形状为 (3,5,2) 的张量。
例子:
a = [[[1,1],[2,2],[3,3]],
[[4,4],[5,5],[6,6]],
[[7,7],[8,8],[9,9]]]
b = [[[10,10],[11,11]],
[[12,12],[13,13]],
[[14,14],[15,15]]]
c = [[[1,1],[10,10],[2,2],[11,11],[3,3]],
[[4,4],[12,12],[5,5],[13,13],[6,6]],
[[7,7],[14,14],[8,8],[15,15],[9,9]]]
解决方案
您可以先重新排序列的索引。
import tensorflow as tf
a = [[[1,1],[2,2],[3,3]],
[[4,4],[5,5],[6,6]],
[[7,7],[8,8],[9,9]]]
b = [[[10,10],[11,11]],
[[12,12],[13,13]],
[[14,14],[15,15]]]
a_tf = tf.constant(a)
b_tf = tf.constant(b)
a_tf_column = tf.range(a_tf.shape[1])*2
# [0 2 4]
b_tf_column = tf.range(b_tf.shape[1])*2+1
# [1 3]
column_indices = tf.concat([a_tf_column,b_tf_column],axis=-1)
# Before TF v1.13
column_indices = tf.contrib.framework.argsort(column_indices)
## From TF v1.13
# column_indices = tf.argsort(column_indices)
# [0 3 1 4 2]
然后你应该为tf.gather_nd()
.
column,row = tf.meshgrid(column_indices,tf.range(a_tf.shape[0]))
combine_indices = tf.stack([row,column],axis=-1)
# [[[0,0],[0,3],[0,1],[0,4],[0,2]],
# [[1,0],[1,3],[1,1],[1,4],[1,2]],
# [[2,0],[2,3],[2,1],[2,4],[2,2]]]
最后,您应该连接 and 的值a
并b
使用tf.gather_nd()
来获得结果。
combine_value = tf.concat([a_tf,b_tf],axis=1)
result = tf.gather_nd(combine_value,combine_indices)
with tf.Session() as sess:
print(sess.run(result))
# [[[1,1],[10,10],[2,2],[11,11],[3,3]],
# [[4,4],[12,12],[5,5],[13,13],[6,6]],
# [[7,7],[14,14],[8,8],[15,15],[9,9]]]
推荐阅读
- c++ - libcurl 无法正确下载图像文件
- flutter - 使用 Aqueduct 提供的文件没有 Content-Length 标头
- angular - 托管在 Pivotal Cloud Foundry 中的 Angular 应用程序中的安全凭证
- kubernetes - 创建 Kubernetes 存储 Google Compute Engine
- dataframe - 在apache spark SQL中,如何在窗口函数中使用collect_list时删除重复的行?
- java - Else condition confusion on Scanner class
- google-bigquery - 如何禁止将重复行加载到 BigQuery?
- java - Spring Configuration 类放置最佳实践
- kubernetes-helm - HELM _helpers.tpl:如何检查 _helpers.tpl 中是否定义了模板块?
- javascript - 获取模式的左上角坐标(jquery kendo-ui)