python - Fix random seed for torchvision transforms
问题描述
I use some code similar to the following - for data augmentation:
from torchvision import transforms
#...
augmentation = transforms.Compose([
transforms.RandomApply([
transforms.RandomRotation([-30, 30])
], p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
])
During my testing I want to fix random values to reproduce the same random parameters each time I change the model training settings. How can I do it?
I want to do something similar to np.random.seed(0)
so each time I call random function with probability for the first time, it will run with the same rotation angle and probability. In other words, if I do not change the code at all, it must reproduce the same result when I rerun it.
Alternatively I can separate transforms, use p=1
, fix the angle min
and max
to a particular value and use numpy random numbers to generate results, but my question if I can do it keeping the code above unchanged.
解决方案
在__getitem__
您的数据集类中制作一个 numpy 随机种子。
def __getitem__(self, index):
img = io.imread(self.labels.iloc[index,0])
target = self.labels.iloc[index,1]
seed = np.random.randint(2147483647) # make a seed with numpy generator
random.seed(seed) # apply this seed to img transforms
if self.transform is not None:
img = self.transform(img)
random.seed(seed) # apply this seed to target transforms
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
推荐阅读
- excel - 如何在 Excel 中连接来自同一列和不同行的多个单元格?
- javascript - Discord.js 拍命令
- matlab - Matlab 符号工具箱中的 Subs 不计算值
- tcp - Hazelcast tcp-ip 配置集群:即使指定了集群名称,不需要的 IP 也会加入集群
- javascript - TypeError:无法获取 - 仅在 Cordova 应用程序上
- css - 如何让一个 div 覆盖所有可用的垂直空间,然后将它包裹在一个行方向的 flexbox 布局中?
- assembly - 为什么让一些寄存器调用者保存而另一些寄存器保存被调用者?为什么不让调用者保存它想要保存的所有内容?
- java - Firebase:对收到的通知进行 POST 请求
- r - 与条形图相比,误差线图太大?
- node.js - Nest 无法解析 WithdrawService 的依赖关系