tensorflow - Google Colab 上的 Tensorflow-Keras 重现性问题
问题描述
我有一个可以在 Google Colab 上运行的简单代码(我使用 CPU 模式):
import numpy as np
import pandas as pd
## LOAD DATASET
datatrain = pd.read_csv("gdrive/My Drive/iris_train.csv").values
xtrain = datatrain[:,:-1]
ytrain = datatrain[:,-1]
datatest = pd.read_csv("gdrive/My Drive/iris_test.csv").values
xtest = datatest[:,:-1]
ytest = datatest[:,-1]
import tensorflow as tf
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.utils import to_categorical
## SET ALL SEED
import os
os.environ['PYTHONHASHSEED']=str(66)
import random
random.seed(66)
np.random.seed(66)
tf.set_random_seed(66)
from tensorflow.keras import backend as K
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
## MAIN PROGRAM
ycat = to_categorical(ytrain)
# build model
model = tf.keras.Sequential()
model.add(Dense(10, input_shape=(4,)))
model.add(Activation("sigmoid"))
model.add(Dense(3))
model.add(Activation("softmax"))
#choose optimizer and loss function
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
# train
model.fit(xtrain, ycat, epochs=15, batch_size=32)
#get prediction
classes = model.predict_classes(xtest)
#get accuration
accuration = np.sum(classes == ytest)/len(ytest) * 100
我已阅读设置以在此处创建可重现性代码Reproducible results using Keras with TensorFlow backend并将所有代码放在同一个单元格中。但是每次我运行该单元(使用 运行该单元shift + enter
)时,结果(例如损失)总是不同的。
在我的情况下,可以复制上面代码的结果,只要:
- 我使用“运行时”>“重新启动并运行所有”运行,或者,
- 我将该代码放在一个文件中并使用命令行 (
python3 file.py
)运行它
有什么我想念的东西可以在不重新启动运行时的情况下使结果重现吗?
解决方案
您还应该kernel_initializer
在Dense
图层中修复种子。因此,您的模型将如下所示:
model = tf.keras.Sequential()
model.add(Dense(10, kernel_initializer=keras.initializers.glorot_uniform(seed=66), input_shape=(4,)))
model.add(Activation("sigmoid"))
model.add(Dense(3, kernel_initializer=keras.initializers.glorot_uniform(seed=66)))
model.add(Activation("softmax"))
推荐阅读
- c - 从 pthread 向主线程发送消息
- php - Laravel 查询 - whereIn 或Where 组合
- reactjs - 如何重定向到受保护路由的登录页面?
- android - Flutter Tflite PlatformException 错误模型尚未正确提及
- c# - Asp.net 单控制器动作,根据路由数据具有不同的模型类型
- php - 将 ajax 标头中的 authtoken 发送到 Codeigniter REST API
- python - 如何在 jupyter notebook 上显示来自摄像头的视频
- abap - SELECT中的左右CP比较?
- android - 如何将 Safeargs 与多个 navGraph 一起使用
- java - 如何拆分 JsonElement