首页 > 解决方案 > 在 LDA 中查找数据已分配到的三个类别的均值

问题描述

我正在查看如何计算基于 Q = W^{-1}B 的 Fisher 的 LDA 的示例代码。

数据导入如下:

aircraft = read_csv(file = "aircraft.csv") %>%
  mutate( Period = factor( Period ))

我有以下 Fisher 的 LDA 示例。使用 计算 Q solve(W, B),然后找到 Q 的第一个特征向量,然后分配类:

W1 = cov( dplyr::select( dplyr::filter( aircraft, Period == 1 ), -Year, -Period ) )
W2 = cov( dplyr::select( dplyr::filter( aircraft, Period == 2 ), -Year, -Period ) )
W3 = cov( dplyr::select( dplyr::filter( aircraft, Period == 3 ), -Year, -Period ) )

W = W1 + W2 + W3

mu1 = colMeans( dplyr::select( dplyr::filter( aircraft, Period == 1 ), -Year, -Period ) ) 
mu2 = colMeans( dplyr::select( dplyr::filter( aircraft, Period == 2 ), -Year, -Period ) ) 
mu3 = colMeans( dplyr::select( dplyr::filter( aircraft, Period == 3 ), -Year, -Period ) )  
mu = rbind( mu1, mu2, mu3 )

B = ( 3 - 1 ) * cov (mu ) 

Q = solve( W, B )

eta = eigen( Q )$vectors[,1]

XX = dplyr::select( aircraft, -Year, -Period ) 

XXproj = as.matrix( XX ) %*% as.matrix( eta ) 
muP1 = t( as.matrix( mu1 ) ) %*% as.matrix( eta )
muP2 = t( as.matrix( mu2 ) ) %*% as.matrix( eta )
muP3 = t( as.matrix( mu3 ) ) %*% as.matrix( eta )

tXXproj = t( XXproj ) 
m1 = as.data.frame( tXXproj )  - muP1 
m2 = as.data.frame( tXXproj )  - muP2
m3 = as.data.frame( tXXproj )  - muP3
mm = rbind( abs( m1 ), abs( m2 ), abs( m3 ) ) 
classes = sapply( mm, which.min ) 

classified = data.frame( assigned = classes, aircraft )

xtabs( ~ assigned + Period , data = classified )

该命令str(classified)产生以下输出:

'data.frame':   709 obs. of  9 variables:
 $ assigned: int  1 1 1 1 1 1 1 1 1 1 ...
 $ Year    : int  14 14 14 15 15 15 15 16 16 16 ...
 $ Period  : Factor w/ 3 levels "1","2","3": 1 1 1 1 1 1 1 1 1 1 ...
 $ Power   : num  82 82 224 164 119 ...
 $ Span    : num  12.8 11 17.9 14.5 12.9 ...
 $ Length  : num  7.6 9 10.3 9.8 7.9 ...
 $ Weight  : num  1070 830 2200 1946 1190 ...
 $ Speed   : int  105 145 135 138 140 177 113 230 175 106 ...
 $ Range   : int  400 402 500 500 400 350 402 700 525 300 ...

我想找到数据被分配到的三个类的方法。这听起来应该很简单;但是,我对 R 缺乏经验,并且不确定如何执行此操作。我认为applyandselect函数与这种情况有关,但我不确定。

我能够使用相关的 R 函数实现我自己的 LDA:

lda.0 = lda( Period ~ Power + Span + Length + Weight + Speed + Range, data = aircraft )
preds.0 = predict( lda.0 )$class
xtabs( ~ preds.0 + aircraft$Period )

我想对我的实现做同样的事情(找到三个类的方法),就像上面的例子一样。

该命令str(predict( lda.0 ))产生以下输出:

List of 3
 $ class    : Factor w/ 3 levels "1","2","3": 1 1 1 1 1 1 1 1 1 1 ...
 $ posterior: num [1:709, 1:3] 0.712 0.659 0.67 0.665 0.69 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:709] "1" "2" "3" "4" ...
  .. ..$ : chr [1:3] "1" "2" "3"
 $ x        : num [1:709, 1:2] -1.469 -0.988 -1.504 -1.22 -1.385 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:709] "1" "2" "3" "4" ...
  .. ..$ : chr [1:2] "LD1" "LD2"

那么,对于这两种情况,找到数据分配到的三个类别的均值的好方法是什么?

完整的数据集太大,无法包含在这篇文章中,因此我包含了一个较小版本的数据集:

