首页 > 解决方案 > 如果 __name__ == '__main__' 在内部训练神经网络和没有它之间的区别

问题描述

如果有人能解释这段代码之间的区别,我将不胜感激 -

models = (encoder, decoder)
data = (x_test, y_test)

# VAE loss = mse_loss or xent_loss + kl_loss
reconstruction_loss = mse(inputs, outputs)
reconstruction_loss *= original_dim
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')

# train the autoencoder
vae.fit([x_train,y_train_1hot],
    epochs=epochs,
    batch_size=batch_size,
    validation_data=([x_test,y_test], None))

和这个 -

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load h5 model trained weights"
    parser.add_argument("-w", "--weights", help=help_)
    help_ = "Use mse loss instead of binary cross entropy (default)"
    parser.add_argument("-m",
                        "--mse",
                        help=help_, action='store_true')
    args = parser.parse_args(args=[])
    models = (encoder, decoder)
    data = (x_test, y_test)

    # VAE loss = mse_loss or xent_loss + kl_loss
    if args.mse:
        reconstruction_loss = mse(inputs, outputs)
    else:
        reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)

    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')

    #if args.weights:
        #vae.load_weights(args.weights)
    #else:
        # train the autoencoder
    vae.fit([x_train,y_train_1hot],
        epochs=epochs,
        batch_size=batch_size,
        validation_data=([x_test,y_test_1hot], None))
        #vae.save_weights('vae_mlp_mnist.h5')

据我了解,它们都是相同的,并且我没有保存权重,并且已经注释了加载权重代码,即使重建看起来相同,两者的验证损失也是不同的。我不明白为什么。

标签: python

解决方案


如果您运行第二个片段应该没有区别python <yourmodule> --mse(否则,您的第二个片段默认为二进制交叉熵,而您的第一个片段始终使用 mse)。

除此之外,if __name__ == '__main__'守卫仅更改脚本的行为,以便仅在运行模块(文件)时执行包含的代码。即,当导入你的模块时,第一个片段会被执行,第二个不会。


推荐阅读