首页 > 解决方案 > Pytorch:nn.Parameter 到 tensorflow 变量

问题描述

我正在尝试将 pytorch 代码转换为 tensorflow 代码。我无法找到如何在 tensorflow 中转换参数,我想使用 tf.get_variable 但不确定。

Pytorch 代码如下所示:

def gen_adj(A):
    D = torch.pow(A.sum(1).float(), -0.5)
    D = torch.diag(D)
    adj = torch.matmul(torch.matmul(A, D).t(), D)
    return adj

self.A = Parameter(torch.from_numpy(_adj).float())
inp = inp[0]
adj = gen_adj(self.A).detach()
x = self.new_fucn(inp, adj)

所以对于这条线

self.A = Parameter(torch.from_numpy(_adj).float())

我尝试的是:

self.A = tf.get_variable(name="self.A", 
                                 shape=[_adj.shape[0],_adj.shape[1]], 
                                 initializer=tf.constant_initializer(np.array(_adj)), 
                                 trainable=False
                                )

和 get_adj 函数是这样的:

def tensorflow_adj(A):

    D = tf.pow(tf.reduce_sum(A, 1),-0.5)
    D = tf.diag_part(D)
    adj = tf.matmul( tf.transpose(tf.matmul(A,D),[1,0]),D)

    return adj

如果我的转换代码正确,如果有人能给我一些建议,我将不胜感激。

谢谢 !

标签: pythontensorflowkerasdeep-learningpytorch

解决方案


推荐阅读