首页 > 解决方案 > 为什么样本数据集会改变模型预测的速度?

问题描述

我一直在玩 sklearn randomForestClassifier,有些东西我不明白。

我确实创建了一个模型,该模型根据每个像素是否对应于对象进行分类,在本例中为梨。像这样的东西:

在此处输入图像描述

为了训练这个模型,我使用了这三个图像:

在此处输入图像描述

现在来了,我创建了两个不同的模型,每个模型对相同的训练图像使用不同的标记掩码:

模型1:

在此处输入图像描述

模型 2:

在此处输入图像描述

两种模型都有可接受的精度,但这里的问题是第一个模型需要大约 0.5 秒来预测 10 个新图像,而模型 2 需要大约 4 秒来预测相同的 10 个图像。这怎么可能?

使用@TYZ 建议的代码和更多信息进行编辑

def train_model(image_path, labeled_image_path,  test_split=0.3):
    train_img = cv2.imread(image_path) #Load training image
    train_img = cv2.cvtColor(train_img, cv2.COLOR_BGR2RGB) #Convert BGR to RGB
    train_img = train_img[:, :, 2] #Only take the GREEN channel (Try for each channel and select the best)

    df = generate_features(train_img) #Create a dataFrame containing all features

    train_img_labeled = cv2.imread(labeled_image_path)
    train_img_labeled = cv2.cvtColor(train_img_labeled, cv2.COLOR_BGR2GRAY) #Load labeled image
    df["Labels"] = train_img_labeled.reshape(-1)

    Y = df["Labels"].values
    X = df.drop(labels=["Labels"], axis=1)

    # Split the data into train and test
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=test_split, random_state=20)

    #Create new model
    from sklearn.ensemble import RandomForestClassifier
    model = RandomForestClassifier(n_estimators = 5, random_state = 25, max_depth=1, n_jobs=1)

    # Train the model
    model.fit(X_train, y_train)
    return model 
def test(model):
    
    image_paths = ["test/scaled/test_1.jpeg", "test/scaled/test_2.jpeg", "test/scaled/test_3.jpeg", "test/scaled/test_4.jpeg", "test/scaled/test_5.jpeg", "test/scaled/test_6.jpeg", "test/scaled/test_7.jpeg",
    "test/scaled/test_8.jpeg", "test/scaled/test_9.jpeg", "test/scaled/test_10.jpeg"]
        
    import time
    start_time = time.time()
    
    for image_path in image_paths:
        start_time = time.time()
        result = calculate_model_output(model, image_path, show=False)
    
    print("--- %s seconds ---" % (time.time()-start_time))

在此处输入图像描述

标签: pythonmachine-learningscikit-learnobject-detectionrandom-forest

解决方案


推荐阅读