首页 > 解决方案 > pytorch 中的 tf.cast 等价物?

问题描述

我是 PyTorch 的新手。TensorFlow 有一个 API tf.cast()tf.shape()tf.cast在 TensorFlow中具有特定用途,在火炬中有什么等价物吗?我有张量 x= tensor(shape(128,64,32,32)): tf.shape(x)创建尺寸为 1 x.shape的张量创建真实尺寸。我需要在火炬中使用tf.shape(x) 。

tf.cast的作用与仅仅改变Torch 中的张量dtype不同。

有没有人在 Torch/PyTorch 中有等效的 API。

标签: pythontensorflowpytorch

解决方案


查看PyTorch 文档

正如他们提到的:

print(x.dtype) # Prints "torch.int64", currently 64-bit integer type
x = x.type(torch.FloatTensor)
print(x.dtype) # Prints "torch.float32", now 32-bit float
print(x.float()) # Still "torch.float32"
print(x.type(torch.DoubleTensor)) # Prints "tensor([0., 1., 2., 3.], dtype=torch.float64)"
print(x.type(torch.LongTensor)) # Cast back to int-64, prints "tensor([0, 1, 2, 3])"

推荐阅读