LDA and QDA are parametric statistical methods for classification that put assume a MVN distribution over the data clusters.
LDA assumes that each cluster has the same covariance matrix. We can
use the lda function from the MASS package to
fit a LDA model.
salmon <- read.table("datasets/salmon.txt")
print(dim(salmon))## [1] 100   3head(salmon)##   SalmonOrigin Freshwater Marine
## 1       Alaska        108    368
## 2       Alaska        131    355
## 3       Alaska        105    469
## 4       Alaska         86    506
## 5       Alaska         99    402
## 6       Alaska         87    423plot(salmon[, -1], col = as.factor(salmon[, 1]))We can so something a bit more fancy using the ellipse
package.
library(ellipse)## Warning: package 'ellipse' was built under R version 4.4.2## 
## Attaching package: 'ellipse'## The following object is masked from 'package:graphics':
## 
##     pairsplot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), xlim = c(50, 190), ylim = c(290, 530))
lines(ellipse(cov(salmon[c(1:50), c(2, 3)]), centre = colMeans(salmon[c(1:50), c(2, 3)]), level = c(0.5)))
lines(ellipse(cov(salmon[c(51:100), c(2, 3)]), centre = colMeans(salmon[c(51:100), c(2, 3)]), level = 0.5), col = 2)Before we use LDA, we need to split the data into training and test sets. We don’t need a validation set because LDA uses probabilities to classify the data. It doesn’t have any hyperparameters to tune. So we don’t need a dataset for tuning the hyperparameters. We will only use the test set for evaluating the model.
The first 50 observations are from Alaska and the next 50 are from Canada.
train <- salmon[c(1:40, 51:90), ]
test <- salmon[c(41:50, 91:100), ]library(MASS)
lsol <- lda(train[, c(2, 3)], grouping = train[, 1])
print(lsol)## Call:
## lda(train[, c(2, 3)], grouping = train[, 1])
## 
## Prior probabilities of groups:
## Alaska Canada 
##    0.5    0.5 
## 
## Group means:
##        Freshwater  Marine
## Alaska    100.550 422.275
## Canada    138.625 368.650
## 
## Coefficients of linear discriminants:
##                    LD1
## Freshwater  0.04390178
## Marine     -0.01806237We can access these information from the lsol object
with lsos$prior, lsol$means, \(\dots\) Note that the pooled covariance
matrix is not provided, but we can calculate it using the
cov function.
alaska <- train[train == "Alaska", c(2, 3)]
canada <- train[train == "Canada", c(2, 3)]
n_alaska <- dim(alaska)[1]
n_canada <- dim(canada)[1]
pooled_cov <- ((n_alaska - 1) * cov(alaska) + (n_canada - 1) * cov(canada)) / (n_alaska + n_canada - 2)
pooled_cov##            Freshwater     Marine
## Freshwater  322.22147  -15.24744
## Marine      -15.24744 1087.44968We can use the predict function to classify the test
set. Lets try a salmon with Freshwater recording 120 and a Marine
recording 380.
predict(lsol, c(120, 380))## $class
## [1] Canada
## Levels: Alaska Canada
## 
## $posterior
##         Alaska    Canada
## [1,] 0.3132047 0.6867953
## 
## $x
##            LD1
## [1,] 0.2973989Lets evaluate the model using the test set.
predict(lsol, test[, c(2, 3)])## $class
##  [1] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [11] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
## 
## $posterior
##          Alaska       Canada
## 41  0.999934575 6.542453e-05
## 42  0.998909821 1.090179e-03
## 43  0.999641196 3.588039e-04
## 44  0.997267179 2.732821e-03
## 45  0.991071121 8.928879e-03
## 46  0.990434148 9.565852e-03
## 47  0.973525192 2.647481e-02
## 48  0.998445913 1.554087e-03
## 49  0.999459094 5.409062e-04
## 50  0.999593904 4.060962e-04
## 91  0.073753358 9.262466e-01
## 92  0.172305247 8.276948e-01
## 93  0.068420264 9.315797e-01
## 94  0.019825308 9.801747e-01
## 95  0.061697460 9.383025e-01
## 96  0.001990077 9.980099e-01
## 97  0.042753089 9.572469e-01
## 98  0.048058245 9.519418e-01
## 99  0.002611083 9.973889e-01
## 100 0.205956271 7.940437e-01
## 
## $x
##            LD1
## 41  -3.6492357
## 42  -2.5833037
## 43  -3.0045114
## 44  -2.2345978
## 45  -1.7837951
## 46  -1.7574513
## 47  -1.3653479
## 48  -2.4488379
## 49  -2.8489705
## 50  -2.9575969
## 91   0.9584339
## 92   0.5944260
## 93   0.9890377
## 94   1.4774783
## 95   1.0309356
## 96   2.3550095
## 97   1.1774384
## 98   1.1310284
## 99   2.2519041
## 100  0.5111346How much uncertainty is there in relation to the classification of the datapoints?
table(predict(lsol, test[, c(2, 3)])$class, test[, 1])##         
##          Alaska Canada
##   Alaska     10      0
##   Canada      0     10The model correctly classified all 20 test observations. And it was
very sure with the, with the exception of 92 and 100.
We can easily perform leave-one-out cross validation by setting the
CV parameter in the lda function to
TRUE.
lsol_cv <- lda(salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)
lsol_cv## $class
##   [1] Canada Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [11] Alaska Canada Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [21] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Canada
##  [31] Alaska Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [41] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [51] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [61] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [71] Alaska Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [81] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [91] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
## 
## $posterior
##           Alaska       Canada
## 1   3.948409e-01 6.051591e-01
## 2   1.175753e-02 9.882425e-01
## 3   9.949318e-01 5.068154e-03
## 4   9.999539e-01 4.613785e-05
## 5   9.302454e-01 6.975460e-02
## 6   9.945646e-01 5.435388e-03
## 7   9.944957e-01 5.504346e-03
## 8   9.912424e-01 8.757632e-03
## 9   9.988077e-01 1.192294e-03
## 10  9.335866e-01 6.641344e-02
## 11  8.829710e-01 1.170290e-01
## 12  9.950019e-02 9.004998e-01
## 13  9.950019e-02 9.004998e-01
## 14  9.052959e-01 9.470410e-02
## 15  6.247805e-01 3.752195e-01
## 16  9.021690e-01 9.783100e-02
## 17  8.912112e-01 1.087888e-01
## 18  5.421235e-01 4.578765e-01
## 19  9.701981e-01 2.980188e-02
## 20  9.754263e-01 2.457366e-02
## 21  5.884179e-01 4.115821e-01
## 22  9.975001e-01 2.499890e-03
## 23  9.869529e-01 1.304709e-02
## 24  9.736560e-01 2.634404e-02
## 25  9.969952e-01 3.004795e-03
## 26  9.986092e-01 1.390781e-03
## 27  7.421947e-01 2.578053e-01
## 28  9.772908e-01 2.270923e-02
## 29  9.981309e-01 1.869084e-03
## 30  2.694046e-01 7.305954e-01
## 31  7.470959e-01 2.529041e-01
## 32  4.514649e-01 5.485351e-01
## 33  9.990434e-01 9.566242e-04
## 34  9.993426e-01 6.574269e-04
## 35  9.999630e-01 3.698873e-05
## 36  9.725100e-01 2.748998e-02
## 37  9.995363e-01 4.636552e-04
## 38  9.871862e-01 1.281376e-02
## 39  9.837609e-01 1.623907e-02
## 40  9.976475e-01 2.352489e-03
## 41  9.999751e-01 2.490113e-05
## 42  9.991973e-01 8.027494e-04
## 43  9.997861e-01 2.139083e-04
## 44  9.976951e-01 2.304917e-03
## 45  9.917228e-01 8.277233e-03
## 46  9.910411e-01 8.958916e-03
## 47  9.728086e-01 2.719136e-02
## 48  9.987883e-01 1.211663e-03
## 49  9.996480e-01 3.520384e-04
## 50  9.997472e-01 2.527688e-04
## 51  4.505982e-01 5.494018e-01
## 52  5.373865e-03 9.946261e-01
## 53  5.686601e-04 9.994313e-01
## 54  5.431215e-03 9.945688e-01
## 55  6.857915e-04 9.993142e-01
## 56  2.253202e-01 7.746798e-01
## 57  2.473720e-02 9.752628e-01
## 58  1.231348e-02 9.876865e-01
## 59  9.331176e-03 9.906688e-01
## 60  2.314767e-03 9.976852e-01
## 61  1.488799e-02 9.851120e-01
## 62  1.108498e-01 8.891502e-01
## 63  1.229294e-02 9.877071e-01
## 64  6.432612e-04 9.993567e-01
## 65  6.576817e-05 9.999342e-01
## 66  1.110309e-02 9.888969e-01
## 67  6.576817e-05 9.999342e-01
## 68  4.880101e-01 5.119899e-01
## 69  5.345196e-03 9.946548e-01
## 70  1.002757e-02 9.899724e-01
## 71  9.673576e-01 3.264243e-02
## 72  1.328459e-03 9.986715e-01
## 73  8.407354e-02 9.159265e-01
## 74  1.018754e-02 9.898125e-01
## 75  1.370577e-01 8.629423e-01
## 76  5.617608e-02 9.438239e-01
## 77  1.134551e-01 8.865449e-01
## 78  2.325374e-02 9.767463e-01
## 79  2.745928e-01 7.254072e-01
## 80  1.487332e-01 8.512668e-01
## 81  1.526065e-02 9.847394e-01
## 82  1.701460e-03 9.982985e-01
## 83  6.751387e-03 9.932486e-01
## 84  6.706322e-04 9.993294e-01
## 85  7.792248e-02 9.220775e-01
## 86  6.867805e-02 9.313219e-01
## 87  2.373124e-01 7.626876e-01
## 88  4.619613e-02 9.538039e-01
## 89  7.075930e-04 9.992924e-01
## 90  3.433555e-03 9.965664e-01
## 91  4.280047e-02 9.571995e-01
## 92  1.130620e-01 8.869380e-01
## 93  4.127517e-02 9.587248e-01
## 94  9.901025e-03 9.900990e-01
## 95  3.495695e-02 9.650431e-01
## 96  7.623827e-04 9.992376e-01
## 97  2.459784e-02 9.754022e-01
## 98  2.772808e-02 9.722719e-01
## 99  1.029729e-03 9.989703e-01
## 100 1.631781e-01 8.368219e-01
## 
## $call
## lda(x = salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)We can see that output returned by lsol_cv includes a
list of how data points were classified when they were the only point
left out. Additionally, a matrix of group membership probability is
returned. The class assigned to each data point is stored in
lsol_cv$class. We can use this to count how many points
were missclassified.
table(lsol_cv$class, salmon[, 1])##         
##          Alaska Canada
##   Alaska     44      1
##   Canada      6     49We can see that the model misclassified 7 observations. It put 1 Canadian salmon to Alaska, and 6 Alaskan salmon to Canada.
We will plot the data colored with the true class, but the points will be marked with the class that the LDA model predicted.
plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), pch = as.numeric(lsol_cv$class))QDA assumes that each cluster has a different covariance matrix.
qsol <- qda(train[, c(2, 3)], grouping = train[, 1])
predict(qsol, test[, c(2, 3)])## $class
##  [1] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [11] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
## 
## $posterior
##          Alaska       Canada
## 41  0.999999603 3.973313e-07
## 42  0.999934364 6.563641e-05
## 43  0.999975258 2.474169e-05
## 44  0.999739798 2.602019e-04
## 45  0.992748801 7.251199e-03
## 46  0.997730481 2.269519e-03
## 47  0.985325240 1.467476e-02
## 48  0.999895544 1.044557e-04
## 49  0.999982257 1.774293e-05
## 50  0.999987507 1.249314e-05
## 91  0.112831650 8.871684e-01
## 92  0.237765000 7.622350e-01
## 93  0.111230243 8.887698e-01
## 94  0.030850466 9.691495e-01
## 95  0.074449894 9.255501e-01
## 96  0.008231741 9.917683e-01
## 97  0.076453258 9.235467e-01
## 98  0.085706823 9.142932e-01
## 99  0.007297430 9.927026e-01
## 100 0.182480379 8.175196e-01We will again perform cross validation.
qsol_cv <- qda(salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)
table(qsol_cv$class, salmon[, 1])##         
##          Alaska Canada
##   Alaska     45      3
##   Canada      5     47plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), pch = as.numeric(qsol_cv$class))