javascript - Tensorflow 模型未在 javascript 中训练
问题描述
TensorFlow.js 版本1.4.0。
描述我正在尝试用 javascript 训练模型,但模型无法学习或收敛。我采用了相同的模型和相同的数据,我在程序的 python 版本中使用了它们,所以我希望模型能够在同一阶段学习。相反,该模型未能改进,并且在第一次运行后验证准确度保持不变。python 模型能够达到约 70% 的准确率,而 javascript 模型在 50 个 epoch 后几乎不能达到 5% 以上。如果您想使用相同的数据,则 URL 可以工作。
重现错误的代码Python代码:
checkpoint = ModelCheckpoint('best_models/model--{val_accuracy:03f}--{epoch:03d}-{accuracy:03f}.h5', verbose=1, monitor='val_accuracy',save_best_only=True, mode='auto')
X_train_raw = requests.get("http://tb-test.chatbotech.com/info/get-training-arrays").json().get("xTrain")
X_test_raw = requests.get("http://tb-test.chatbotech.com/info/get-training-arrays").json().get("xTest")
y_train_raw = requests.get("http://tb-test.chatbotech.com/info/get-training-arrays").json().get("yTrain")
y_test_raw = requests.get("http://tb-test.chatbotech.com/info/get-training-arrays").json().get("yTest")
X_train = np.array(ast.literal_eval(X_train_raw))
X_test = np.array(ast.literal_eval(X_test_raw))
y_train_hot = np.array(ast.literal_eval(y_train_raw))
y_test_hot = np.array(ast.literal_eval(y_test_raw))
max_pad_length = 220
model = Sequential()
model.add(Conv2D(128, kernel_size=(8, 48), activation='relu', input_shape=(20, max_pad_length, 1)))
model.add(MaxPooling2D(pool_size=(3, 120)))
model.add(Dropout(0.2))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(30, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
history = model.fit(X_train, y_train_hot, batch_size=20, epochs=2000, verbose=1, validation_data=(X_test, y_test_hot),callbacks=[checkpoint])
Javascript代码:
async function getData() {
const dataReq = await fetch('http://tb-test.chatbotech.com/info/get-training-arrays');
const trainData = await dataReq.json();
return trainData;
}
async function run() {
// Load and plot the original input data that we are going to train on.
const data = await getData();
console.log(data);
const model = createModel();
// More code will be added below
model.fit(tf.tensor(JSON.parse(data.xTrain), [230, 20, 220, 1], 'float32'), tf.tensor(JSON.parse(data.yTrain), [230, 30]), { shuffle: false, epochs: 2000, validationData: [tf.tensor(JSON.parse(data.xTest), [154, 20, 220, 1], 'float32'), tf.tensor(JSON.parse(data.yTest), [154, 30])], callbacks: {
async onEpochEnd(epoch, logs) {
console.log(logs);
},
onBatchEnd(batch, logs) {
console.log(logs);
console.log(batch);
}}});
}
function createModel() {
const model = tf.sequential();
model.add(tf.layers.conv2d({filters: 128, kernelSize: [8, 48], activation: 'relu', inputShape: [20, 220, 1], strides: [1, 1], padding: 'valid'}));
model.add(tf.layers.maxPooling2d({poolSize: [3, 120], strides: [3, 120]}));
model.add(tf.layers.dropout({rate: 0.2}));
model.add(tf.layers.dense({units: 128, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.3}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 30, activation: 'softmax'}));
model.compile({loss: tf.metrics.categoricalCrossentropy, optimizer: tf.train.adadelta(1, 0.95, 1e-07 ), metrics: ['accuracy']});
return model;
}
document.addEventListener('DOMContentLoaded', run);
解决方案
您也可以尝试这种架构。我认为这将有助于提高您的模型准确性:
async function* data() {
while (true) {
for (i in train) {
// this function return tensor data
}
}
}
async function* labels() {
while (true) {
for (i in train) {
// this function return tensor label
}
}
}
async function initModel() {
// model write here
}
(async function () {
const xs = tf.data.generator(data);
const ys = tf.data.generator(labels);
const model = await initModel();
model.summary();
await model.fit(xs, ys,{
epochs: 5
batchesPerEpoch: 5
});
})()
提供您的反馈。
推荐阅读
- reactjs - 如何在 saga 中测试 firebase 功能?
- reactjs - 如何在 reactjs 的 CKEditor5 工具栏中添加下划线选项和对齐选项
- javascript - 获取触发事件的复选框的 ID
- django - 可以将用户/配置文件链接到先前创建的名称,该名称通过外键链接到模型
- spyder - 如何在 Spyder 4 中禁用代码折叠?
- python - 使用烧瓶更新 Web 应用程序中的变量
- java - TimerTask Scheduler 在静态引用其他类对象后不起作用
- swift - 如何在 Swift 中嵌入文件?
- c++ - 使用 `emplace_back` 而不是 `push_back` 时没有缩小警告
- excel - Excel 数据中的变音符号