首页 > 解决方案 > 为什么这个 Keras OneHot 层实现与 OneHot 训练数据不同?

问题描述

我想学习一个 convnet 在大约 2000 个类中对 > 240.000 个文档进行分类。为此,我选择了前 60 个单词并将它们转换为索引。我试图在 Keras 中实现一个 OneHot 层以避免内存问题,但是该模型的性能比已经准备为 OneHot 的数据的模型要差得多。真正的区别是什么?

除了额外的 One_hot Lambda 层外,模型摘要报告的形状和参数相似。我使用了这里描述的 One_Hot 函数:https ://fdalvi.github.io/blog/2018-04-07-keras-sequential-onehot/

def OneHot(input_dim=None, input_length=None): 
# input_dim refers to the eventual length of the one-hot vector (e.g. 
vocab size)
# input_length refers to the length of the input sequence
# Check if inputs were supplied correctly
if input_dim is None or input_length is None:
    raise TypeError("input_dim or input_length is not set")

# Helper method (not inlined for clarity)
def _one_hot(x, num_classes):
    return K.one_hot(K.cast(x, 'uint8'),
                      num_classes=num_classes)

# Final layer representation as a Lambda layer
return Lambda(_one_hot,
              arguments={'num_classes': input_dim},
              input_shape=(input_length,))

# Model A :  This is the Keras model I use with the OneHot function:
model = Sequential()
model.add(OneHot(input_dim=model_max,
                     input_length=input_length))
model.add(Conv1D(256, 6, activation='relu'))
model.add(Conv1D(64, 3, activation='relu'))
model.add(MaxPooling1D(3)) 
model.add(Conv1D(128, 3, activation='relu'))
model.add(Conv1D(128, 3, activation='relu'))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.5))
model.add(Dense(labels_max, activation='softmax'))
checkpoint = ModelCheckpoint('model-best.h5', verbose=1, 
monitor='val_loss',save_best_only=True, mode='auto')
model.compile(optimizer=Adam(),
          loss='categorical_crossentropy',
          metrics=['accuracy'])

#Model B: And this model I use with the data already converted to OneHot:
model = Sequential()
model.add(Conv1D(256, 6, activation='relu', input_shape=(input_length, 
model_max)))
model.add(Conv1D(64, 3, activation='relu'))
model.add(MaxPooling1D(3))
model.add(Conv1D(128, 3, activation='relu'))
model.add(Conv1D(128, 3, activation='relu'))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.5))
model.add(Dense(labels_max, activation='softmax'))
checkpoint = ModelCheckpoint('model-best.h5', verbose=1, 
monitor='val_loss',save_best_only=True, mode='auto')
model.compile(optimizer=Adam(),
          loss='categorical_crossentropy',
          metrics=['accuracy'])

模型 B 的性能要好得多,验证准确率高达 60%,但它很容易陷入内存错误。模型 A 快得多,但最大验证准确率仅达到 25%。我希望他们表现类似。我在这里想念什么?谢谢!

标签: pythonkerasone-hot-encoding

解决方案


推荐阅读