javascript - Writing custom InstantLayerNormalization in tensorflow js
问题描述
I am trying to implement a deep learning model in the browser and this requires porting some custom layers, one of them is an instant layer normalization. Below the piece of code that is supposed to work but it's a bit old. I get this error:
Uncaught (in promise) ReferenceError: initializer is not defined at InstantLayerNormalization.build
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
<script>
class InstantLayerNormalization extends tf.layers.Layer
{
static className = 'InstantLayerNormalization';
epsilon = 1e-7
gamma;
beta;
constructor(config)
{
super(config);
}
getConfig()
{
const config = super.getConfig();
return config;
}
build(input_shape)
{
let shape = tf.tensor(input_shape);
// initialize gamma
self.gamma = self.add_weight(shape=shape,
initializer='ones',
trainable=true,
name='gamma')
// initialize beta
self.beta = self.add_weight(shape=shape,
initializer='zeros',
trainable=true,
name='beta')
}
call(inputs){
mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
variance = tf.math.reduce_mean(tf.math.square(inputs - mean), axis=[-1], keepdims=True)
std = tf.math.sqrt(variance + self.epsilon)
outputs = (inputs - mean) / std
outputs = outputs * self.gamma
outputs = outputs + self.beta
return outputs
}
static get className() {
console.log(className);
return className;
}
}
tf.serialization.registerClass(InstantLayerNormalization);
</script>
解决方案
The methods of the inherited class tf.layers.Layer
are not called properly.
self
in python isthis
in jsadd_weight
is ratheraddWeight
- Here is the signature of the
addWeight
method. Please notice that in js there is not the formatvariable=value
for function arguments destructuring assignment
// instead of this
self.gamma = self.add_weight(shape=shape, initializer='ones', trainable=true, name='gamma')
// it should rather be
this.gamma = this.addWeight('gamma', shape, undefined, 'ones', undefined, true)
推荐阅读
- python - 关键错误:所有 [Index[Columns] 都不在列中
- javascript - 长按按钮重复烧瓶功能
- javascript - 猫鼬查询的行为方式很奇怪
- javascript - 如果 Id 匹配,反应如何替换名称
- indexing - 为使用 Qgis 中的聚合函数连接的每个值创建索引
- c++ - 三向比较运算符的结果类型重载比较运算符
- mitmproxy - 如何在 Mitmproxy 插件中发出网络请求
- c# - Internet 连接检查 API (wininet.dll) 不起作用 | C# WPF
- git - 如何 git checkout “下一个” 提交
- java - 如何使用 Java 客户端代码获取 OAuth2.0 访问令牌并使用此令牌调用 RESTFul Web 服务