首页 > 解决方案 > 来自数据库的 TensorFlow DataLoader

问题描述

我的数据源是一个postgresql database. 我有多个表可以从中提取数据,并且我在固定的时间间隔内提取数据。我之前没有用 Tensorflow 做过太多的神经网络工作,之前我主要使用 Pytorch,这也是我第一次使用数据库而不是将完整的数据集读入内存。

我很好奇加载这些数据的一般管道应该是什么样子。

我已经构建了一个 API,用于从数据库中查询数据并将其转换为 2 个我打算转换为张量的数据帧。如果没有更好的方法,这是我的管道将是什么样子的草图。

batchsize = 4
interval = 60
n_batches = 10000
table_names = ['g1','g2','g3',...]
valid_time_starts = np.uniform(0,10)
model = init_model()

Xs = []
ys = []
for _ in range(n_batches):
   tables = np.random.select(table_names, batchsize)
   start_times = valid_time_starts(batch_size)
   for i, (time, tablename) in enumerate(zip(tables, start_times)):
       X, y = myFetchAPI(tablename, time, time+interval)
       X, y = tensorfy(X,y)
       Xs[i] = X
       ys[i] = y
   b_X, b_y = tf.concat(Xs), tf.concat(y)
   train(model, b_X, b_y)

并且train会运行一次迭代的训练。

我找到了这个文档,但它看起来无法处理更改表名和间隔时间段,并且需要进行大量处理才能将查询转换为正确的数据结构。(如果有问题可以详细说明。)

有没有更聪明的方法可以做到这一点(最好与延迟加载并行),或者其他解决方案没有比这更好的方法吗?

标签: pythondatabasetensorflow

解决方案


推荐阅读