python - 对齐上下轴的刻度坐标
问题描述
我想构建一个图形,在其底部 x 轴上具有一组刻度,在其顶部 x 轴上具有另一组与底部刻度对齐的刻度。特别是在我的情况下,这些是批次和时期。对于底部的每个n
批处理点(不一定是刻度),我希望在顶部有一个时代刻度。考虑这个例子:
import numpy as np
import matplotlib.pyplot as plt
batches = np.arange(1,101)
epoch_ends = batches[[(i*10)-1 for i in range(1,11)]]
accuracy = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : x/len(batches))
loss = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : 1 - (x/len(batches)))
fig, ax1 = plt.subplots( nrows=1, ncols=1 )
ax2 = ax1.twinx()
ax3 = ax1.twiny()
ax1.set_xlabel('batches')
ax1.set_xticks(np.arange(1, len(batches)+1, 9))
ax1.set_ylabel('accuracy')
ax1.grid()
ax2.set_ylabel('loss')
ax2.set_yticklabels(np.linspace(3, 10, 9))
ax3.set_xlabel('epochs')
ax3.set_xticks(epoch_ends)
ax3.set_xticklabels(range(1, len(epoch_ends)+1))
acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')
loss_plt = ax2.plot(batches, loss, color='blue', label='loss')
lns = acc_plt+loss_plt
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=2)
plt.show()
batches
并epoch_ends
分别看起来像这样
[ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
91 92 93 94 95 96 97 98 99 100]
[ 10 20 30 40 50 60 70 80 90 100]
所以我希望纪元刻度 1 与批次 x-coorrdiante 10、2 与 20 等对齐。
但正如你在图片中看到的那样,它们并没有排成一行。
我需要在我的代码中进行哪些更改才能使其正常工作?
解决方案
这是对齐它们的一种方法。思路如下:
ax1
首先使用 ( )在下 x 轴上绘制数据- 然后使用设置上限 x 轴与下 x 轴相同
ax3.set_xlim(ax1.get_xlim())
- 然后将上 x 轴的刻度设置在对应于下 x 轴值 (10, 20, 30, ..., 90, 100) 的位置
- 最后,使用 重命名刻度标签
ax3.set_xticklabels()
。
这是代码:我正在用注释替换代码中已经存在的部分#
。
# imports and define data and compute accuracy and loss here
# Initiate figure and axis objects here
ax1.set_xlabel('batches')
ax1.set_xticks(np.arange(1, len(batches)+1, 9))
ax1.set_ylabel('accuracy')
ax1.grid()
acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')
loss_plt = ax2.plot(batches, loss, color='blue', label='loss')
ax2.set_ylabel('loss')
ax2.set_yticklabels(np.linspace(3, 10, 9))
new_tick_locations = np.arange(1, 11)*10
new_tick_labels = range(1, 11)
ax3.set_xlabel('epochs')
ax3.set_xlim(ax1.get_xlim())
ax3.set_xticks(new_tick_locations)
ax3.set_xticklabels(new_tick_labels)
# Set legends and show the plot
推荐阅读
- azure - Cortana 仍在呼叫我已删除的机器人
- python - 我们如何修复以下错误“InternalError:无法将元素作为字节获取”?
- vue.js - 如何覆盖 vue cli-service 入口设置
- ios - 如何将uiview录制为带有音频iOS swift的视频
- android - 使用来自 kotlinx-coroutine-test 的类时出错
- android - Android:ImageView src 忽略父元素的 layout_weight 属性
- angular - 如何使用带角度的物化步进器
- php - 致命错误:未捕获的错误:在 bool 上调用成员函数 RowCount()
- python - 排列中间数组长度为 M 的 N 个列表的所有可能性
- mysql - 插入空值