python - 如果 __name__ == '__main__' 在内部训练神经网络和没有它之间的区别
问题描述
如果有人能解释这段代码之间的区别,我将不胜感激 -
models = (encoder, decoder)
data = (x_test, y_test)
# VAE loss = mse_loss or xent_loss + kl_loss
reconstruction_loss = mse(inputs, outputs)
reconstruction_loss *= original_dim
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
# train the autoencoder
vae.fit([x_train,y_train_1hot],
epochs=epochs,
batch_size=batch_size,
validation_data=([x_test,y_test], None))
和这个 -
if __name__ == '__main__':
parser = argparse.ArgumentParser()
help_ = "Load h5 model trained weights"
parser.add_argument("-w", "--weights", help=help_)
help_ = "Use mse loss instead of binary cross entropy (default)"
parser.add_argument("-m",
"--mse",
help=help_, action='store_true')
args = parser.parse_args(args=[])
models = (encoder, decoder)
data = (x_test, y_test)
# VAE loss = mse_loss or xent_loss + kl_loss
if args.mse:
reconstruction_loss = mse(inputs, outputs)
else:
reconstruction_loss = binary_crossentropy(inputs,
outputs)
reconstruction_loss *= original_dim
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
#if args.weights:
#vae.load_weights(args.weights)
#else:
# train the autoencoder
vae.fit([x_train,y_train_1hot],
epochs=epochs,
batch_size=batch_size,
validation_data=([x_test,y_test_1hot], None))
#vae.save_weights('vae_mlp_mnist.h5')
据我了解,它们都是相同的,并且我没有保存权重,并且已经注释了加载权重代码,即使重建看起来相同,两者的验证损失也是不同的。我不明白为什么。
解决方案
如果您运行第二个片段应该没有区别python <yourmodule> --mse
(否则,您的第二个片段默认为二进制交叉熵,而您的第一个片段始终使用 mse)。
除此之外,if __name__ == '__main__'
守卫仅更改脚本的行为,以便仅在运行模块(文件)时执行包含的代码。即,当导入你的模块时,第一个片段会被执行,第二个不会。
推荐阅读
- flutter - google_maps_flutter 在发行版中不起作用
- html - 如何使引导崩溃响应?
- database - 如何在文件浏览中放置多个文件
- ios - uitableview 单元格第一次仍然无法正确删除
- matlab - 创建一个忽略 NaN 条目的新数组
- c# - .Net Standard 或 .Net Core 中的“服务器违反协议。Section=ResponseHeader Detail=CR 必须后跟 LF”
- php - “/Users/l2sap/Documents/laravel-dev/projects/fourth/storage/logs”中没有现有目录,无法创建:权限被拒绝
- php - 如何将使用 tcpdf 创建的动态创建的 pdf 转换为使用 imagick 的图像?
- windows - Ninja 将任何 add_subdirectory(foo) 预先添加到所有路径中,导致使用 vcpkg 在 Windows 上的 FindFirstFileExA 处出错
- mysql - MySQL:有没有一种直接的方法可以确保给定字符的每个实例实际上都是同一个字符?