python - 为训练 Tensorflow 网络提供 spark 数据帧的最佳实践
问题描述
我想提供来自火花集群的数据,以训练深度网络。我在节点中没有 GPU,因此分布式 TensorFlow 或类似elephas
的软件包不是一个选项。
我想出了以下可以完成这项工作的生成器。它只是从 Spark 中检索下一批。为了处理批次,我添加了一个额外的列index
(这是简单的增量 id 列),并在每次调用下一批时对其进行过滤。
class SparkBatchGenerator(tfk.utils.Sequence):
def __init__(self, spark_df, batch_size, sample_count=None, feature_col='features', label_col='labels'):
w = Window().partitionBy(sf.lit('a')).orderBy(sf.lit('a'))
df = spark_df.withColumn('index', sf.row_number().over(w)).sort('index')
self.X = df.select([feature_col, 'index'])
self.y = df.select([label_col, 'index'])
self.data_count = sample_count if sample_count else spark_df.count()
self.feature_col = feature_col
self.label_col = label_col
self.batch_size = batch_size
def __len__(self):
return np.ceil(self.data_count /self.batch_size).astype(int)
def __getitem__(self, idx):
start, end = idx * self.batch_size, (idx + 1) * self.batch_size
batch_x = (
self.X.filter(f'index >= {start} and index < {end}')
.toPandas()[self.feature_col]
.apply(lambda x: x.toArray()).tolist()
)
batch_y = (
self.y.filter(f'index >= {start} and index < {end}')
.toPandas()[self.label_col].tolist()
)
return np.array(batch_x), np.array(batch_y)
这可行,但当然很慢,特别是在batch_size
很小的时候。我只是想知道是否有人有更好的解决方案。
解决方案
我曾经tf.data.Dataset
处理过这个。我可以缓冲来自 spark 的数据,然后将批量创建工作留给 tensorflow dataset api。现在更快了:
class MyGenerator(object):
def __init__(
self, spark_df, buffer_size, feature_col="features", label_col="labels"
):
w = Window().partitionBy(sf.lit("a")).orderBy(sf.lit("a"))
self.df = (
spark_df.withColumn("index", sf.row_number().over(w) - 1)
.sort("index")
.select([feature_col, "index", label_col])
)
self.feature_col = feature_col
self.label_col = label_col
self.buffer_size = buffer_size
def generate_data(self):
idx = 0
buffer_counter = 0
buffer = self.df.filter(
f"index >= {idx} and index < {self.buffer_size}"
).toPandas()
while len(buffer) > 0:
if idx < len(buffer):
X = buffer.iloc[idx][self.feature_col].toArray() / 255.0
y = buffer.iloc[idx][self.label_col]
idx += 1
yield X.reshape((28, 28)), y
else:
buffer = self.df.filter(
f"index >= {buffer_counter * self.buffer_size} "
f"and index < {(buffer_counter + 1) * self.buffer_size}"
).toPandas()
idx = 0
buffer_counter += 1
batch_size = 128
buffer_size = 4*1024
my_gen = MyGenerator(feature_df, buffer_size)
dataset = tf.data.Dataset.from_generator(my_gen.generate_data, output_types=(tf.float32, tf.int32))
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE).batch(batch_size, drop_remainder=True)
推荐阅读
- javascript - 电子邮件中的图像不会与 expresshandlebars 和 nodemailer 一起显示
- scala - 当特征线性化覆盖方法时如何启用警告?
- c++ - 为什么 (C++) 类型是常规类型是个好主意?
- python - 在 DeepAR 训练作业期间解析 json 时出错 - 如何从目标中删除 Nan 值
- java - 在 Eclipse IDE 中创建 Maven 项目时出错
- sql - 创建 PL/SQL 过程和序列的问题
- kubernetes - 如何使kubectl输出一个带有合法json的jsonpath输出的地图?
- javascript - javascript如何通过组合第一个数组和它们的值在其他数组中来形成一个对象数组
- sql - 扩展存储过程对 SQL Server 的影响
- javascript - 仅通过接口 TypeScript 键入函数的参数