首页 > 解决方案 > 从另一个函数内部悬停事件处理

问题描述

我正在尝试将新代码集成到其他人编写的现有代码中,但遇到了一些问题。现有代码用于matplotlib制作一个 GUI 绘图仪,该绘图仪可以在给定输入文件的情况下绘制各种波形。我希望能够将鼠标悬停在图表上的任何轨迹上并有一个注释框显示它是哪条线(想象一个图表上有 30 条线并且无法将它们彼此区分开来)。我找到了这段代码(我要离开第一个答案):将鼠标悬停在 matplotlib 中的某个点上时可能出现标签?

这是代码:

import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)

x = np.random.rand(15)
y = np.random.rand(15)
names = np.array(list("ABCDEFGHIJKLMNO"))
c = np.random.randint(1,5,size=15)

norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn

fig,ax = plt.subplots()
sc = plt.scatter(x,y,c=c, s=100, cmap=cmap, norm=norm)

annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)

def update_annot(ind):

pos = sc.get_offsets()[ind["ind"][0]]
annot.xy = pos
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))), 
                       " ".join([names[n] for n in ind["ind"]]))
annot.set_text(text)
annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]])))
annot.get_bbox_patch().set_alpha(0.4)


def hover(event):
    vis = annot.get_visible()
    if event.inaxes == ax:
    cont, ind = sc.contains(event)
    if cont:
        update_annot(ind)
        annot.set_visible(True)
        fig.canvas.draw_idle()
    else:
        if vis:
            annot.set_visible(False)
            fig.canvas.draw_idle()

fig.canvas.mpl_connect("motion_notify_event", hover)

plt.show()

现有代码ax在绘图函数内部定义。如果不是很长,我会将整个函数粘贴在这里,但这里有一个片段(下面是上面的一些代码):

            else:
                print ('The label is: %s' % label)
                ax = plt.subplot('111')
                axesDict[labelKey] = ax
            #end if
#
            annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
                                bbox=dict(fc="b"),
                                arrowprops=dict(arrowstyle="->"))
            annot.set_visible(True)

            fig.canvas.mpl_connect("motion_notify_event", hover)
#

问题是我不知道如何传递ax到悬停函数,因为由于mpl_connect.

我对 Python 非常非常陌生,处理这种大小的现有代码一直是一个挑战。也许我正在错误地考虑实现,请随时指出这一切!我确信我还有更多问题,但这是一个好的开始。提前感谢您的帮助和时间。

