tensorflow - tensorflow.python.framework.errors_impl 在恢复模型后失败
问题描述
尽管查看了许多其他 StackOverflow 页面,但很难使用 TensorFlow 1.15 恢复已保存的模型
基本上,它在调用pred = model.predict(data)后失败
def restoreModel(data, input_dims, network_settings,file_path_final):
tf.compat.v1.reset_default_graph()
try:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
init = tf.compat.v1.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver = tf.compat.v1.train.import_meta_graph(file_path_final+'Losv2_311-800.meta')
saver.restore(sess, tf.train.latest_checkpoint(file_path_final))
model = Model_DeepHit(sess, "DeepHit", input_dims, network_settings)
pred = model.predict(data)
class Model_DeepHit:
def __init__(self, sess, name, input_dims, network_settings):
self.sess = sess
self.name = name
# INPUT DIMENSIONS
self.x_dim = input_dims['x_dim']
self.num_Event = input_dims['num_Event']
self.num_Category = input_dims['num_Category']
# NETWORK HYPER-PARMETERS
self.h_dim_shared = network_settings['h_dim_shared']
self.h_dim_CS = network_settings['h_dim_CS']
self.num_layers_shared = network_settings['num_layers_shared']
self.num_layers_CS = network_settings['num_layers_CS']
self.active_fn = network_settings['active_fn']
self.initial_W = network_settings['initial_W']
self.reg_W = tf.contrib.layers.l2_regularizer(scale=1.0)
self.reg_W_out = tf.contrib.layers.l1_regularizer(scale=1.0)
self._build_net()
def predict(self, x_test, keep_prob=1.0):
return self.sess.run(self.out, feed_dict={self.x: x_test, self.mb_size: np.shape(x_test)[0], self.keep_prob: keep_prob})
输出
Traceback(最近一次调用最后):
文件“C:\tools\Python\Python37\lib\site-packages\tensorflow_core\python\client\session.py”,第 1365 行,在 _do_call
return fn(*args)
文件中C:\tools\Python\Python37\lib\site-packages\tensorflow_core\python\client\session.py”,第 1350 行,在 _run_fn 目标列表,run_metadata)
文件“C:\tools\Python\Python37\lib\site- packages\tensorflow_core\python\client\session.py",第 1443 行,在 _call_tf_sessionrun run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value DeepHit/fully_connected_2/biases_1
[[{{node DeepHit/fully_connected_2/biases_1 /读}}]]
注意:DeepHit 类还有其他方法,但似乎缺少一些变量来提供类,但即使查看相关的 TensorFlow 1.15 版本,我也无法弄清楚
x_test: [[ 0.63193141 -0.19093631 1.32147613 1.5206112 -0.32124657 1.09037371 -0.56493425 0.94727136 0.27852626 -1.39992279 -0.77934678 -0.01374535 -0.19709776 -0.37450909 -0.43395308 -0.17253745 0.21124831 -0.85609694 -0.17252606 -0.37291207 -1.08927999 -0.18948054 -0.15836433 -0.28718442 -0.14664455]]
输入:Tensor("DeepHit/inputs_1:0", shape=(?, 133), dtype=float32)
num_layers: 2
h_dim: 300
h_fn: <function elu at 0x00000153294DE8C8>
o_dim: 300
o_fn: <function elu at 0x00000153294DE8C8
> : <function variance_scaling_initializer.._initializer at 0x0000015340DFFAE8>
keep_prob: Tensor("DeepHit/keep_probability_1:0", shape=(), dtype=float32)
w_reg: <function l2_regularizer..l2 at 0x000001534204D9D8>
目标:创建不同规格的FC网络
输入(张量):输入张量
num_layers:FCNet 中的层数
h_dim(int):隐藏单元的数量
h_fn:隐藏层的激活函数(默认值:tf.nn.relu)
o_dim (int) :输出单元的数量
o_fn :输出层的激活函数(默认值:无)
w_init :权重矩阵的初始化(默认值:Xavier)
keep_prob :保持概率 [0, 1] (如果没有,则不使用 dropout)
输入数据
dbfs ls -l dbfs:/FileStore/
文件 12845144 Losv2_311-800.data-00000-of-00001
文件 1756 Losv2_311-800.index
文件 330197 Losv2_311-800.meta
文件 260 Losv2_311_par.pkl
文件 137 检查点
解决方案
推荐阅读
- flutter - 无法在 Flutter 中用阴影重构颜色
- css - 如何使用引导程序 5 显示具有响应式设计的卡片
- android - 在 android 10/11 (api 29/30) 下是否仍然可以破坏性地修改您不拥有的文件?
- django - 视图...没有返回 HttpResponse 对象。它返回 None 代替 - django
- import - 如何在导入过程中高效地重命名多个文件或部分选择这些文件的名称?
- c - GCC 11 参数顺序触发误报 Wstringop-overflow,这是错误吗?
- algorithm - 对所有 i, j 有效地求和 max(Ai+Bj, Bi+Aj)
- reactjs - Material-UI 日历中的一个视图中的十二个月 - DatePicker 或任何其他包
- php - 是否可以从其他地方更改函数内部?
- html - 全宽三列行上的水平滚动条