tensorflow - 预测张量流
问题描述
在训练我的“参数”(w1,w2,Conv 网络中过滤器的权重)后,将它们保存为 parameters=sess.run(parameters)
我拍摄图像 img=[1,64,64,3],并将其传递给 mypredict(x,parameters) 函数进行预测,但它给出了错误。功能如下。任何关于出现问题的建议。
def forward_propagation(X,参数):
W1 = parameters['W1']
W2 = parameters['W2']
Z1 = tf.nn.conv2d(X,W1,strides=[1,1,1,1],padding='SAME')
A1 = tf.nn.relu(Z1)
P1 = tf.nn.max_pool(A1,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME')
Z2 = tf.nn.conv2d(P1,W2,strides=[1,1,1,1],padding='SAME')
A2 = tf.nn.relu(Z2)
P2 = tf.nn.max_pool(A2,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')
P2 = tf.contrib.layers.flatten(P2)
Z3 = tf.contrib.layers.fully_connected(P2,num_outputs=6,activation_fn=None)
return Z3
def mypredict(X,par):
W1 = tf.convert_to_tensor(par["W1"])
W2 = tf.convert_to_tensor(par["W2"])
params = {"W1": W1,
"W2": W2}
x = tf.placeholder("float", [1,64,64,3])
z3 = forward_propagation_for_predict(x, params)
p = tf.argmax(z3)
sess = tf.Session()
prediction = sess.run(p, feed_dict = {x:X})
return prediction
我使用相同的函数“forward_propagation”来训练权重,但是当我传递单个图像时,它不起作用。
错误:
FailedPreconditionError Traceback(最近一次调用最后)/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1138 try: -> 1139 return fn (*args) 1140 除了errors.OpError as e:
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata) 1120 feed_dict, fetch_list, target_list, -> 1121 状态,运行元数据)1122
/opt/conda/lib/python3.6/contextlib.py in exit (self, type, value, traceback) 88 try: ---> 89 next(self.gen) 90 除了 StopIteration:
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in raise_exception_on_not_ok_status() 465 compat.as_text(pywrap_tensorflow.TF_Message(status)), --> 466 pywrap_tensorflow.TF_GetCode(status )) 467 最后:
FailedPreconditionError: Attempting to use uninitialized value fully_connected_1/biases [[Node: fully_connected_1/biases/read = IdentityT=DT_FLOAT, _class=["loc:@fully_connected_1/biases"], _device="/job:localhost/replica:0/task :0/cpu:0"]]
在处理上述异常的过程中,又出现了一个异常:
FailedPreconditionError Traceback (last last call last) in () ----> 1 pred=mypredict(t,pp) 2
在 mypredict(X, par) 49 50 sess = tf.Session() ---> 51 prediction = sess.run(p, feed_dict = {x:X}) 52 53 返回预测
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 787 try: 788 result = self._run(None, fetches , feed_dict, options_ptr, --> 789 run_metadata_ptr) 790 if run_metadata: 791 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 995 if final_fetches 或 final_targets: 996 results = self. _do_run(handle, final_targets, final_fetches, --> 997 feed_dict_string, options, run_metadata) 998 else: 999 results = []
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 1130 如果句柄为 None:1131 返回 self ._do_call(_run_fn, self._session, feed_dict, fetch_list, -> 1132 target_list, options, run_metadata) 1133 else: 1134 return self._do_call(_prun_fn, self._session, handle, feed_dict,
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1150 除了 KeyError: 1151 pass -> 1152 raise type(e)(node_def , 操作, 消息) 1153 1154 def_extend_graph(self):
FailedPreconditionError:尝试使用未初始化的值fully_connected_1/biases
解决方案
您还必须从全连接层加载参数。
不过,我还是建议使用TensorFlow 的 Saver 和 Restore 功能。
为了节省,这是一个玩具示例:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000) # saving model after 1000 steps
存储以下文件:
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
所以为了恢复,你可以先重新创建网络,然后加载参数:
with tf.Session() as sess:
recreated_net = tf.train.import_meta_graph('my_test_model-1000.meta')
recreated_net.restore(sess, tf.train.latest_checkpoint('./'))
推荐阅读
- reactjs - react-router 开关正在停止状态以向下传递给子组件
- swift4 - 拐角半径仅在特定拐角处
- r - 如何防止marrangeGrob打开图形设备
- ios - 从 url (AVURLAsset) 加载后,我的肖像视频会旋转
- javascript - 正则表达式仅匹配值中的第一个字符
- javascript - 如何创建一个基本的点击事件?
- postgresql - 这两个查询如何与 and or 子句一起使用?
- rust - 使用 Serde 反序列化可能为空的字符串
- c# - C# 和 PowerShell 中的非统一 System.DateTime.Now 默认格式
- javascript - 使用 div 在地图外触发谷歌地图 infoWindow