python - Mat plot lib 绘制图像而不是保存(facebook Detectron 代码)
问题描述
我试图让这个代码在函数被调用时绘制出来。发生了很多事情,也许我忽略了一些事情。我试图弄清楚什么是最好的绘图方式而不是存储图像。这里的大部分代码来自 facebook detectron ( https://github.com/facebookresearch/Detectron ),我试图编辑的只有部分是函数的结尾。
def vis_image(
im, im_name, output_dir, boxes, segms=None, keypoints=None, thresh=0.9,
kp_thresh=2, dpi=200, box_alpha=0.0, dataset=None, show_class=False,
ext='pdf'):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if isinstance(boxes, list):
boxes, segms, keypoints, classes = convert_from_cls_format(
boxes, segms, keypoints)
if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
return
dataset_keypoints, _ = keypoint_utils.get_keypoints()
if segms is not None and len(segms) > 0:
masks = mask_util.decode(segms)
color_list = colormap(rgb=True) / 255
kp_lines = kp_connections(dataset_keypoints)
cmap = plt.get_cmap('rainbow')
colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]
fig = plt.figure(frameon=False)
fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.axis('off')
fig.add_axes(ax)
ax.imshow(im)
# Display in largest to smallest order to reduce occlusion
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sorted_inds = np.argsort(-areas)
mask_color_id = 0
for i in sorted_inds:
bbox = boxes[i, :4]
score = boxes[i, -1]
if score < thresh:
continue
# show box (off by default)
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1],
fill=False, edgecolor='g',
linewidth=0.5, alpha=box_alpha))
if show_class:
ax.text(
bbox[0], bbox[1] - 2,
get_class_string(classes[i], score, dataset),
fontsize=3,
family='serif',
bbox=dict(
facecolor='g', alpha=0.4, pad=0, edgecolor='none'),
color='white')
# show mask
if segms is not None and len(segms) > i:
img = np.ones(im.shape)
color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1
w_ratio = .4
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
for c in range(3):
img[:, :, c] = color_mask[c]
e = masks[:, :, i]
_, contour, hier = cv2.findContours(
e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
for c in contour:
polygon = Polygon(
c.reshape((-1, 2)),
fill=True, facecolor=color_mask,
edgecolor='w', linewidth=1.2,
alpha=0.5)
ax.add_patch(polygon)
# show keypoints
if keypoints is not None and len(keypoints) > i:
kps = keypoints[i]
plt.autoscale(False)
for l in range(len(kp_lines)):
i1 = kp_lines[l][0]
i2 = kp_lines[l][1]
if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
x = [kps[0, i1], kps[0, i2]]
y = [kps[1, i1], kps[1, i2]]
line = plt.plot(x, y)
plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7)
if kps[2, i1] > kp_thresh:
plt.plot(
kps[0, i1], kps[1, i1], '.', color=colors[l],
markersize=3.0, alpha=0.7)
if kps[2, i2] > kp_thresh:
plt.plot(
kps[0, i2], kps[1, i2], '.', color=colors[l],
markersize=3.0, alpha=0.7)
# add mid shoulder / mid hip for better visualization
mid_shoulder = (
kps[:2, dataset_keypoints.index('right_shoulder')] +
kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
sc_mid_shoulder = np.minimum(
kps[2, dataset_keypoints.index('right_shoulder')],
kps[2, dataset_keypoints.index('left_shoulder')])
mid_hip = (
kps[:2, dataset_keypoints.index('right_hip')] +
kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
sc_mid_hip = np.minimum(
kps[2, dataset_keypoints.index('right_hip')],
kps[2, dataset_keypoints.index('left_hip')])
if (sc_mid_shoulder > kp_thresh and
kps[2, dataset_keypoints.index('nose')] > kp_thresh):
x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]]
y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]]
line = plt.plot(x, y)
plt.setp(
line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7)
if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
x = [mid_shoulder[0], mid_hip[0]]
y = [mid_shoulder[1], mid_hip[1]]
line = plt.plot(x, y)
plt.setp(
line, color=colors[len(kp_lines) + 1], linewidth=1.0,
alpha=0.7)
'''
output_name = os.path.basename(im_name) + '.' + ext
fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi)
plt.close('all')
'''
plt.show()
plt.close('all')
解决方案
去掉plt.close('all')
线怎么办?
推荐阅读
- javascript - 更改函数以获得单独的 JS 文件 - AJAX、HTML 和 MySQL
- c# - 如何在 C# 中将 Excel 文件导入 DataGridView
- keras - 当目标有上限和下限时,在回归任务中使用什么激活函数
- r - 如何在 R 中乘以引号?
- php - LARAVEL:我想知道用户今天是否在数据库中输入数据并在刀片上显示是或否
- python - 在构建用于时间序列分析的数据框时使列名唯一的最快方法?
- hibernate - SpringBoot 和 H2 自动生成的字段
- python - Python:是否可以使用不带参数的 string.format 方法?
- javascript - 获取 iframe 当前地址
- go - 在 Go 中使用通道,我创建了一个返回地址的阶乘函数