首页 > 解决方案 > 如何在张量中找到与子字符串匹配的元素

问题描述

我有一个案例,我需要知道在训练期间提取示例的文件名中的前缀是什么。

在我的数据集生成器中,我有一个当前文件名的张量'sourceA_stuff.tfrecords'or'sourceB_stuff.tfrecords'等​​。我想确定张量中的哪个元素sources = ['sourceA', 'sourceB']与文件名的前缀匹配,并将该元素的索引作为源标签传递。如果没有急切执行,我会遇到问题,如果可以避免,我真的不想使用急切执行。最小示例如下(见底部注释):

filename = tf.cast('sourceA_stuff.tfrecords', tf.string)
sources = ['sourceA', 'sourceB']
for i in range(len(sources)):
    if sources[i] in filename:
        source = tf.cast(i, tf.int32)
        break

TypeError:张量对象仅在启用急切执行时才可迭代。要迭代此张量,请使用tf.map_fn.

问题是我无法弄清楚如何使用子字符串匹配tf.map_fn来基本上模拟where查询,而且我无法找到一个好方法来解决我正在尝试做的事情而无需迭代。

也试过:

source = [i for i in range(len(sources)) if source[i] in filename]

同样的交易。

注意:现在在我的电脑上测试这个有问题。如有必要,将更新修复程序。

标签: pythontensorflow

解决方案


以下应该工作。

import tensorflow as tf

filename = tf.cast('sourceB_stuff.tfrecords', tf.string)
sources = tf.constant(['sourceA.+', 'sourceB.+'])

tf_label = tf.argmax(tf.cast(tf.map_fn(lambda x: tf.strings.regex_full_match(filename, x), sources, dtype=tf.bool), tf.int32))

with tf.Session() as sess:

  print(sess.run(tf_label))

需要注意的一点:

  • TensorFlow 仍然没有startswith()字符串操作类型。所以我能找到的最接近的regex_full_match意思是你需要一个匹配你要比较的完整字符串的正则表达式。

推荐阅读