python - 使用顺序模型预测()方法的正确方法
问题描述
我是使用 TensorFlow 进行机器学习的新手。我在下面的代码中构建了一个模型。该模型训练和测试成功。我的数据集如下所示:
[![在此处输入图像描述][1]][1]
模型训练好后,我想手动输入一些数据进行测试,比如这样:
test_row = [57, 1, 0, 140, 192, 0, 1, 148, 0, 0.4, 1, 0, 1, 1]
但是,当我尝试使用numpy.array
将该列表转换为 numpy 数组格式时
np_array = numpy.array(test_row)
按照堆栈溢出帖子之一中的说明,然后使用
result = model.predict(np_array)
预测结果我得到一个错误。我认为我使用的predict()
方法不正确,但是我花了 5 个小时在这上面,找不到解决这个问题的好方法。
file_name = "heart.csv"
data=pd.read_csv(file_name) #store data to variab
feature_columns = [] #combined features to input to the model
data["cp"] = data["cp"].apply(str)#represent data in cp as String
cp = tf.feature_column.categorical_column_with_vocabulary_list(
'cp', ['0', '1', '2', '3'])#create one-hot vector from the string
cp_one_hot = tf.feature_column.indicator_column(cp) #mapped to numeric value
feature_columns.append(cp_one_hot)
#same for restecg
data["restecg"] = data["restecg"].apply(str)#represent data in cp as String
restecg = tf.feature_column.categorical_column_with_vocabulary_list(
'restecg', ['0', '1', '2'])#create one-hot vector from the string
restecg_one_hot = tf.feature_column.indicator_column(restecg) #mapped to numeric value
feature_columns.append(restecg_one_hot)
thalach = tf.feature_column.numeric_column("thalach")
feature_columns.append(thalach)
#same for restecg
data["slope"] = data["slope"].apply(str)#represent data in cp as String
slope = tf.feature_column.categorical_column_with_vocabulary_list(
'slope', ['1', '2', '3'])#create one-hot vector from the string
slope_one_hot = tf.feature_column.indicator_column(slope) #mapped to numeric value
feature_columns.append(slope_one_hot)
def create_data_set(self,df, size=32):
df = df.copy()
labels = df.pop('target')
return tf.data.Dataset.from_tensor_slices((dict(df),labels)).shuffle(buffer_size = len(df)).batch(size)
RANDOM_SEED = 42
train, test = train_test_split(data, test_size=0.2, random_state=RANDOM_SEED) #without random_state, every time this function run, it will generate different selection
train_set = self.create_data_set(train)
test_set = self.create_data_set(test)
#create model and train the model
model = tf.keras.models.Sequential([tf.keras.layers.DenseFeatures(feature_columns = feature_columns),
tf.keras.layers.Dense(units=128, activation='relu'),
tf.keras.layers.Dropout(rate=0.2),
tf.keras.layers.Dense(units=128,activation='relu'),
tf.keras.layers.Dense(units=1, activation = 'sigmoid')])
#compile
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
#training model fit train data and test data to model
model.fit(train_set,validation_data = test_set, epochs = 100, use_multiprocessing=True)
更新:代码导致错误
def start_predic(input_to_predict)
#format of input_to_predict is a DICTIONARY of String ex {'val1':'1', 'val2':'2', 'val3':'3','val4':'4','val5':'5',
'val6':'6','val7':'7','val8':'8','val9':'9',
'val10':'10','val10':'10','val11':'11','val12':'12','val13':'13'}
for k, v in input_to_predict.items():
if v != None :
input_to_predict[k] = float(v)
input_array_for_prediction = np.array(list(input_to_predict.values()))
#pass in data to predict the disease
result=model.predict(input_array_for_prediction)
更新:错误回溯
Traceback (most recent call last):
File "C:\ProgramData\Anaconda3\lib\site-packages\flask\app.py", line 2447, in wsgi_app
response = self.full_dispatch_request()
File "C:\ProgramData\Anaconda3\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request
rv = self.handle_user_exception(e)
File "C:\ProgramData\Anaconda3\lib\site-packages\flask\app.py", line 1821, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "C:\ProgramData\Anaconda3\lib\site-packages\flask\_compat.py", line 39, in reraise
raise value
File "C:\ProgramData\Anaconda3\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request
rv = self.dispatch_request()
File "C:\ProgramData\Anaconda3\lib\site-packages\flask\app.py", line 1936, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "C:\Users\adm\DiagnosisSystem\app\routes.py", line 35, in index
result = modules.heart_predict.start_predict(symptom_dict) #import data from UI to the model
File "C:\Users\adm\DiagnosisSystem\diagnosis\HeartDiagnosisSystem.py", line 171, in start_predict
result =self.model.predict(input_array_for_prediction)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\training.py", line 1751, in predict
tmp_batch_outputs = self.predict_function(iterator)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py", line 885, in __call__
result = self._call(*args, **kwds)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py", line 933, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py", line 759, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py", line 3066, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py", line 3463, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py", line 3298, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\framework\func_graph.py", line 1007, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py", line 668, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\framework\func_graph.py", line 994, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\training.py:1586 predict_function *
return step_function(self, iterator)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\training.py:1576 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\distribute\distribute_lib.py:1286 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\distribute\distribute_lib.py:2849 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\distribute\distribute_lib.py:3632 _call_for_each_replica
return fn(*args, **kwargs)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\training.py:1569 run_step **
outputs = model.predict_step(data)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\training.py:1537 predict_step
return self(x, training=False)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\base_layer.py:1037 __call__
outputs = call_fn(inputs, *args, **kwargs)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\sequential.py:383 call
outputs = layer(inputs, **kwargs)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\engine\base_layer.py:1037 __call__
outputs = call_fn(inputs, *args, **kwargs)
C:\Users\adm\AppData\Roaming\Python\Python38\site-packages\keras\feature_column\dense_features.py:158 call **
raise ValueError('We expected a dictionary here. Instead we got: ',
ValueError: ('We expected a dictionary here. Instead we got: ', <tf.Tensor 'ExpandDims:0' shape=(None, 1) dtype=float32>)
127.0.0.1 - - [22/Oct/2021 19:03:41] "POST / HTTP/1.1" 500 -
[1]: https://i.stack.imgur.com/xCgYB.png
解决方案
推荐阅读
- reactjs - 如何在 nextjs 中使用不同的 .env 文件?
- scala - 如何迭代地将 Dataframe 带回 Spark 中的驱动程序
- wso2 - 如何在 Ballerina 中获取变量的类型?
- c# - .Net Core Blazor 发生类型错误,我试图转换它,但它不起作用
- ruby-on-rails - Rails 6 Font Awesome 5 生产问题
- mysql - MySQL多对多关系计数
- python - 有没有办法对python pandas中的投资组合标准差进行矢量化
- python - Python - 使用路径打开文件
- security - JWT - 刷新令牌和安全改进
- javascript - 未捕获的 TypeError: $(...).ThreeSixty 不是函数