python - ValueError: Tensor Tensor("mrcnn_detection/Reshape_1:0", shape=(1, 100, 6), dtype=float32) 不是该图的元素
问题描述
我试图使用带有烧瓶的 rcnn 来检测形状。但是得到了这个错误: Traceback(最近一次调用最后一次):文件“C:\Users\KushalM\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\framework\ops.py”,第 3505 行,在 as_graph_element 返回 self._as_graph_element_ocked(obj, allow_tensor, allow_operation) raise ValueError("Tensor %s 不是此图的元素。" % obj) ValueError: Tensor Tensor("mrcnn_detection/Reshape_1:0", shape=(1, 100, 6), dtype=float32) 不是该图的元素。
这是我的代码:
app.py
from flask import request, redirect
import os
from flask import Flask
from flask import render_template
import os
#import detect
from detect import getImage
app = Flask(__name__)
import os
import sys
import numpy as np
HOME_DIR = os.getcwd()
LIB_DIR = os.path.join(HOME_DIR, "libraries")
#os.chdir(LIB_DIR)
sys.path.insert(0, LIB_DIR)
#print(os.getcwd())
print(os.listdir('./libraries/'))
from mrcnn.config import Config
import mrcnn.utils as utils
import mrcnn.model as modellib
import mrcnn.visualize as visualize
from mrcnn.model import log
import skimage
import tensorflow as tf
from tensorflow.python.keras import backend as k
image_size = 64
rpn_anchor_template = (1, 2, 4, 8, 16) # anchor sizes in pixels
rpn_anchor_scales = tuple(i * (image_size // 16) for i in rpn_anchor_template)
class ShapesConfig(Config):
"""Configuration for training on the shapes dataset.
"""
NAME = "shapes"
# Train on 1 GPU and 2 images per GPU. Put multiple images on each
# GPU if the images are small. Batch size is 2 (GPUs * images/GPU).
GPU_COUNT = 1
IMAGES_PER_GPU = 1
# Number of classes (including background)
NUM_CLASSES = 1 + 4 # background + 3 shapes (triangles, circles, and squares)
# Use smaller images for faster training.
IMAGE_MAX_DIM = image_size
IMAGE_MIN_DIM = image_size
# Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = rpn_anchor_scales
# Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 32
STEPS_PER_EPOCH = 50
VALIDATION_STEPS = STEPS_PER_EPOCH / 20
config = ShapesConfig()
config.display()
class InferenceConfig(ShapesConfig):
GPU_COUNT = 1
IMAGES_PER_GPU = 1
inference_config = InferenceConfig()
MODEL_DIR = "./weights/mask_rcnn_coco.h5"
# Recreate the model in inference mode
model = modellib.MaskRCNN(mode="inference",
config=inference_config,
model_dir=MODEL_DIR)
# Get path to saved weights
# Either set a specific path or find last trained weights
# model_path = os.path.join(ROOT_DIR, ".h5 file name here")
#model_path = model.find_last()
ROOT = os.getcwd()
model_path = ROOT + "/weights/mask_rcnn_coco.h5"
# Load trained weights
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)
#graph = tf.get_default_graph()
class_names = ['BG', 'Std Paneer', 'Non Std Paneer', 'Std Chicken', 'Non Std Chicken']
#graph = tf.get_default_graph()
app.config["IMAGE_UPLOADS"] = './static'
@app.route("/upload-image", methods=["GET", "POST"])
def upload_image():
if request.method == "POST":
if request.files:
image = request.files["image"]
#print(image.shape)
image= getImage(image)
file = image.filename
image.save(os.path.join(app.config["IMAGE_UPLOADS"], image.filename))
print("Image saved")
return render_template("view.html", image_name = file)
return render_template("upload.html")
if __name__ == '__main__':
app.run(debug = True, use_reloader = False)
detect.py
def getImage(image):
image = skimage.io.imread(image)
print(image.shape)
#image = image.resize((64,64))
# Run detection
results = model.detect([image], verbose=1)
# Visualize results
r = results[0]
print(len(r['class_ids']))
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], class_names,
r['scores'])
解决方案
改成:
model.load_weights(model_path, by_name=True)
model.keras_model._make_predict_function()
推荐阅读
- cuda - cuda.jit 矩阵乘法崩溃
- tabulator - 如何通过 ajax:Response 从返回中访问制表列以进行格式化
- javascript - 如何在不移动的本机反应中安排带有滚动视图的静态按钮?
- javascript - 在 JavaScript 数组中搜索与字符串模式匹配的所有项目的计数
- python - Pandas 数据读取器时间限制
- jquery - 在 jQuery 中设置的编辑器值未传递给控制器
- c - 不明白为什么来自 APUE 的代码片段会取消链接到客户端 unix 域套接字的文件
- php - “消息”:“密钥路径 \”file:///app/storage/oauth-private.key\“在 Laravel 6 中不存在或不可读”
- javascript - TypeError: XXXXXXX 不是函数(开玩笑时)
- java - 构建apk后找不到Java静态接口方法