python - 张量流中不规则/变化的批量大小?
问题描述
我有一个 tensorflow 数据集,并希望对它进行批处理,以使批次的大小不同 - 例如将示例分组为批次,其大小由值向量而不是固定值定义。
有没有办法在张量流中做到这一点?
对于一个没有固定批量大小的网络,喂不规则批量是否会成为问题?
提前致谢!
解决方案
答案是肯定的。model.fit() 方法允许将生成器传递给它,该生成器将生成随机大小的批次。
x_train_batches = ... # some list of data batches of uneven length
y_train_batches = ... # some list of targets of SAME lengths as x_train_batches
def train_gen(x_train_batches, y_train_batches):
i = 0
num_batches = len(x_train_batches)
while True:
yield (x_train_batches[i%num_batches], y_train_batches[i%num_batches])
i += 1
model.fit(train_gen(x_train_batches, y_train_batches), epochs=epochs, steps_per_epoch=NUM_BATCHES)
另一种更优雅的方法是子类tf.keras.utils.Sequence
化并将其提供给模型:
class UnevenSequence(keras.utils.Sequence):
def __init__(self, x_batches, y_batches):
# x_batches, y_batches are lists of uneven batches
self.x, self.y = x_batches, y_batches
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
batch_x = self.x[idx]
batch_y = self.y[idx]
return (batch_x, batch_y)
my_uneven_sequence = UnevenSequence(x_train_batches, y_train_batches)
model.fit(my_uneven_sequence, epochs=10)
推荐阅读
- timer - RabbitMQ 中延迟消息的已知问题
- vue.js - Adding image and text to Vue Select dropdown
- c# - System.IO.IOException:'进程无法访问文件'File',因为它正被另一个进程使用。'
- javascript - React 16.4 - 手动表单输入填充以及来自 getDerivedStateFromProps 的更新?
- abap - 自定义F4的新纪录
- sqlite - 是否可以获得在颤振中为 Android 模拟器创建的 SQLite 数据库的 GUI?
- mysql - 仅从 mysql 表中选择最后一个日志,使用多列的 DISTINCT
- javascript - 使用鼠标单击按钮和使用开发工具单击有什么区别?
- swift - @propertyWrapper 段错误
- ios - Visual Studio Xamarin - 使用 web 服务构建 realease iOS 项目的问题