python - Tensorflow通过张量迭代
问题描述
我使用 TensorFlow 1.13。但是,我收到一条错误消息,说除非我处于渴望模式,否则我无法遍历张量。有没有办法在不进入急切模式的情况下做到这一点?
with tf.Session(config=config) as sess:
context = tf.placeholder(tf.int32, [args.batch_size, None])
mask = tf.placeholder(tf.int32, [args.batch_size, 2])
output = model.model(hparams=hparams, X=context)
for batch_index in range(args.batch_size):
start = mask[batch_index][0]
end = mask[batch_index][1]
for i in range(start, end+1):
output['logits'][batch_index, i , context[batch_index,i]].assign(math.inf)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=context[:, 1:], logits=output['logits'][:, :-1]))
解决方案
您可以尝试使用tf.while_loop吗?您可以尝试以下代码段(可能对您的代码稍作修改),看看它是否有效?
import tensorflow as tf
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
context = tf.placeholder(tf.int32, [args.batch_size, None])
mask = tf.placeholder(tf.int32, [args.batch_size, 2])
output = model.model(hparams=hparams, X=context)
for batch_index in [0,1,2,3]: #I have assumed a dummy list cz we can't iterate through a 'Dimension'
start = mask[batch_index][0]
end = mask[batch_index][1]
i = tf.constant(0)
while_condition = lambda i: (tf.less(i, end)) & (tf.math.greater_equal(i,start))
def body(i):
return output['logits'][batch, i , context[batch,i]].assign(math.inf)
r = tf.while_loop(while_condition, body, [i])
# for i in range(start, end+1):
# output['logits'][batch, i , context[batch,i]].assign(math.inf)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=context[:, 1:], logits=output['logits'][:, :-1]))
推荐阅读
- java - 标记上的语法错误,错位的构造,
- arrays - 如何在 Jupyter Notebook 中查看完整的多维数组
- json - 使用 circe 递归地将 JSON 树转换为其他格式(XML、CSV 等)
- php - Blade模板中本地注册Vue组件的最佳实践(Laravel)
- python - python中的嵌套字典
- javascript - 从 NodeJS 中获取的子详细信息中遍历和构建 Tree
- firebase - Flutter Future:构建函数返回 null
- spring - 如何使用 spring-boot 为 Ldap 创建多个配置
- orm - 构建数据访问层
- flask - flask-admin:美化 json 字段