编辑:这是长绘图功能(这只是我处理的第一部分):

    def plotData(self, refreshPlotAxes = False):

        if len(self.waveformObjectList) == 0:
            print ('no waveforms to plot')
            return
        #end if

        startFigureNumber = self.startFigureNumber
        nextFigureNumber = startFigureNumber

        if self.fileDataTypeMode == 'ascii':
            markerArray = self.defaultMarkerArray
        else:
            markerArray = ['']

        waveformIndexList = self.getFilteredWaveformObjectIndexList()

        ###################### First Plot #############################

        if self.plotFreqResp:
            firstLoop = True
            markerIndex = 0
            #which labels are in each figure
            xAxisLabelDictionary = {}
            yAxisLabelDictionary = {}
            subplotDictionary = {}   #subplots for each figure
            plotAxisDictionary = {}  #plot axis for every subplot
            #at the moment, I don't support multiple figures and multiple subplots at the same time,
            #but I might someday
            logXDictDict = {}
            logYDictDict = {}

            plotFilename = 'blank_freqresp.png'

            plotAxisList = []
            numberOfFigures = 0
            numberOfSubPlots = 0
            numberOfLabels = 0

            #set up the plots
            axesDict = {}

            labelList = []
            for waveformObj in self.waveformObjectList:
                label = waveformObj.label
                labelPieces = label.split('_')
                labelList.append(labelPieces)
            #end for waveformObj

            commonLabelPieces = []
            if len(labelList) > 1:
                labelPieces0 = labelList[0]
                for labelPiece in labelPieces0:
                    isCommon = True
                    for labelPieces in labelList:
                        if labelPieces.count(labelPiece) == 0:
                            isCommon = False
                            break
                        #end if
                    #end for
                    if isCommon:
                        commonLabelPieces.append(labelPiece)
                    #end if
                #end for labelPiece
            #end if

            for waveformIndex in waveformIndexList:
                waveformObj = self.waveformObjectList[waveformIndex]

                plotFilename = waveformObj.filename
                [plotFilename, ext] = os.path.splitext(plotFilename)
                plotFilename += '_freqresp.png'

                if firstLoop or (self.plot1SeparatePlots and not self.plot1SubPlots):
                    currentFigureNumber = nextFigureNumber
                    logXDictDict[currentFigureNumber] = {}
                    logYDictDict[currentFigureNumber] = {}
                    nextFigureNumber += 1
                    numberOfFigures += 1
                    figureTxt = 'Figure %d - %s' % (currentFigureNumber, self.appTitle)
                    fig = plt.figure(figureTxt, figsize=self.cwPlotSize)
                #end if

                label = waveformObj.getLabel(shortLabel = self.shortLabel, includeXLabel = self.showXInLabel)
                shortLabel = waveformObj.getLabel(shortLabel = True, includeXLabel = self.showXInLabel)

                if self.enableShortenedLabels:
                    label = waveformObj.label
                    labelPieces = label.split('_')
                    uniqueLabelPieces = []
                    for labelPiece in labelPieces:
                        if commonLabelPieces.count(labelPiece) == 0:
                            uniqueLabelPieces.append(labelPiece)
                        #end if
                    #end for
                    label = '_'.join(uniqueLabelPieces)
                    label += '(' + shortLabel + ')'
                #end if

                try:
                    if waveformObj.hasReference():
                        label += '%s%s @ %s' % (waveformObj.referenceWaveformOperation, waveformObj.referenceWaveform, waveformObj.referenceWaveformFreq)
                    #end if
                except:
                    pass

                [xAxisLabel, yAxisLabel] = waveformObj.axisLabels()[0:2]
                if xAxisLabel == 'none':
                    xAxisLabel = waveformObj.getDataLabels()[0]
                if yAxisLabel == 'none' or yAxisLabel == 'mag':
                    yAxisLabel = waveformObj.getDataLabels()[1]

##                print ('data labels = %s' % str([xAxisLabel, yAxisLabel]))
##                print ('shortLabel = %s' % shortLabel
##                print ('label = %s' % label

                #when there is just one subplot (the default), it's designated '111'
                subplotString = '1'
                logX = self.logHorizontalAxis
                dbY = self.dBVerticalAxis
                if self.plot1SubPlots:
                    subplotString = '000'
                    for subplotNum in self.plot1SubPlotDict['filter'].keys():
                        matchList = self.plot1SubPlotDict['filter'][subplotNum]
                        for matchItem in matchList:
                            if re.search(matchItem, shortLabel):
                                subplotString = subplotNum
                                break
                            #end if
                        #end for
                    #end for

                    if subplotString == '000':
                        firstLoop = False
                        continue

                    try:
                        logX = self.plot1SubPlotDict['xlog'][subplotString]
                    except:
                        pass

                    try:
                        dbY = self.plot1SubPlotDict['ydb'][subplotString]
                    except:
                        pass

                #end if

