首页 > 解决方案 > 将数据集从亚马逊 S3 加载到 EC2 上的 jupyter 笔记本

问题描述

我想尝试使用 AWS 进行深度学习的图像分割。我将数据存储在 Amazon S3 上,我想从在 Amazon EC2 实例上运行的 Jupyter Notebook 访问它。

我正计划使用 Tensorflow 进行分割,因此对我来说使用 Tensorflow 自己提供的选项(https://www.tensorflow.org/deploy/s3)似乎是合适的,因为我觉得最终我希望我的数据能够以 tf.Dataset 的格式表示。然而,这对我来说并不完全奏效。我尝试了以下方法:

filenames = ["s3://path_to_first_image.png", "s3://path_to_second_image.png"]
dataset = tf.data.TFRecordDataset(filenames)

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
for i in range(2):
    print(sess.run(next_element))

我收到以下错误:

OutOfRangeError: End of sequence
 [[Node: IteratorGetNext_6 = IteratorGetNext[output_shapes=[[]], output_types=[DT_STRING], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_6)]]

我对 tensorflow 很陌生,最近才开始尝试一些 AWS 的东西,所以我希望我的错误对于有更多经验的人来说是显而易见的。我将不胜感激任何帮助或建议!也许这甚至是错误的方式,我最好使用 boto3 之类的东西(也偶然发现它,但认为 tf 在我的情况下更合适)或其他东西?

PS Tensorflow 还建议使用以下部分测试设置:

from tensorflow.python.lib.io import file_io
print (file_io.stat('s3://path_to_image.png'))

对我来说,这会导致Object doesn't exist错误,尽管该对象肯定存在并且如果我使用它会在其他列表中列出

for obj in s3.Bucket(name=MY_BUCKET_NAME).objects.all():
print(os.path.join(obj.bucket_name, obj.key))

我也填写了我的凭据/.aws/credentials。这里可能有什么问题?

标签: amazon-web-servicestensorflowamazon-s3amazon-ec2jupyter-notebook

解决方案


不是对您的问题的直接回答,但我仍然注意到为什么您无法使用 Tensorflow 加载数据。

文件名中.png的文件不是.tfrecord二进制存储格式的文件格式。所以,tf.data.TFRecordDataset(filenames)不应该工作。

我认为以下将起作用。注意:这是TF2的,不知道TF1是否一样。可以在 TensorFlow 的网站tensorflow 示例中找到类似的示例

步骤1

使用 .将文件加载到 TensorFlow 数据集中tf.data.Dataset.list_files

import tensorflow as tf

list_ds = tf.data.Dataset.list_files(filenames)

第2步

使用 ; 创建一个将应用于数据集中每个元素的函数map;这将对 TF 数据集中的每个元素使用该函数。

def process_path(file_path):
    '''reads the path and returns an image.''' 
    # load the raw data from the file as a string
    byteString = tf.io.read_file(file_path)
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_png(byteString, channels=3)
    return img

dataset = list_ds.map(preprocess_path)

第 3 步

查看图片。

import matplotlib.pyplot as plt

for image in dataset.take(1): plt.imshow(image)


推荐阅读