首页 > 解决方案 > 将两个图形保存在同一个 png 文件中(matplotlib)

问题描述

我正在运行以下 python 函数:

def plot_keras_history(history, embeddings_dimension, batch_size, version_data_control): #where history =  model.fit()
    """

    :param history: 
    :return: 
    """
    # the history object gives the metrics keys. 
    # we will store the metrics keys that are from the training sesion.
    metrics_names = [key for key in history.history.keys() if not key.startswith('val_')]

    for i, metric in enumerate(metrics_names):

        # getting the training values
        metric_train_values = history.history.get(metric, [])

        # getting the validation values
        metric_val_values = history.history.get("val_{}".format(metric), [])

        # As loss always exists as a metric we use it to find the 
        epochs = range(1, len(metric_train_values) + 1)

        # leaving extra spaces to align with the validation text
        training_text = "   Training {}: {:.5f}".format(metric,
                                                        metric_train_values[-1])
        # metric
        fig1 = plt.gcf()

        plt.figure(i, figsize=(12, 6))

        plt.plot(epochs,
                 metric_train_values,
                 'b',
                 label=training_text)

        # if we validation metric exists, then plot that as well
        if metric_val_values:
            validation_text = "Validation {}: {:.5f}".format(metric,
                                                             metric_val_values[-1])

            plt.plot(epochs,
                     metric_val_values,
                     'g',
                     label=validation_text)

        # add title, xlabel, ylabe, and legend
        plt.title('Model Metric: {}'.format(metric))
        plt.xlabel('Epochs')
        plt.ylabel(metric.title())
        plt.legend()

        plt.draw()

        fig1.savefig(os.path.join(os.getcwd(), 'plot_two_figures_{0}_{1}_{2}.png'.format(embeddings_dimension, batch_size, version_data_control)), dpi=100)
        plt.show()
        plt.close()

上述函数在 Jupyter Notebook 的 iPyhton 单元格中打印:

在此处输入图像描述

但是 savefig png 文件只包含其中一个。

在此处输入图像描述

保存两张图时出现什么问题,并且只保存了其中一张?并且形状错误。它应该是 12x6。

按照评论中的建议将图形更改为 subplot() :

def plot_keras_history(history, embeddings_dimension, batch_size, version_data_control): #where history =  model.fit()
    """

    :param history: 
    :return: 
    """
    # the history object gives the metrics keys. 
    # we will store the metrics keys that are from the training sesion.
    metrics_names = [key for key in history.history.keys() if not key.startswith('val_')]

    fig = plt.gcf()
    fig, axs = plt.subplots(2)

    for i, metric in enumerate(metrics_names):

        # getting the training values
        metric_train_values = history.history.get(metric, [])

        # getting the validation values
        metric_val_values = history.history.get("val_{}".format(metric), [])

        # As loss always exists as a metric we use it to find the 
        epochs = range(1, len(metric_train_values) + 1)

        # leaving extra spaces to align with the validation text
        training_text = "   Training {}: {:.5f}".format(metric,
                                                        metric_train_values[-1])
        # metric
        plt.figure(i, figsize=(12, 6))

        axs[i].plot(epochs,
                 metric_train_values,
                 'b',
                 label=training_text)

        # if we validation metric exists, then plot that as well
        if metric_val_values:
            validation_text = "Validation {}: {:.5f}".format(metric,
                                                             metric_val_values[-1])

            axs[i].plot(epochs,
                     metric_val_values,
                     'g',
                     label=validation_text)

            # add title, xlabel, ylabe, and legend
            plt.title('Model Metric: {}'.format(metric))
            plt.xlabel('Epochs')
            plt.ylabel(metric.title())
            plt.legend()

            plt.draw()

            fig.savefig(os.path.join(os.getcwd(), 'model_one\\ploting_training_validation_performance_{0}_{1}_{2}.png'.format(embeddings_dimension, batch_size, version_data_control)), dpi=100)
            plt.show()
            plt.close()

,我收到以下错误:

在此处输入图像描述

我的代码的第三次更改(认为我做到了):

def plot_keras_history(history, embeddings_dimension, batch_size, version_data_control): #where history =  model.fit()
    """

    :param history: 
    :return: 
    """
    # the history object gives the metrics keys. 
    # we will store the metrics keys that are from the training sesion.
    metrics_names = [key for key in history.history.keys() if not key.startswith('val_')]

    fig = plt.gcf()
    fig, axs = plt.subplots(2)

    for i, metric in enumerate(metrics_names):

        # getting the training values
        metric_train_values = history.history.get(metric, [])

        # getting the validation values
        metric_val_values = history.history.get("val_{}".format(metric), [])

        # As loss always exists as a metric we use it to find the 
        epochs = range(1, len(metric_train_values) + 1)

        # leaving extra spaces to align with the validation text
        training_text = "   Training {}: {:.5f}".format(metric,
                                                        metric_train_values[-1])
        axs[i].plot(epochs,
                 metric_train_values,
                 'b',
                 label=training_text)

        # if we validation metric exists, then plot that as well
        if metric_val_values:
            validation_text = "Validation {}: {:.5f}".format(metric,
                                                             metric_val_values[-1])

            axs[i].plot(epochs,
                     metric_val_values,
                     'g',
                     label=validation_text)

        # add title, xlabel, ylabe, and legend
        plt.title('Model Metric: {}'.format(metric))
        plt.xlabel('Epochs')
        plt.ylabel(metric.title())
        plt.legend()

    fig.savefig(os.path.join(os.getcwd(), 'model_one\\ploting_training_validation_performance_{0}_{1}_{2}.png'.format(embeddings_dimension, batch_size, version_data_control)), dpi=100)
    plt.show()
    plt.draw()
    plt.close()

标准输出: 在此处输入图像描述

标签: pythonmatplotlibjupyter-notebook

解决方案


推荐阅读