python - 使用 tf.data 拆分 mnist 数据集
问题描述
我正在尝试使用 tf.data 将具有 60000 个图像的训练 mnist 数据集拆分为 55000 个训练图像和 5000 个验证图像。
当我跑步时
session_config = tf.ConfigProto(log_device_placement=False)
config = tf.estimator.RunConfig(tf_random_seed=230,
model_dir=chpt_dir_path,
save_checkpoints_steps=params["save_checkpoints_steps"],
keep_checkpoint_max=params["keep_checkpoint_max"],
session_config=session_config)
estimator = tf.estimator.Estimator(model_fn=model_fn, params=params, config=config)
train_dataset, valid_dataset = train_input_fn(args.DATA_DIR_PATH, params)
estimator.train(lambda: train_dataset)
错误是:
Tensor("Iterator:0", shape=(), dtype=resource) 必须与 Tensor("PrefetchDataset:0", shape=(), dtype=variant) 来自同一个图。
问题来自这个函数
def train_input_fn(data_dir_path, params):
"""Train input function for the MNIST dataset.
Args:
data_dir: (string) path to the data directory
params: (Params) contains hyperparameters of the model (ex: `params.num_epochs`)
"""
dataset = train(data_dir_path)
dataset = dataset.shuffle(params["train_size"] + params["valid_size"], seed=416) # whole dataset into the buffer
train_dataset = dataset.take(params["train_size"])
valid_dataset = dataset.skip(params["train_size"])
train_dataset = train_dataset.batch(params["batch_size"])
train_dataset = train_dataset.shuffle(params["train_size"])
train_dataset = train_dataset.prefetch(1) # make sure you always have one batch ready to serve
valid_dataset = valid_dataset.batch(params["batch_size"])
valid_dataset = valid_dataset.shuffle(params["valid_size"])
valid_dataset = valid_dataset.prefetch(1) # make sure you always have one batch ready to serve
return train_dataset, valid_dataset
我不知道如何解决这个问题。有人知道怎么做吗?或者有更好的方法用 td.data 分割数据?
以下代码用于加载 mnist 数据集并创建数据管道。
def download(data_dir_path, filename):
"""Download (and unzip) a file from the MNIST dataset if not already done."""
filepath = os.path.join(data_dir_path, filename)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(data_dir_path):
tf.gfile.MakeDirs(data_dir_path)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = "https://storage.googleapis.com/cvdf-datasets/mnist/" + filename + ".gz"
zipped_filepath = filepath + ".gz"
print("Downloading %s to %s" % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, "rb") as f_in, open(filepath, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(data_dir_path, images_file, labels_file):
images_file_path = download(data_dir_path, images_file)
labels_file_path = download(data_dir_path, labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def decode_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
label = tf.reshape(label, []) # label is a scalar
return tf.to_int32(label)
images = tf.data.FixedLengthRecordDataset(images_file_path, 28 * 28, header_bytes=16)
images = images.map(decode_image)
labels = tf.data.FixedLengthRecordDataset(labels_file_path, 1, header_bytes=8)
labels = labels.map(decode_label)
return tf.data.Dataset.zip((images, labels))
def train(data_dir_path):
"""tf.data.Dataset object for MNIST training data."""
return dataset(data_dir_path, "train-images-idx3-ubyte", "train-labels-idx1-ubyte")
解决方案
推荐阅读
- python-3.x - 在 Python 中创建嵌套字典
- javascript - Swiper Slider 3d Cube Slide 不工作?
- firebase - 如何使用邮递员将我的推送通知发送给特定用户?
- css - 在某些表格列上应用 box-shadow
- webserver - 从 Wan 访问网络服务器(位于 Lan)
- flutter - 摆脱涟漪效应 TextButton Flutter
- python - 使用 opencv solvePnP 将 2D 像素坐标转换为 3D 世界坐标
- powerbi - 如果我使用 .pbix 文件上传报表,然后在 Power BI 服务中更改报表,会发生什么情况?这些更改会被覆盖吗?
- bash - 如何在 shell 脚本中切换到 LDAP 用户并在 NFS 挂载上执行事件?
- optimization - Yosys synthesys - 这是最好的吗?