python - Pytorch:nn.Parameter 到 tensorflow 变量
问题描述
我正在尝试将 pytorch 代码转换为 tensorflow 代码。我无法找到如何在 tensorflow 中转换参数,我想使用 tf.get_variable 但不确定。
Pytorch 代码如下所示:
def gen_adj(A):
D = torch.pow(A.sum(1).float(), -0.5)
D = torch.diag(D)
adj = torch.matmul(torch.matmul(A, D).t(), D)
return adj
self.A = Parameter(torch.from_numpy(_adj).float())
inp = inp[0]
adj = gen_adj(self.A).detach()
x = self.new_fucn(inp, adj)
所以对于这条线
self.A = Parameter(torch.from_numpy(_adj).float())
我尝试的是:
self.A = tf.get_variable(name="self.A",
shape=[_adj.shape[0],_adj.shape[1]],
initializer=tf.constant_initializer(np.array(_adj)),
trainable=False
)
和 get_adj 函数是这样的:
def tensorflow_adj(A):
D = tf.pow(tf.reduce_sum(A, 1),-0.5)
D = tf.diag_part(D)
adj = tf.matmul( tf.transpose(tf.matmul(A,D),[1,0]),D)
return adj
如果我的转换代码正确,如果有人能给我一些建议,我将不胜感激。
谢谢 !
解决方案
推荐阅读
- matrix - 从 MacAulay2 中的序列构造向量
- react-native - 为什么我的图像(和子组件)没有在 React Native 上使用 ImageBackground 显示?
- python-3.x - 如何删除google colab上本地上传的文件?
- mysql - SQL - 带条件的 min(date)
- javascript - 如何在 Javascript 中创建全局变量的新实例
- apache-spark - 如何在其他集群上向纱线提交作业?
- php - App Engine 标准:超时给了我 500 错误
- c# - 有没有办法直接从 sql 加载文件,我需要一种方法用 sql 中的其他图像替换图像
- java - 将动画的 layout_weight 从 0.3 扩展到 0.8
- react-native - 在 UIManager 中找不到 RNTMap