tensorflow - 结合一个冻结模型的预测输出进行预测,然后将这些预测用作我正在训练的模型的损失?
问题描述
我想将一个冻结模型的预测输出组合到另一个模型的训练阶段。
我尝试过使用不同的图表会话,但它会在训练阶段重置默认图表。
predictions = model1.model(input1, input2, mode)
predictions2 = model2.predict(predictions)
loss1 = mean_squared_error(predictions, labels)
loss2 = mean_squared_error(input2, predictions2)
total_loss = loss1+loss2
optimizer.minimize(total_loss)
ValueError: Tensor Tensor("output_layer/BiasAdd:0", shape=(?, 100), dtype=float32) 不是该图的元素
解决方案
我刚刚想通了!
在 tensorflow 的估计器框架中,将模型加载到估计器空间的“model_fn”中,属性为:“keras_model.trainable=False”
例如片段:
def model_fn(inputs):
.......
#some operations
.......
model2=load_model('frozen_model.h5')
model2.trainable=False
model2.summary()
predictions= model2(inputs=predictions)
推荐阅读
- powershell - 为什么 ScriptBlock 将 Hashtable 转换为 Collection 以及如何避免它?
- maven - 如何理解 Maven .m2 文件夹?
- caching - 如何使用兵马俑服务器实现集群缓存?
- python - 我如何将多处理应用到这个 Python3 脚本?
- memory-leaks - Pytorch:如何知道是否确实需要使用正在使用的 GPU 内存或是否存在内存泄漏
- c# - 如何使用 .NET 通用主机将 REST API 添加到用 .NET Framework 4.8 编写的控制台应用程序?
- ios - 您好,我正在使用 Collection View 单元格,我想记录其中有多少被单击“变为真”并返回下一个视图控制器
- ruby - 如何为每个 Enumerable 方法创建别名?
- r - R - 访问在函数参数中作为字符串传递的数据框列名
- javascript - 如何强制更新(重新渲染)React 组件?