python - 保存模型后得到错误的预测
问题描述
我的模型源代码:-
import numpy as np
import pandas as pd
dataset= pd.read_csv("heart900.csv")
X=dataset.iloc[:, :-1].values
Y=dataset.iloc[:, 13].values
from sklearn.impute import SimpleImputer
imputer = SimpleImputer(missing_values=np.nan, strategy='mean')
imputer= imputer.fit(X)
X= imputer.transform(X)
from sklearn.model_selection import train_test_split
X_Train, X_Test, Y_Train, Y_Test= train_test_split(X,Y, test_size=0.2, random_state=0)
from sklearn.preprocessing import StandardScaler
sc=StandardScaler()
X_Train=sc.fit_transform(X_Train)
X_Test=sc.fit_transform(X_Test)
import keras
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.models import load_model
import tensorflow as tf
import warnings
model=Sequential()
##First Hidden Layer
model.add(Dense(6, input_dim=13, activation='relu'))
##Second Hidden Layer
model.add(Dense(6, activation='relu'))
##Third Hidden Layer
model.add(Dense(6, activation='relu'))
##Output Layer
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(X_Train, Y_Train, epochs=150, batch_size=10)
new_pred= model.predict(sc.transform(np.array([[62,1,1,120,281,0,0,103,0,1.4,1,1,3]])))
new_pred= (new_pred>0.5)
print(new_pred )
model.save('keras_model.h5')`
这里的 O/P 是:- 假
以及用于访问已保存模型的我的源代码:-
from tensorflow.python.keras.models import load_model
import numpy as np
mp = load_model('keras_model.h5')
new_pre = mp.predict((np.array([[62,1,1,120,281,0,0,103,0,1.4,1,1,3]])))
new_pre = (new_pre>0.5)
print(new_pre)
这里的 O/P 是:- True(它应该是 False)
而且我已经尝试了所有可能的方法来保存 h5 模型,但结果预测每次都是错误的。请帮忙!!!!!
解决方案
您是否忘记在第二次预测时向输入数据添加转换?
mp = load_model('keras_model.h5')
new_pre = mp.predict(sc.transform((np.array([[62,1,1,120,281,0,0,103,0,1.4,1,1,3]]))))
new_pre = (new_pre>0.5)
print(new_pre)
推荐阅读
- windows - 为每个 Windows 事件日志写入文本文件
- javascript - ReactJS - 是从子组件提供参数更好,还是在父组件中提供参数 - 当回调相同时?
- typescript - 打字稿:如何使用父get方法获取子类实例
- tcl - 使用 Tcl 将 HTML5 画布转换为 PDF
- sql - 如何从表中检索最新数据
- java - Apache命令行解析器错误?
- xml - 如何使用 xslt 1.0 键计算而不是静态节点值?
- php - CakePHP 3中的BelongsToMany关联,如何在辅助字段中保持关系并保存关联?
- opengl - GLSL 几何着色器替换 glLineWidth
- angularjs - Mapbox-GL Angular - mgl-layer 通过 .json 文件添加符号系统