,python,tensorflow"/>

首页 > 解决方案 > 如何连接类型的稀疏张量

问题描述

当从 tensorflow 的操作中将一个常量张量与一个备用张量tf.sets.set_intersection连接起来时,它失败了。

import tensorflow as tf
sess = tf.Session()

a = tf.add(tf.range(10), tf.cast(tf.ones([10]), dtype=tf.int32))
b = tf.constant([0, 1, 1, 0, 0, 1, 1, 0, 0, 1])

# This set appears to be sorted, but that is not documented behavior.
s = tf.sets.set_intersection(a[None,:], b[None, :])
s = tf.concat([a, tf.convert_to_tensor(s)], axis=0)
fsort = tf.contrib.framework.sort(s.values)

with tf.Session() as sess:
    print(type(s))
    print(sess.run(s).values)
    print(sess.run(fsort))

错误如下:

TypeError                                 Traceback (most recent call last)
/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    526     try:
--> 527       str_values = [compat.as_bytes(x) for x in proto_values]
    528     except TypeError:

/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in <listcomp>(.0)
    526     try:
--> 527       str_values = [compat.as_bytes(x) for x in proto_values]
    528     except TypeError:

/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/util/compat.py in as_bytes(bytes_or_text, encoding)
     60     raise TypeError('Expected binary or unicode string, got %r' %
---> 61                     (bytes_or_text,))
     62 

TypeError: Expected binary or unicode string, got <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f9da1a8bc88>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-107-a3e76f2038b5> in <module>
      7 # This set appears to be sorted, but that is not documented behavior.
      8 s = tf.sets.set_intersection(a[None,:], b[None, :])
----> 9 s = tf.concat([a, tf.convert_to_tensor(s)], axis=0)
     10 fsort = tf.contrib.framework.sort(s.values)
     11 

/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, preferred_dtype)
   1048       name=name,
   1049       preferred_dtype=preferred_dtype,
-> 1050       as_ref=False)
   1051 
   1052 

/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx)
   1144 
   1145     if ret is None:
-> 1146       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1147 
   1148     if ret is NotImplemented:

/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    227                                          as_ref=False):
    228   _ = as_ref
--> 229   return constant(v, dtype=dtype, name=name)
    230 
    231 

/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape)
    206   tensor_value.tensor.CopyFrom(
    207       tensor_util.make_tensor_proto(
--> 208           value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    209   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    210   const_tensor = g.create_op(

/data00/tiger/jupyterhub_deploy/venv/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    529       raise TypeError("Failed to convert object of type %s to Tensor. "
    530                       "Contents: %s. Consider casting elements to a "
--> 531                       "supported type." % (type(values), values))
    532     tensor_proto.string_val.extend(str_values)
    533     return tensor_proto

TypeError: Failed to convert object of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. Contents: SparseTensor(indices=Tensor("DenseToDenseSetOperation_17:0", shape=(?, 2), dtype=int64), values=Tensor("DenseToDenseSetOperation_17:1", shape=(?,), dtype=int32), dense_shape=Tensor("DenseToDenseSetOperation_17:2", shape=(2,), dtype=int64)). Consider casting elements to a supported type.

标签: pythontensorflow

解决方案


而不是使用tf.convert_to_tensor(s)use tf.sparse.to_dense(s)。另外,我会将整个代码重写为:

tf.reset_default_graph()
a = tf.add(tf.range(10), tf.cast(tf.ones([10]), dtype=tf.int32))
b = tf.constant([0, 1, 1, 0, 0, 1, 1, 0, 0, 1])

# This set appears to be sorted, but that is not documented behavior.
s_sparse = tf.sets.set_intersection(a[None,:], b[None, :])
s_dense = tf.squeeze(tf.sparse.to_dense(s_sparse), axis=0)
s = tf.concat([a, s_dense], axis=0)
fsort = tf.contrib.framework.sort(s)

with tf.Session() as sess:
    print(type(s))
    print(sess.run(s))
    print(sess.run(fsort))

推荐阅读