keras - 如何在 Keras 模型中实现一些可训练的参数,例如 Pytorch 中的 nn.Parameters()?
问题描述
我只想用 Keras 在我的模型中实现一些可训练的参数。在 Pytorch 中,我们可以通过使用 torch.nn.Parameter() 来实现,如下所示:
self.a = nn.Parameter(torch.ones(8))
self.b = nn.Parameter(torch.zeros(16,8))
我认为通过在 pytorch 中执行此操作,它可以在模型中添加一些可训练的参数。现在我想知道,如何在keras中实现类似的操作?欢迎任何建议或建议!
谢谢!:)
ps 我只是在 Keras 中编写了一个自定义层,如下所示:
class Mylayer(Layer):
def __init__(self,input_dim,output_dim,**kwargs):
self.input_dim = input_dim
self.output_dim = output_dim
super(Mylayer,self).__init__(**kwargs)
def build(self):
self.kernel = self.add_weight(name='pi',
shape=(self.input_dim,self.output_dim),
initializer='zeros',
trainable=True)
self.kernel_2 = self.add_weight(name='mean',
shape=(self.input_dim,self.output_dim),
initializer='ones',
trainable=True)
super(Mylayer,self).build()
def call(self,x):
return x,self.kernel,self.kernel_2
我想知道我是否没有改变通过层的张量,我是否应该编写def compute_output_shape()
必要的函数?
解决方案
您需要在自定义层中创建可训练的权重:
class MyLayer(Layer):
def __init__(self, my_args, **kwargs):
#do whatever you need with my_args
super(MyLayer, self).__init__(**kwargs)
#you create the weights in build:
def build(self, input_shape):
#use the input_shape to infer the necessary shapes for weights
#use self.whatever_you_registered_in_init to help you, like units, etc.
self.kernel = self.add_weight(name='kernel',
shape=the_shape_you_calculated,
initializer='uniform',
trainable=True)
#create as many weights as necessary for this layer
#build the layer - equivalent to self.built=True
super(MyLayer, self).build(input_shape)
#create the layer operation here
def call(self, inputs):
#do whatever operations are needed
#example:
return inputs * self.kernel #make sure the shapes are compatible
#tell keras about the output shape of your layer
def compute_output_shape(self, input_shape):
#calculate the output shape based on the input shape and your layer's rules
return calculated_output_shape
现在在模型中使用您的图层。
如果您在 tensorflow 上使用 Eager Execution 并创建自定义训练循环,您可以使用与 PyTorch 几乎相同的方式工作,并且可以在层外创建权重tf.Variable
,将它们作为参数传递给梯度计算方法。
推荐阅读
- reactjs - JSX 根据 props 中包含的元素控制标题的可见性
- java - Java - 如何按升序将元素添加到数组列表并计算数组列表中的对数?
- c# - 如何调整 TableLayoutPanel 以按高度和重量包装内容?(动态)
- python - 排序字典,其值为字典列表的形式
- swift - 在 iOS 上制作 Admob 应用程序,出现线程 1 错误
- c++ - 如何在不创建新对象的情况下实现类运算符?
- bash - freeStyleJob Jenkins DSL 作业运行复杂的 bash 脚本
- typescript - NestJS 自定义装饰器返回未定义
- android - 使用- Fragment- Android -Mapbox 执行位置选择器
- python - 如何在 Python 的子类中访问超类变量?