python - 来自数据库的 TensorFlow DataLoader
问题描述
我的数据源是一个postgresql database
. 我有多个表可以从中提取数据,并且我在固定的时间间隔内提取数据。我之前没有用 Tensorflow 做过太多的神经网络工作,之前我主要使用 Pytorch,这也是我第一次使用数据库而不是将完整的数据集读入内存。
我很好奇加载这些数据的一般管道应该是什么样子。
我已经构建了一个 API,用于从数据库中查询数据并将其转换为 2 个我打算转换为张量的数据帧。如果没有更好的方法,这是我的管道将是什么样子的草图。
batchsize = 4
interval = 60
n_batches = 10000
table_names = ['g1','g2','g3',...]
valid_time_starts = np.uniform(0,10)
model = init_model()
Xs = []
ys = []
for _ in range(n_batches):
tables = np.random.select(table_names, batchsize)
start_times = valid_time_starts(batch_size)
for i, (time, tablename) in enumerate(zip(tables, start_times)):
X, y = myFetchAPI(tablename, time, time+interval)
X, y = tensorfy(X,y)
Xs[i] = X
ys[i] = y
b_X, b_y = tf.concat(Xs), tf.concat(y)
train(model, b_X, b_y)
并且train
会运行一次迭代的训练。
我找到了这个文档,但它看起来无法处理更改表名和间隔时间段,并且需要进行大量处理才能将查询转换为正确的数据结构。(如果有问题可以详细说明。)
有没有更聪明的方法可以做到这一点(最好与延迟加载并行),或者其他解决方案没有比这更好的方法吗?
解决方案
推荐阅读
- android - Kotlin SAM 到 Lambda 的转换不起作用
- ampps - 在 Ampps 上安装 PHP-7.2(mac 版)
- git - 我可以在 Android Studio 中显示 Git 的 Tag 名称吗?
- linux - `socket.close()` 和 `socket.fromfd()` 之间用于关闭与套接字关联的文件描述符的区别?
- android - 如何在 Android 中读写字符设备(如 /dev/ttyS0)
- android - com.android.builder.internal.aapt.v2.Aapt2InternalException: AAPT2 aapt2-3.4.1-5326820-windows 守护进程
- javascript - 如何按条件和空值始终排在最后的 3 个字段对对象数组进行排序
- c++ - 与在构造函数中将非常量左值绑定到右值相关的错误
- python - 无法安装模块 - AttributeError: 'NoneType' 对象没有属性 'loader'
- weblogic12c - 如何在 jython 中使用 WLST 部署应用程序之前在 weblogic.xml 中指定会话描述符