python - CNTK Python API:加载模型后访问图层
问题描述
加载模型后我无法访问图层。
我创建的模型如下:
def create_model(vocab_dim, hidden_dim):
input_seq_axis1 = Axis('inputAxis1')
input_sequence_before = sequence.input_variable(shape=vocab_dim, sequence_axis=input_seq_axis1, is_sparse = use_sparse)
input_sequence_after = sequence.input_variable(shape=vocab_dim, sequence_axis=input_seq_axis1, is_sparse = use_sparse)
e=Sequential([
C.layers.Embedding(hidden_dim),
Stabilizer()
],name='Embedding')
a = Sequential([
e,
C.layers.Recurrence(C.layers.LSTM(hidden_dim//2),name='ForwardRecurrence'),
],name='ForwardLayer')
b = Sequential([
e,
C.layers.Recurrence(C.layers.LSTM(hidden_dim//2),go_backwards=True),
],name='BackwardLayer')
latent_vector = C.splice(a(input_sequence_before), b(input_sequence_after))
bias = C.layers.Parameter(shape = (vocab_dim, 1), init = 0, name='Bias')
weights = C.layers.Parameter(shape = (vocab_dim, hidden_dim), init = C.initializer.glorot_uniform(), name='Weights')
z = C.times_transpose(weights, latent_vector,name='Transpose') + bias
z = C.reshape(z, shape = (vocab_dim))
return z
然后我加载模型:
def load_my_model(vocab_dim, hidden_dim):
z=load_model("models/lm_epoch0.dnn")
input_sequence_before = z.arguments[0]
input_sequence_after = z.arguments[1]
a=z.ForwardLayer
b=z.BackwardLayer
latent_vector = C.splice(a(input_sequence_before), b(input_sequence_after))
我收到一个错误:TypeError("argument ForwardRecurrence 的类型 SequenceOver[inputAxis1][Tensor[100]] 与传递的变量的类型 SequenceOver[inputAxis1][SparseTensor[50000]] 不兼容",)
看起来名称引用的层 (z.ForwardLayer) 表示来自层立即输入的函数。如何计算“latent_vector”(我需要这个变量来创建交叉熵和损失函数以继续训练)?
解决方案
根据错误,与 ForwardLayer 的预期 (100) 相比,您的输入 seq 的尺寸太大 (5000)。
当您通过 选择节点 ForwardLayer 时z.ForwardLayer
,您只能选择那个非常特定的节点/层,而不是与其连接的计算图的层/节点/其余部分。
你应该这样做a = C.combine([z.ForwardLayer.owner])
,你应该没事。
推荐阅读
- javascript - 使用javascript在azure blob上传中暂停和恢复选项
- android - 如何在 Android 中设置环境变量?
- go - 如何将 protoc-gen-go gzipped FileDescriptorProto 显示为纯文本?
- javascript - 以编程方式更改选择范围后如何显示选择光标?
- react-native - 反应原生 | 'react-native run-android' 命令需要运行两次才能编译
- javascript - 是否可以使用 AJAX 在 jsp 上不断显示来自 Servlet 的数据?
- python - 如何在不复制代码的情况下将相同的例外应用于多个函数?
- javascript - 如何在 Chartist.JS 中为标签添加逗号
- python - 在入口小部件上意外捕获 GTK 多输入
- javascript - 无法启动 Vue CLI 项目