tensorflow - Keras Lambda 层何时生成随机数?
问题描述
我想将简单的数据增强(输入向量乘以随机标量)应用于在 Keras 中实现的完全连接的神经网络。Keras 具有很好的图像增强功能,但尝试使用它对于我的输入(1-张量)来说似乎很尴尬和缓慢,其训练数据集适合我的计算机内存。
相反,我想象我可以使用 Lambda 层来实现这一点,例如:
x = Input(shape=(10,))
y = x
y = Lambda(lambda z: random.uniform(0.5,1.0)*z)(y)
y = Dense(units=5, activation='relu')(y)
y = Dense(units=1, activation='sigmoid')(y)
model = Model(x, y)
我的问题是什么时候会生成这个随机数。这是否会修复一个随机数:
- 整个训练过程?
- 每批?
- 每个训练数据点?
解决方案
使用它将创建一个根本不会改变的常量,因为random.uniform
它不是 keras 函数。您在图中将此操作定义为constant * tensor
并且因子将是常数。
您需要“来自 keras”或“来自 tensorflow”的随机函数。例如,您可以采取K.random_uniform((1,), 0.5, 1.)
.
这将按批次更改。您可以通过对这段代码进行大量训练来测试它,并查看损失的变化。
from keras.layers import *
from keras.models import Model
from keras.callbacks import LambdaCallback
import numpy as np
ins = Input((1,))
outs = Lambda(lambda x: K.random_uniform((1,))*x)(ins)
model = Model(ins,outs)
print(model.predict(np.ones((1,1))))
print(model.predict(np.ones((1,1))))
print(model.predict(np.ones((1,1))))
model.compile('adam','mae')
model.fit(np.ones((100000,1)), np.ones((100000,1)))
如果您希望它针对每个训练样本进行更改,则获取一个固定的批量大小并为每个样本生成一个带有随机数的张量:K.random_uniform((batch_size,), .5, 1.)
.
如果您在自己的生成器中执行此操作,您可能会获得更好的性能model.fit_generator()
,但是:
class MyGenerator(keras.utils.Sequence):
def __init__(self, inputs, outputs, batchSize, minRand, maxRand):
self.inputs = inputs
self.outputs = outputs
self.batchSize = batchSize
self.minRand = minRand
self.maxRand = maxRand
#if you want shuffling
def on_epoch_end(self):
indices = np.array(range(len(self.inputs)))
np.random.shuffle(indices)
self.inputs = self.inputs[indices]
self.outputs = self.outputs[indices]
def __len__(self):
leng,rem = divmod(len(self.inputs), self.batchSize)
return (leng + (1 if rem > 0 else 0))
def __getitem__(self,i):
start = i*self.batchSize
end = start + self.batchSize
x = self.inputs[start:end] * random.uniform(self.minRand,self.maxRand)
y = self.outputs[start:end]
return x,y
推荐阅读
- python - Using tkinter to produce n labels where n is variable
- reactjs - 删除 React 中的对象实例:“汽车”返回未定义
- python - 在 seaborn barplot 顶部显示计数
- mysql - 如何修复损坏的 MySQL 安装?
- python - 如果 Python 编写的脚本可以在标准 Spark 中启动,我们为什么需要 PySpark?
- javascript - 根据我的引导轮播的位置设置文本
- javascript - 如何制作一个在单击按钮一定次数时执行的功能?
- bash - BASH 脚本中的 gzip 可以在文件完全解压缩之前给出退出状态吗?如何防止这种情况?
- javascript - Ember 3 计算属性
- microsoft-graph-api - 权限不足,无法完成操作 - Graph API