tensorflow - 如何随机旋转张量图像
问题描述
我想在预处理阶段使用“地图”并行旋转我的图像。
问题是每张图像都向相同的方向旋转(在生成一个随机数之后)。但我希望每个图像都有不同程度的旋转。
这是我的代码:
import tensorflow_addons as tfa
import math
import random
def rotate_tensor(image, label):
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
我试图在每次调用该函数时更改种子:
import tensorflow_addons as tfa
import math
import random
seed_num = 0
def rotate_tensor(image, label):
seed_num += 1
random.seed(seed_num)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
但我得到:
UnboundLocalError: local variable 'seed_num' referenced before assignment
我使用 tf2,但我认为这并不重要(除了旋转图像的代码之外)。
编辑:我尝试了@Mehraban 的建议,但似乎 rotate_tensor 函数只被调用一次:
import tensorflow_addons as tfa
import math
import random
num_seed = 1
def rotate_tensor(image, label):
global num_seed
num_seed += 1
print(num_seed) #<---- print num_seed
random.seed(num_seed)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
但它只打印一次“2”。所以我认为 rotate_tensor 被调用一次。
编辑 2 - 这是显示旋转图像的功能:
plt.figure(figsize=(12, 10))
for X_batch, y_batch in rotated_test_set.take(1):
for index in range(9):
plt.subplot(3, 3, index + 1)
plt.imshow(X_batch[index])
plt.title("Predict: {} | Actual: {}".format(class_names[y_test_proba_max_index[index]], class_names[y_batch[index]]))
plt.axis("off")
plt.show()
解决方案
问题在于如何生成随机数。尽管在处理 tensorflow 时random
应该使用模块,但您依赖于模块。tf.random
以下是从 tf 获取随机数时情况如何变化的演示:
import tensorflow as tf
import random
def gen():
for i in range(10):
yield [1.]
ds = tf.data.Dataset.from_generator(gen, (float))
def m1(d):
return d*random.random()
def m2(d):
return d*tf.random.normal([])
[d for d in ds.map(m2)]
[0.17368042,
1.5629852,
1.2372143,
1.8170034,
1.7040217,
-0.16738933,
-0.11567844,
-0.17949782,
-0.67811996,
-0.5391556]
[d for d in ds.map(m1)]
[0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798]
推荐阅读
- javascript - 更改响应式数据表中的输入值
- php - $_SESSION:未定义索引
- php - 如何使用类函数和init文件获取redis键值
- node.js - big-react-calendar:'未捕获的类型错误:无法读取未定义的属性'momentLocalizer'
- json - 如何将 NSArray 存储在 Userdefaults 中?
- hadoop - 数据节点故障后 hdfs 恢复
- python - 从Python中的对象属性获取对象
- python - 如何在 Python 中返回可迭代对象?
- c# - C# DispatcherTimer 问题
- python - 图表的 X-label 不可见