#                if waveformObj.yUnits.lower().count('db'):
#                    yData = waveformObj.getNormalizeddBVector()
#                    logY = False
                if waveformObj.yUnits.lower().count('bits') or \
                     waveformObj.yUnits.lower().count('data'):
                    yData = waveformObj.getMagnitudeVector()
                    logY = False
                    dbY = False
                    forceLinearYAxis = True
                else:
                    forceLinearYAxis = False
                    if dbY:
                        yData = waveformObj.getNormalizeddBVector(self.absoluteValueForDB)
                        logY = False
                    else:
                        yData = waveformObj.getNormalizedMagnitudeVector()
                        logY = self.logVerticalAxis
                    #end if
                #end if

                fData = waveformObj.getFreqVector()

                labelKey = str(currentFigureNumber) + '_' + subplotString

                if not labelKey in xAxisLabelDictionary:
                    xAxisLabelDictionary[labelKey] = []
                if not labelKey in yAxisLabelDictionary:
                    yAxisLabelDictionary[labelKey] = []
                if not currentFigureNumber in subplotDictionary:
                    subplotDictionary[currentFigureNumber] = []

                xAxisLabelDictionary[labelKey].append(xAxisLabel)
                yAxisLabelDictionary[labelKey].append(yAxisLabel)

                plot1FormatMatchesKey = False
                for key in self.plot1Format.keys():

                    if re.search(key, waveformObj.yLabel) or re.search(key, waveformObj.label):
                        plot1FormatMatchesKey = True
                        break
                    elif re.search(key, label):
                        plot1FormatMatchesKey = True
                        break
                    #end if
                #end for key

                if plot1FormatMatchesKey:
                    pltFormatText = self.plot1Format[key][0]
                    pltLineWidth = self.plot1Format[key][1]
                    pltMarkerSize = self.plot1Format[key][2]
                    allowLabel = self.plot1Format[key][3]
                    if len(self.plot1Format[key]) > 4:
                        markerColor = self.plot1Format[key][4]
                    else:
                        markerColor = -1

                    if pltFormatText is None:
                        pltFormatText = markerArray[markerIndex]+'-'
                        markerIndex += 1
                    if pltLineWidth < 0:
                        pltLineWidth = self.defaultLineWidth
                    if pltMarkerSize < 0:
                        pltMarkerSize = self.defaultMarkerSize
                    if not allowLabel:
                        label = ''
                    if markerColor != -1:
                        markerEdgeColor = None
                        markerEdgeWidth = self.defaultMarkerEdgeWidth
                        markerFaceColor = markerColor
                    else:
                        markerEdgeColor = None
                        markerEdgeWidth = self.defaultMarkerEdgeWidth
                        markerFaceColor = None
                    #end if

                else:
                    pltFormatText = markerArray[markerIndex] + self.defaultLinePattern
                    markerIndex += 1
                    pltLineWidth = self.defaultLineWidth
                    pltMarkerSize = self.defaultMarkerSize
                    markerEdgeColor = None
                    markerEdgeWidth = self.defaultMarkerEdgeWidth
                    markerFaceColor = None
                #end if

                if markerIndex >= len(markerArray):
                    markerIndex = 0

                if labelKey in axesDict:
                    try:
                        plt.sca(axesDict[labelKey])
                    except:
                        print ('something went wrong with subplot label %s' % labelKey)
                        print ('probably due to overlapping subplots.')
                        print ('make adjustments to the figInfoDict items')
                    #end try
                elif self.plot1SubPlots:
                    gridShape = self.plot1SubPlotDict['gridShape']
                    subplotInfo = self.plot1SubPlotDict['figInfoDict'][subplotString]
                    ax = plt.subplot2grid(gridShape, subplotInfo[0], subplotInfo[1], subplotInfo[2])
                    axesDict[labelKey] = ax
                else:
                    print ("Made it inside else condition")
                    print ('The label is: %s' % label)
                    ax = plt.subplot('111')
                    axesDict[labelKey] = ax
                #end if

    #
                annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
                                    bbox=dict(fc="b"),
                                    arrowprops=dict(arrowstyle="->"))
                annot.set_visible(True)

                h = lambda x: hover(x, annot, label)

                fig.canvas.mpl_connect("motion_notify_event", h)
#

格式化绘图

for p in range(numberOfFigures):
figureNumber = p + startFigureNumber

figureTxt = 'Figure %d - %s' % (figureNumber, self.appTitle)
plt.figure(figureTxt)

if not figureNumber in subplotDictionary:
    continue

