python - 从 CPP 程序用 Python 执行神经网络代码
问题描述
有一个 MNIST 神经网络的 Python 代码,我们用它来识别脚本中的古吉拉特语字符。为此,我们借助 C++ 语言的 QT 框架开发了一个 GUI。目前,神经网络代码在 python 中执行,但我们在通过 C++ 执行 MNIST 模型代码时遇到了问题。我找到了一个代号为 Keras2CPP 的 GitHub,它有一个 dump_to_simple_cpp.py 文件。该文件接受 weights.h5 权重文件和 model_json 文件,然后生成一个 dumped.nnet 文件。dump_to_simple_cpp.py 文件如下:
import numpy as np
np.random.seed(1337)
from keras.models import Sequential, model_from_json
import json
import argparse
np.set_printoptions(threshold=np.inf)
parser = argparse.ArgumentParser(description='This is a simple script
to dump Keras model into simple format suitable for porting into pure
C++ model')
parser.add_argument('-a', '--architecture', help="JSON with model
architecture", required=True)
parser.add_argument('-w', '--weights', help="Model weights in HDF5
format", required=True)
parser.add_argument('-o', '--output', help="Ouput file name",
required=True)
parser.add_argument('-v', '--verbose', help="Verbose", required=False)
args = parser.parse_args()
print('Read architecture from', args.architecture)
print('Read weights from', args.weights)
print('Writing to', args.output)
arch = open(args.architecture).read()
model = model_from_json(arch)
model.load_weights(args.weights)
model.compile(loss='categorical_crossentropy', optimizer='adadelta')
arch = json.loads(arch)
with open(args.output, 'w') as fout:
fout.write('layers ' + str(len(model.layers)) + '\n')
layers = []
for ind, l in enumerate(arch["config"]):
if args.verbose:
print(ind, l)
fout.write('layer ' + str(ind) + ' ' + l['class_name'] + '\n') #line number: 33
if args.verbose:
print(str(ind), l['class_name'])
layers += [l['class_name']]
if l['class_name'] == 'Convolution2D':
W = model.layers[ind].get_weights()[0]
if args.verbose:
print(W.shape)
fout.write(str(W.shape[0]) + ' ' + str(W.shape[1]) + ' ' + str(W.shape[2]) + ' ' + str(W.shape[3]) + ' ' + l['config']['border_mode'] + '\n')
for i in range(W.shape[0]):
for j in range(W.shape[1]):
for k in range(W.shape[2]):
fout.write(str(W[i,j,k]) + '\n')
fout.write(str(model.layers[ind].get_weights()[1]) + '\n')
if l['class_name'] == 'Activation':
fout.write(l['config']['activation'] + '\n')
if l['class_name'] == 'MaxPooling2D':
fout.write(str(l['config']['pool_size'][0]) + ' ' + str(l['config']['pool_size'][1]) + '\n')
if l['class_name'] == 'Dense':
W = model.layers[ind].get_weights()[0]
if args.verbose:
print(W.shape)
fout.write(str(W.shape[0]) + ' ' + str(W.shape[1]) + '\n')
for w in W:
fout.write(str(w) + '\n')
fout.write(str(model.layers[ind].get_weights()[1]) + '\n')
在以下命令的帮助下执行上述代码:
python3 dump_to_simple_cpp.py -a model_json -w model.h5 -o dumped.nnet
抛出以下错误:
File "dump_to_simple_cpp.py", line 33, in <module>
fout.write('layer ' + str(ind) + ' ' + l['class_name'] + '\n')
TypeError: string indices must be integers
谁能帮我解决上述错误?
解决方案
推荐阅读
- python - ConnectionPatch 在交叉轴时中断 constrained_layout
- html - 在 EDGAR 网站上查找快速搜索文本框的 value 属性
- c++ - 获取访问冲突读取位置但不确定原因
- android - 带有png源的androidx小部件AppCompatImageView不会在棒棒糖设备中膨胀?
- powershell - 为什么我的 DateTime 对象的行为不同?
- c# - C# (.NET) WebAPI:在调用堆栈中显示更多项目
- sql - 如何遍历 CTE 中的每个值
- c++ - Visual Studio 2019:为什么我的 IDE 无法检查标头语法?
- java - Java 文件中的安全密钥,例如 API 密钥等免受黑客攻击
- javascript - JavaScript 智能感知无法正常工作