首页 > 解决方案 > 如何在 tensorflow 中使用 CSR 格式

问题描述

我发现在tensorflow 中有一个 API(DenseToCSRSparseMatrix) tf.raw_ops。但是,当我尝试进行从密集到 CSR 的简单对话时,它会引发以下异常:

Traceback (most recent call last):
  File "/Users/", line 33, in <module>
    print(sess.run(b))
  File "/Users/mac/sparse-exp/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 967, in run
    result = self._run(None, fetches, feed_dict, options_ptr,
  File "/Users/mac/sparse-exp/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1190, in _run
    results = self._do_run(handle, final_targets, final_fetches,
  File "/Users/mac/sparse-exp/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1368, in _do_run
    return self._do_call(_run_fn, feeds, fetches, targets, options,
  File "/Users/mac/sparse-exp/lib/python3.9/site-packages/tensorflow/python/client/session.py", line 1394, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Tensorflow type 21 not convertible to numpy dtype.

这是我的示例代码:

import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

ss = tf.constant([[1, 2, 3, 4, 5], [0, 0, 0, 2, 1]], dtype=tf.float32)
indecies = tf.where(tf.not_equal(ss, 0))

b = tf.raw_ops.DenseToCSRSparseMatrix(dense_input=ss, indices=indecies)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(b))

标签: tensorflow

解决方案


推荐阅读