for subplotString in subplotDictionary[figureNumber]:

    labelKey = str(figureNumber) + '_' + subplotString
    try:
        plt.sca(axesDict[labelKey])
    except:
        print ('something went wrong with subplot label %s' % labelKey)
        print ('probably due to overlapping subplots.')
        print ('make adjustments to the figInfoDict items')
        continue
    #end try
    #plt.subplot(subplotString)
    plotAxis = plotAxisDictionary[labelKey]
    #print ('start misc plot settings';
    plt.grid(self.plot1Grid, 'both')

    plot1YticksList = self.plot1YticksList
    plot1XticksList = self.plot1XticksList
    plot1YLimits = self.cwPlotYLimits
    plot1XLimits = self.cwPlotXLimits
    vcursors = []

    logX = logXDictDict[figureNumber][subplotString]
    logY = logYDictDict[figureNumber][subplotString]

    enablePlotXLabel = True
    legendEnable = True

    if self.plot1SubPlots:
        if not logY:
            try:
                plot1YticksList = self.plot1SubPlotDict['yticks'][subplotString]
            except:
                pass
        else:
            plot1YticksList = []
        #end if

        if not logX:
            try:
                plot1XticksList = self.plot1SubPlotDict['xticks'][subplotString]
            except:
                pass
        else:
            plot1XticksList = []
        #end if

        try:
            plot1YLimits = self.plot1SubPlotDict['ylimits'][subplotString]
        except:
            pass

        try:
            plot1XLimits = self.plot1SubPlotDict['xlimits'][subplotString]
        except:
            pass

        try:
            vcursors = self.plot1SubPlotDict['vcursors'][subplotString]
        except:
            pass

        try:
            enablePlotXLabel = self.plot1SubPlotDict['xLabelEnable'][subplotString]
        except:
            pass
        #end

        try:
            legendEnable = self.plot1SubPlotDict['legendEnable'][subplotString]
        except:
            pass
        #end

    #end if

    if logY:
        for tick in plot1YticksList:
            if tick <= 0:
                plot1YticksList = []
                break
            #end if
        #end for
        if len(plot1YLimits) == 2:
            if plot1YLimits[0] <= 0:
                plot1YLimits = []
            #end if
        #end if
    #end if

    if len(plot1YticksList):
        plt.yticks(plot1YticksList)
    if len(plot1XticksList):
        plt.xticks(plot1XticksList)

    if plotAxis == (0.0,1.0,0.0,1.0) or refreshPlotAxes:
        if len(plot1YLimits) == 2:
            plt.ylim(plot1YLimits)

        if len(plot1XLimits) == 2:
            plt.xlim(plot1XLimits)
    else:
        plt.axis(plotAxis)
    #end if

    if len(vcursors):
        ylimits = plt.ylim()
        for x in vcursors:
            plt.plot([x,x], ylimits, self.vcursorFormatText, linewidth = self.vcursorWidth)

    yAxisLabelListSet = list(set(yAxisLabelDictionary[labelKey]))
    if len(yAxisLabelListSet) == 1:
        yAxisLabel = yAxisLabelDictionary[labelKey][0]
    elif len(yAxisLabelListSet) > 1:
        yAxisLabel = yAxisLabelListSet[0]
        for buf in yAxisLabelListSet[1:]:
            yAxisLabel += ',' + buf
        #end for
    else:
        yAxisLabel = ''
    #end if

    xAxisLabelListSet = list(set(xAxisLabelDictionary[labelKey]))
    if len(xAxisLabelListSet) == 1:
        xAxisLabel = xAxisLabelDictionary[labelKey][0]
    elif len(xAxisLabelListSet) > 1:
        xAxisLabel = xAxisLabelListSet[0]
        for buf in xAxisLabelListSet[1:]:
            xAxisLabel += ',' + buf
        #end for
    else:
        xAxisLabel = ''
    #end if

    if not forceLinearYAxis:
        if dbY:
            if not waveformObj.yUnits.lower().count('db'):
                yAxisLabel += ' (dB)'
        else:
            yAxisLabel += ' (lin)'
    #end if

    plt.ylabel(yAxisLabel)
    if enablePlotXLabel:
        plt.xlabel(xAxisLabel)
    else:
        xtickList = plt.xticks()[0]
        plt.xticks(xtickList, '')
    #end if

    prop=matplotlib.font_manager.FontProperties(size=self.legendFontSize)
    if self.shortLabel:
        plt.title(waveformObj.filename, fontsize=12)
    #end if
    if self.cwPlotLegend and legendEnable:
        plt.legend(loc=self.plot1LegendLocation,prop=prop,borderpad=0.3,labelspacing=0.1,handletextpad=0,numpoints=self.numLegendPoints)
    #end if
    #print ('done'

#end for subplotString

plt.draw()

if self.savePlotAsImage:
    plt.savefig(plotFilename, format='png')

标签: pythonmatplotlib

解决方案


你可以这样做:

h = lambda x: hover(x, ax)
fig.canvas.mpl_connect("motion_notify_event", h)

然后将悬停功能更改为:

def hover(event, ax):
    ...

推荐阅读