python - Tensorflow:TypeError:Fetch参数None的类型无效
问题描述
我正在尝试运行这个简单的程序来计算梯度,但是我得到了 None 错误:
import tensorflow as tf
import numpy as np
batch_size = 5
dim = 3
hidden_units = 8
sess = tf.Session()
with sess.as_default():
x = tf.placeholder(dtype=tf.float32, shape=[None, dim], name="x")
y = tf.placeholder(dtype=tf.int32, shape=[None], name="y")
w = tf.Variable(initial_value=tf.random_normal(shape=[dim, hidden_units]), name="w")
b = tf.Variable(initial_value=tf.zeros(shape=[hidden_units]), name="b")
logits = tf.nn.tanh(tf.matmul(x, w) + b)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y,name="xentropy")
# define model end
# begin training
optimizer = tf.train.GradientDescentOptimizer(1e-5)
grads_and_vars = optimizer.compute_gradients(cross_entropy, tf.trainable_variables())
# generate data
data = np.random.randn(batch_size, dim)
labels = np.random.randint(0, 10, size=batch_size)
sess.run(tf.initialize_all_variables())
gradients_and_vars = sess.run(grads_and_vars, feed_dict={x:data, y:labels})
for g, v in gradients_and_vars:
if g is not None:
print "****************this is variable*************"
print "variable's shape:", v.shape
print v
print "****************this is gradient*************"
print "gradient's shape:", g.shape
print g
sess.close()
错误 :
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-14-8096b2e21e06> in <module>()
29
30 sess.run(tf.initialize_all_variables())
---> 31 outnet = sess.run(grads_and_vars, feed_dict={x:data, y:labels})
32 # print(gradients_and_vars)
33 # if g is not None:
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
893 try:
894 result = self._run(None, fetches, feed_dict, options_ptr,
--> 895 run_metadata_ptr)
896 if run_metadata:
897 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1107 # Create a fetch handler to take care of the structure of fetches.
1108 fetch_handler = _FetchHandler(
-> 1109 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1110
1111 # Run request and get response.
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
411 """
412 with graph.as_default():
--> 413 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
414 self._fetches = []
415 self._targets = []
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
231 elif isinstance(fetch, (list, tuple)):
232 # NOTE(touts): This is also the code path for namedtuples.
--> 233 return _ListFetchMapper(fetch)
234 elif isinstance(fetch, dict):
235 return _DictFetchMapper(fetch)
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
338 """
339 self._fetch_type = type(fetches)
--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
342
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
338 """
339 self._fetch_type = type(fetches)
--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
342
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
231 elif isinstance(fetch, (list, tuple)):
232 # NOTE(touts): This is also the code path for namedtuples.
--> 233 return _ListFetchMapper(fetch)
234 elif isinstance(fetch, dict):
235 return _DictFetchMapper(fetch)
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
338 """
339 self._fetch_type = type(fetches)
--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
342
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
338 """
339 self._fetch_type = type(fetches)
--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
342
//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
228 if fetch is None:
229 raise TypeError('Fetch argument %r has invalid type %r' %
--> 230 (fetch, type(fetch)))
231 elif isinstance(fetch, (list, tuple)):
232 # NOTE(touts): This is also the code path for namedtuples.
TypeError: Fetch argument None has invalid type <class 'NoneType'>
为什么会出现此错误?版本问题?
解决方案
Gradients
None
如果图中它们之间没有显式连接,则返回。在您的代码中,似乎所有声明的变量都有连接,因此可能是从其他图表加载变量的情况。您可以使用:
print([v.name for v in tf.all_variables()])
并仅检查预期变量是此图表的一部分。
尝试这样的事情:
sess.run(tf.initialize_all_variables())
gradients_and_vars = sess.run([variable for grad,variable in grads_and_vars], feed_dict={x:data, y:labels})
print(gradients_and_vars)
推荐阅读
- r - 如何创建按变量分层的高级ggplot?
- http - HTTP SEARCH 方法是否标准化?
- solr - solr搜索中的两个csv
- python - 如何使用 python 脚本从文件中获取特定数据?
- javascript - 在打字稿中的一组对象上展开运算符
- java - 从多个模块初始化静态变量时,它们在 Junit 中如何工作?
- javascript - 在 java 脚本中将 Unix 时间转换为其他 GMT
- python - 使滚动条在填充了超出画布尺寸的文本的画布内工作
- python - Pandas - 将两个数据帧中的最近事件与条件连接起来
- python - 我有一个数据框,我想为每一行申请循环,如果多列的条件将数据存储在新数据框中