首页 > 解决方案 > TensorFlow 预训练模型导入错误

问题描述

我正在研究 tensorflow 中的对象检测模型。我有一个文件model.py

from PIL import Image   
import cv2   
import numpy as np   
import tensorflow as tf   
from .squeezenet import SqueezeNet

save_path = "sqnet/squeezenet.ckpt"
sess = tf.Session()
model = SqueezeNet(save_path=save_path, sess=sess)

class Finder(object):
    def __init__(self, image_path):
        self.image_path = image_path

    def predict(self):
        image = process(self.image_path)
        ans = sess.run(model.classifier, feed_dict={model.image: 
                       image})
        return ans


def process(path):
    image = Image.open(path)
    # image.show()
    image = np.array(image)
    image = cv2.resize(image, dsize=(224, 224), 
                       interpolation=cv2.INTER_CUBIC)
    image = image.reshape((1, 224, 224, 3))
    #print(image.shape)
    #img = Image.fromarray(image, 'RGB')
    return image


image_path = "/home/jatin/ai.jpeg"

object_detector = Finder(image_path)

ans = object_detector.predict()

print(np.argmax(ans))

sess.close()

我有一个与我有文件的文件sqnet一起命名的文件夹。但是运行它会给出错误:model.pysquuezenet.cpkt

InvalidArgumentError(参见上面的回溯):TensorSliceReader 构造函数不成功:无法在 sqnet/squeezenet.ckpt 上获取匹配文件:未找到:sqnet;没有这样的文件或目录。

可能是什么问题?

标签: pythontensorflow

解决方案


对我来说似乎是一个简单的 IO 错误。您是否尝试过使用绝对路径?

save_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'sqnet', 'squeezenet.ckpt')

推荐阅读