python - 如何从特定类别中抽取批次?
问题描述
我想在一个 ImageNet 数据集(1000 个类,每个类大约 1300 张图像)上训练一个分类器。出于某种原因,我需要每批包含来自特定类的 64 个图像(作为int
或占位符提供)。如何使用最新的 TensorFlow 高效地做到这一点?
这是How to sample batch from only one class at each iteration的后续问题。
我目前的想法是使用tf.data.Dataset.filter
:
specific_class = 2 # as an example
dataset = tf.data.TFRecordDataset(filenames)
# __parser_fun__ produces datum tuple (x, y)
dataset = dataset.map(__parser_fun__, num_parallel_calls=num_threads)
dataset = dataset.shuffle(20000)
# print(dataset) gives <ShuffleDataset shapes: ((3, 128, 128), (1,)),
# types: (tf.float32, tf.int64)>
dataset = dataset.filter(lambda x, y: tf.equal(y[0], specific_class))
dataset = dataset.batch(64)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch, y_batch = iterator.get_next()
一个小问题filter
是每次我想从新类中采样时都需要构造一个迭代器。
另一个想法是使用tf.contrib.data.rejection_resample
,但它在计算上似乎令人望而却步(或者是吗?)。
我想知道是否有其他有效的方法可以从特定类别中对批次进行抽样?
解决方案
从概念上讲,您的数据集由变量(要采样的标签)参数化。这是完全可行的!
急切地执行:
import numpy as np
import tensorflow as tf
tf.enable_eager_execution()
data = dict(
x=tf.constant([1., 2., 3., 4.]),
y=tf.constant([1, 2, 1, 2])
)
requested_label = tf.Variable(1)
dataset = (
tf.data.Dataset.from_tensor_slices(data)
.repeat()
.filter(lambda d: tf.equal(d["y"], requested_label)))
it = dataset.make_one_shot_iterator()
for i, datum in enumerate(it):
assert int(datum["y"]) == 1
assert float(datum["x"]) in [1., 3.]
if i > 5:
break
requested_label.assign(2)
for i, datum in enumerate(it):
assert int(datum["y"]) == 2
assert float(datum["x"]) in [2., 4.]
if i > 5:
break
图表构建:
import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
data = dict(
x=tf.constant([1., 2., 3., 4.]),
y=tf.constant([1, 2, 1, 2])
)
requested_label = tf.Variable(1)
dataset = (
tf.data.Dataset.from_tensor_slices(data)
.repeat()
.filter(lambda d: tf.equal(d["y"], requested_label)))
it = dataset.make_initializable_iterator()
datum_tensors = it.get_next()
switch_label_op = requested_label.assign(2)
graph.finalize()
with tf.Session() as session:
session.run(requested_label.initializer) # label=1
session.run(it.initializer)
for _ in range(5):
datum = session.run(datum_tensors)
assert int(datum["y"]) == 1
assert float(datum["x"]) in [1., 3.]
session.run(switch_label_op) # label=2
for _ in range(5):
datum = session.run(datum_tensors)
assert int(datum["y"]) == 2
assert float(datum["x"]) in [2., 4.]
推荐阅读
- angular - Angular:在监视更改时注入 NgControl
- macos - niDAQmx base的下载页面不再列出macos?
- excel - VBA将多行插入表问题
- android - 无法写入我的 firebase 实时数据库
- angular - `simplebar` 不起作用,如何与 angular 集成
- r - 使用 blogdown 插件插入图像的问题
- git - 使用 revert commit 将功能分支重新定位到另一个功能分支
- javascript - 尝试使用 async/await 代替 browser.sleep,但它不起作用
- sql-server - “OPENROWSET”函数一直执行,没有任何输出
- ios - 如何使用 CoreNFC 检测 Mifare Plus 芯片