tfx - 如何使用 Tensorflow Transform 动态计算 tf.one_hot 中深度参数的多类分类标签数量?
问题描述
我有一个多类 NLP 分类问题,我需要针对大约 1000 个类别标签训练大约 100 万个文本样本。随着未来的数据集被输入管道,这个唯一类别标签的数量将略有变化。
为此,我需要将 tf.one_hot 中的深度参数设置为该训练实例的动态确定的唯一标签的数量。
要知道唯一标签的全部数量,我知道我需要完整地遍历数据。所以,我被困的地方是如何计算这个数字。
我认为 tft.size 适合获得这个完整的通行证,但它似乎不起作用。当我硬编码 1000 时,您可以在下面看到它工作正常:
labels = inputs[LABEL_KEY]
sparse_labels_tokens = tft.compute_and_apply_vocabulary(labels, vocab_filename=LABEL_VOCAB_FILE_NAME)
dense_labels_tokens = tf.sparse.to_dense(sparse_labels_tokens)
#labels_count = tf.cast( tft.size(dense_labels_tokens), tf.int32 ) #FIXME
labels_count = 1000
labels_one_hot = tf.one_hot(dense_labels_tokens, depth=labels_count)
labels_indicators = tf.reduce_max(labels_one_hot, axis=1)
outputs[transformed_name(LABEL_KEY)] = labels_indicators
outputs[LABEL_KEY] = _fill_in_missing(inputs[LABEL_KEY])
给予:
# Iterate over the first few tfrecords and decode them.
for tfrecord in dataset.take(5):
serialized_example = tfrecord.numpy()
example = tf.train.Example()
example.ParseFromString(serialized_example)
pprint.pprint(example)```
feature {
key: "label_xf"
value {
float_list {
value: 0.0
value: 0.0
value: 1.0
value: 0.0
但是,如果我改为使用 tft.size 我会收到以下错误:
...~/.local/lib/python3.6/site-packages/tensorflow_transform/schema_inference.py in _infer_feature_schema_common(features, tensor_ranges, feature_annotations, global_annotations)
241 domains[name] = schema_pb2.IntDomain(
242 min=min_value, max=max_value, is_categorical=True)
--> 243 feature_spec = _feature_spec_from_batched_tensors(features)
244
245 schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)
~/.local/lib/python3.6/site-packages/tensorflow_transform/schema_inference.py in _feature_spec_from_batched_tensors(tensors)
86 'Feature {} ({}) had invalid shape {} for FixedLenFeature: apart '
87 'from the batch dimension, all dimensions must have known size'
---> 88 .format(name, tensor, shape))
89 feature_spec[name] = tf.io.FixedLenFeature(shape.as_list()[1:],
90 tensor.dtype)
ValueError: Feature label_xf (Tensor("Max:0", shape=(None, None), dtype=float32)) had invalid shape (None, None) for FixedLenFeature: apart from the batch dimension, all dimensions must have known size
This cell will be skipped during export to pipeline.
我可以将深度硬编码到 1500 并交叉手指说标签数量永远不会超过这个值,但我不确定如果我这样做的话我是否能够和自己一起生活:(
解决方案
推荐阅读
- android - Android材质按钮采用颜色原色而不是颜色重音
- javascript - 如何根据条件更改进度条的颜色
- java - 将字符串中的连续数字值转换为数字格式
- ansible - Ansible items.key in when 条件
- android - 如何保存底部导航片段的状态 - 具有单个导航图的 Android 导航组件
- html - 搜索输入在展开时推送链接列表
- python - 如何绘制具有多个边缘属性的networkx图?
- c# - Fluent Ribbon 未在 Visual Studio 2019 中编译
- html - 如何垂直引导列表组文本
- python - 嵌套列表的最小值/最大值