python - 使用数组查找 TensorFlow 哈希表
问题描述
我正在尝试获得HashMap
与 tensorflow 一起使用的功能类型。int
当键和值是类型时,我让它工作。但是当它们是数组时,它会给出错误 -ValueError: Shapes (2,) and () are not compatible
在线default_value)
import numpy as np
import tensorflow as tf
input_tensor = tf.constant([1, 1], dtype=tf.int64)
keys = tf.constant(np.array([[1, 1],[2, 2],[3, 3]]), dtype=tf.int64)
values = tf.constant(np.array([[4, 1],[5, 1],[6, 1]]), dtype=tf.int64)
default_value = tf.constant(np.array([1, 1]), dtype=tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_value)
out = table.lookup(input_tensor)
with tf.Session() as sess:
table.init.run()
print(out.eval())
解决方案
不幸的是,tf.contrib.lookup.HashTable
仅适用于一维张量。这是一个tf.SparseTensor
s 的实现,当然只有当你的键是整数(int32 或 int64)张量时才有效。
对于值,我将两列存储在两个单独的张量中,但如果您有很多列,您可能只想将它们存储在一个大张量中,并将索引作为值存储在 one 中tf.SparseTensor
。
此代码(已测试):
import tensorflow as tf
lookup = tf.placeholder( shape = ( 2, ), dtype = tf.int64 )
default_value = tf.constant( [ 1, 1 ], dtype = tf.int64 )
input_tensor = tf.constant( [ 1, 1 ], dtype=tf.int64)
keys = tf.constant( [ [ 1, 2 ], [ 3, 4 ], [ 5, 6 ] ], dtype=tf.int64 )
values = tf.constant( [ [ 4, 1 ], [ 5, 1 ], [ 6, 1 ] ], dtype=tf.int64 )
val0 = values[ :, 0 ]
val1 = values[ :, 1 ]
st0 = tf.SparseTensor( keys, val0, dense_shape = ( 7, 7 ) )
st1 = tf.SparseTensor( keys, val1, dense_shape = ( 7, 7 ) )
x0 = tf.sparse_slice( st0, lookup, [ 1, 1 ] )
y0 = tf.reshape( tf.sparse_tensor_to_dense( x0, default_value = default_value[ 0 ] ), () )
x1 = tf.sparse_slice( st1, lookup, [ 1, 1 ] )
y1 = tf.reshape( tf.sparse_tensor_to_dense( x1, default_value = default_value[ 1 ] ), () )
y = tf.stack( [ y0, y1 ], axis = 0 )
with tf.Session() as sess:
print( sess.run( y, feed_dict = { lookup : [ 1, 2 ] } ) )
print( sess.run( y, feed_dict = { lookup : [ 1, 1 ] } ) )
将输出:
[4 1]
[1 1]
根据需要(查找键[ 1, 2 ]的值[ 4, 1 ]和 [ 1, 1 ] 的默认值[ 1, 1 ],它指向不存在的条目。)
推荐阅读
- padding - 如何调整美人鱼节点内的填充?
- python - 匹配所有“错误:”但不匹配“错误:0”
- c++ - 为什么平台之间的 inptr_t 行为不同?
- mysql - Codeigniter MySQL 按会话数据过滤
- java - 无法在 IntelliJ 的 Kotlin 项目中从模块创建 Jar
- python - 如何读取/导入/加载许多 .mat 文件以在 python 中进行训练?
- javascript - 如何使用 React 检查远程源是否可用?
- reactjs - 使用反应钩子聚焦输入
- javascript - 云功能未在数据库写入时触发
- html - 如何在重新加载时保持引导表的滚动位置?