python - 从 Tensorflow 迁移到 PyTorch 时模型定义的注意事项
问题描述
在调试 tf 感到沮丧之后,我最近才切换到 PyTorch,并且明白它几乎完全等同于在 numpy 中编码。我的问题是我们可以在 PyTorch 模型中使用哪些允许的 python 方面(完全放在 GPU 上),例如。if-else 必须在 tensorflow 中按如下方式实现
a = tf.Variable([1,2,3,4,5], dtype=tf.float32)
b = tf.Variable([6,7,8,9,10], dtype=tf.float32)
p = tf.placeholder(dtype=tf.float32)
ps = tf.placeholder(dtype=tf.bool)
li = [None]*5
li_switch = [True, False, False, True, True]
for i in range(5):
li[i] = tf.Variable(tf.random.normal([5]))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def func_0():
return tf.add(a, p)
def func_1():
return tf.subtract(b, p)
with tf.device('GPU:0'):
my_op = tf.cond(ps, func_1, func_0)
for i in range(5):
print(sess.run(my_op, feed_dict={p:li[i], ps:li_switch[i]}))
上述代码的 pytorch 结构将如何变化?如何将上面的变量和操作放在 GPU 上,并在 pytorch 中将列表输入并行化到我们的图形中?
解决方案
要在 PyTorch 中初始化您的张量a
和张量,请执行以下操作:b
a = torch.tensor([1,2,3,4,5], dtype=torch.float32)
b = torch.tensor([6,7,8,9,10], dtype=torch.float32)
但是,由于您需要它们完全在 GPU 上,因此您必须使用魔法.cuda()
功能。所以,它会是:
a = torch.tensor([1,2,3,4,5], dtype=torch.float32).cuda()
b = torch.tensor([6,7,8,9,10], dtype=torch.float32).cuda()
它将张量移动到 GPU
另一种初始化方式是:
a = torch.FloatTensor([1,2,3,4,5]).cuda()
b = torch.FloatTensor([6,7,8,9,10]).cuda()
如果我们需要生成我们使用的随机正态分布torch.randn
(也有torch.rand
一个均匀随机分布)。
li = torch.randn(5, 5)
(抓住这个bug,它必须被初始化cuda
,你不能对位于不同处理单元的张量进行操作,即CPU和GPU)
li = torch.randn(5, 5).cuda()
li_switch
初始化没有区别。
处理您的func_0
和的一种可能方法func_1
是将它们声明为
def func_0(li_value):
return torch.add(a, li_value)
def func_1(li_value):
return torch.sub(b, li_value)
然后,对于谓词函数调用,它可以像这样做一样简单:
for i, pred in enumerate(li_switch):
if pred:
func_0(li[i])
else:
func_1(li[i])
但是,我建议对您的操作进行矢量化并执行以下操作:
li_switch = torch.tensor([True, False, False, True, True])
torch.add(a, li[li_switch]).sum(dim=0)
torch.sub(b, li[~li_switch]).sum(dim=0)
这更加优化。
推荐阅读
- node.js - localhost:5000 未运行 MERN 应用程序-MIME 错误
- python - Django 项目中 WeazyPrint 生成的 PDF 中未显示的图像
- excel - 更新 Excel 中依赖于首先更新的另一个工作簿的数据
- java - Android Studio:膨胀类片段时出错(加载activity_main UI时程序崩溃)
- javascript - 对象有许多数组,遍历每个数组并更改一些值
- flutter - 如何处理散列用户凭据和 Parse 身份验证
- python - 仅从字典中获取最后一个元组
- swift - Swift - 找出即将到来的时间和当前时间之间的差异
- php - 使用多态关系从 2 个表中获取数据
- python - 无法定位硒元素: