python-3.x - AttributeError: 'list' 对象没有属性 'argmax' 和 numpy.AxisError: 轴 2 超出维度 1 数组的范围
问题描述
我有一个预测 keras NMT 模型的功能,它工作正常,但它读取文件,然后将预测保存到另一个文件
但我想要用户输入并做出预测
完整代码
from keras.preprocessing.sequence import pad_sequences
from keras.models import load_model
import sys
import warnings
import argparse
from seq2seq_utils import *
ap = argparse.ArgumentParser()
ap.add_argument('-max_len', type=int, default=95)
ap.add_argument('-vocab_size', type=int, default=31500)
args = vars(ap.parse_args())
MAX_LEN = args['max_len']
VOCAB_SIZE = args['vocab_size']
def load_test_data(user_input, X_word_to_ix, max_len):
user_input = input()
X = [text_to_word_sequence(x)[::-1] for x in user_input.split('\n') if 0 < len(x) <= max_len]
for i, sentence in enumerate(X):
for j, word in enumerate(sentence):
if word in X_word_to_ix:
X[i][j] = X_word_to_ix[word]
elif word in X_word_to_ix is None:
X[i][j] = None
else:
X[i][j] = X_word_to_ix['UNK']
return X
model = load_model('model.h5')
model.get_weights()
X, X_vocab_len, X_word_to_ix, X_ix_to_word, y, y_vocab_len, y_word_to_ix, y_ix_to_word = load_data('english.txt',
'french.txt',
MAX_LEN,
VOCAB_SIZE)
saved_weights = find_checkpoint_file('.')
print('please enter the value')
user_input1 = input()
if len(saved_weights) == 0:
print("The network hasn't been trained! Program will exit...")
sys.exit()
else:
X_test = load_test_data(user_input1, X_word_to_ix, MAX_LEN)
print(type(X_test))
X_test = pad_sequences(X_test, maxlen=4, dtype='int32')
print(type(X_test))
arr=np.array(X_test)
model.load_weights(saved_weights)
predictions = np.argmax(model.predict(arr))
# predictions = np.argmax(list(model.predict(X_test)))
print(type(predictions))
# predictions = np.argmax(model.predict(X_test),axis=0)
sequences = []
print('1')
for prediction in predictions:
print('2')
sequence = ' '.join([y_ix_to_word[index] for index in prediction if index > 0])
print(sequence)
l = sequences.append(sequence)
print(l)
np.savetxt('test_result.txt', sequences, fmt='%s')
AttributeError:“列表”对象没有属性“argmax”
在处理上述异常的过程中,又出现了一个异常:
numpy.AxisError:轴 2 超出维度 1 数组的范围
为什么?并感谢您的帮助
解决方案
推荐阅读
- spring - 使用 Kotlin、SpringBoot 和 Mockk 的 POST 方法出错
- mysql - Telnet 在一个网络上工作,但在尝试在端口 3306 上连接 EC2 实例时无法在另一个网络上工作
- javascript - 多个具有不同值的不同按钮,但只使用第一个?
- api - 你如何从带有 curl 的 Flask 应用程序中获得响应?
- r - R通过带有facet_wrap和数据子集的函数调用为ggplot传递参数
- scala - 使用 Circe 解码消息时,是否可以从 DecodingFailure 中提取无效值
- bots - Discord Bot 发布频道链接
- algorithm - 找到一种时间复杂度为 O(n + k*log(k)) 的整数排序算法
- javascript - 将字符串推送到 for 循环期间创建的数组的开头和结尾
- swift - 当 self 尚未初始化时,如何在属性包装器中使用现有属性?(斯威夫特用户界面)