structure(list(Year = c(14L, 14L, 14L, 15L, 15L, 15L, 15L, 16L, 
16L, 16L, 16L, 16L, 16L, 16L, 16L, 16L, 16L, 16L, 16L, 17L, 17L, 
17L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, 
18L, 18L, 18L, 18L, 18L, 18L, 18L, 18L, 19L, 19L, 20L, 20L, 20L, 
20L, 21L, 21L, 21L, 22L, 22L, 22L, 22L, 22L, 23L, 23L, 23L, 23L, 
23L, 23L, 23L, 23L, 23L, 24L, 24L, 24L, 24L, 24L, 25L, 25L, 25L, 
25L, 25L, 25L, 25L, 26L, 26L, 26L, 26L, 26L, 26L, 26L, 26L, 26L, 
26L, 27L, 27L, 27L, 27L, 27L, 27L, 27L, 27L, 28L, 28L, 28L, 28L, 
28L), Period = c(1L, 3L, 3L, 1L, 2L, 1L, 3L, 2L, 1L, 3L, 2L, 
3L, 1L, 2L, 1L, 1L, 1L, 2L, 1L, 3L, 1L, 3L, 3L, 2L, 1L, 1L, 1L, 
1L, 3L, 2L, 1L, 1L, 3L, 2L, 1L, 2L, 1L, 1L, 2L, 1L, 2L, 1L, 3L, 
2L, 3L, 1L, 1L, 2L, 3L, 1L, 3L, 2L, 1L, 2L, 1L, 1L, 1L, 3L, 2L, 
2L, 3L, 1L, 3L, 1L, 3L, 2L, 1L, 1L, 2L, 1L, 3L, 1L, 1L, 2L, 2L, 
3L, 3L, 1L, 3L, 1L, 2L, 1L, 2L, 1L, 1L, 3L, 1L, 1L, 2L, 3L, 2L, 
1L, 2L, 1L, 1L, 1L, 3L, 2L, 1L, 2L), Power = c(82, 82, 223.6, 
164, 119, 74.5, 74.5, 279.5, 82, 67, 112, 149, 119, 119, 238.5, 
205, 82, 119, 194, 336, 558.9, 287, 388, 164, 194, 194, 186.3, 
119, 119, 89.4, 126.7, 149, 119, 536.6, 402, 298, 298, 342.8, 
536, 223.6, 521.6, 186.3, 238.5, 287, 335.3, 335.3, 335.3, 335.3, 
335.3, 335.3, 357.7, 313, 782.6, 298, 670.6, 223.5, 335.3, 391, 
391, 436, 391, 436, 171.4, 350, 298, 223.6, 298, 634, 223.5, 
864.4, 760, 503.5, 63.3, 357.7, 812, 335.3, 298, 298, 335.3, 
298, 317, 231, 335.3, 432, 918, 745.2, 424.8, 372.6, 782, 626, 
544, 335.3, 372.6, 373, 391.2, 864, 894, 179, 74.5, 391.2), Span = c(12.8, 
11, 17.9, 14.5, 12.9, 7.5, 11.13, 14.3, 7.8, 11, 11.7, 12.8, 
8.5, 13.3, 14.9, 12, 9.4, 15.95, 16.74, 22.2, 23.4, 14.3, 23.72, 
11.9, 14.4, 14.4, 9.7, 8, 9.4, 14.55, 9.1, 8.11, 9.5, 20.73, 
22.8, 38.4, 14, 26.5, 30.48, 9.7, 15.5, 9.1, 14.17, 10.1, 14.8, 
15.62, 14.05, 14.05, 14.8, 15.24, 14, 12.24, 27.2, 8.84, 22.86, 
7.7, 9.5, 9.8, 15.93, 15.93, 15.93, 15.93, 13.08, 15.21, 8.94, 
9.6, 10.8, 13.72, 8.9, 26.72, 25, 9.6, 8.84, 11.58, 17.3, 12.5, 
12.1, 12.09, 9.8, 15.3, 9.08, 17.75, 15.3, 15.15, 27.4, 22, 13.7, 
10.3, 22.76, 22.25, 17.25, 11, 12, 9.5, 14.15, 20.4, 20.4, 14.5, 
8.84, 11.35), Length = c(7.6, 9, 10.35, 9.8, 7.9, 6.3, 8.28, 
9.4, 6.7, 8.3, 8, 8.7, 7.4, 9.6, 8.9, 7.9, 6.2, 10.25, 10.77, 
10.9, 12.6, 9.4, 11.86, 9.8, 9.2, 8.9, 8, 6.5, 6.95, 9.83, 7.3, 
6.38, 8.5, 13.27, 13.5, 20.85, 9.2, 14.33, 19.16, 6.5, 9.7, 8.1, 
9.68, 7.7, 10.8, 11.89, 10.97, 11.28, 9.5, 11.42, 11, 7.3, 18.2, 
7.01, 18.08, 6.8, 6.8, 7.1, 11.5, 11.5, 11.5, 11.5, 9.27, 9.78, 
6.17, 6.4, 7.32, 10.74, 6.9, 18.97, 15.1, 7.06, 7.17, 9.5, 10.55, 
8.38, 8.7, 8.81, 6.7, 9.42, 5.99, 10.27, 10.22, 11, 19.8, 14.63, 
11.2, 6.56, 14.88, 13.81, 12.6, 7, 7.5, 7.2, 9.91, 14.8, 15, 
9.8, 7.17, 8.94), Weight = c(1070, 830, 2200, 1946, 1190, 653, 
930, 1575, 676, 920, 1353, 1550, 888, 1275, 1537, 1292, 611, 
1350, 1700, 3312, 4920, 1510, 3625, 900, 1665, 1640, 1081, 625, 
932, 1378, 886, 902, 1070, 5670, 3636, 12925, 2107, 4770, 6060, 
1192, 1900, 1050, 2155, 1379, 2858, 3380, 2290, 2290, 2347, 3308, 
2630, 1333, 10000, 1351, 6250, 885, 1531, 1438, 3820, 3820, 3820, 
3820, 1905, 2646, 1151, 1266, 1575, 2383, 860, 7983, 6200, 1484, 
567, 1867, 4350, 1935, 1823, 2253, 1487, 2220, 1244, 2700, 2280, 
3652, 8165, 5500, 3568, 1414, 5875, 5460, 4310, 1500, 1795, 1628, 
2449, 6900, 6900, 1900, 567, 2102), Speed = c(105L, 145L, 135L, 
138L, 140L, 177L, 113L, 230L, 175L, 106L, 140L, 170L, 175L, 157L, 
183L, 201L, 209L, 145L, 120L, 135L, 152L, 176L, 140L, 190L, 175L, 
175L, 205L, 196L, 165L, 146L, 175L, 222L, 159L, 166L, 158L, 146L, 
185L, 120L, 157L, 226L, 205L, 230L, 161L, 251L, 171L, 206L, 171L, 
171L, 235L, 161L, 145L, 245L, 183L, 214L, 180L, 220L, 237L, 254L, 
169L, 169L, 169L, 169L, 153L, 183L, 261L, 245L, 235L, 200L, 246L, 
174L, 180L, 319L, 146L, 251L, 230L, 290L, 230L, 233L, 250L, 255L, 
233L, 175L, 230L, 180L, 145L, 185L, 196L, 298L, 183L, 198L, 195L, 
300L, 270L, 297L, 225L, 212L, 195L, 197L, 146L, 296L), Range = c(400L, 
402L, 500L, 500L, 400L, 350L, 402L, 700L, 525L, 300L, 560L, 550L, 
250L, 450L, 700L, 600L, 175L, 450L, 450L, 450L, 600L, 800L, 500L, 
600L, 600L, 600L, 600L, 400L, 250L, 400L, 350L, 547L, 450L, 1770L, 
800L, 2365L, 925L, 400L, 1205L, 580L, 600L, 600L, 684L, 402L, 
563L, 644L, 885L, 885L, 800L, 440L, 557L, 750L, 3600L, 500L, 
805L, 330L, 600L, 628L, 1640L, 1640L, 1640L, 1640L, 604L, 1046L, 
644L, 500L, 600L, 1046L, 550L, 1585L, 650L, 917L, 515L, 805L, 
750L, 1110L, 772L, 1127L, 500L, 850L, 523L, 850L, 900L, 700L, 
668L, 700L, 1706L, 600L, 1385L, 1000L, 902L, 600L, 500L, 450L, 
579L, 1125L, 1300L, 660L, 515L, 756L)), row.names = c(NA, 100L
), class = "data.frame")

标签: rmachine-learningstatisticslinear-discriminant

解决方案


鉴于我无权访问您的 aircraft.csv,您也没有提及您使用了哪个 lda 功能,我使用以下方法为 iris 做了一些事情MASS:lda

library(tidyverse)
library(MASS)

lda.0 = lda( Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data = iris )
preds.0 = predict( lda.0 )$class
xtabs( ~ preds.0 + iris$Species )
str(predict( lda.0 ))
lda.0$means

通过最后一次调用,您可以获得类的方法。

您可以通过使用示例中的“分类”data.frame来计算分配类的平均值:

classified %>% group_by(assigned) %>% summarize(meanSL=mean(Sepal.Length),meanSW=mean(Sepal.Width), meanPL=mean(Petal.Length), meanPW=mean(Petal.Width))

推荐阅读