LDA and QDA are parametric statistical methods for classification that put assume a MVN distribution over the data clusters.
LDA assumes that each cluster has the same covariance matrix. We can
use the lda
function from the MASS
package to
fit a LDA model.
salmon <- read.table("datasets/salmon.txt")
print(dim(salmon))
## [1] 100 3
head(salmon)
## SalmonOrigin Freshwater Marine
## 1 Alaska 108 368
## 2 Alaska 131 355
## 3 Alaska 105 469
## 4 Alaska 86 506
## 5 Alaska 99 402
## 6 Alaska 87 423
plot(salmon[, -1], col = as.factor(salmon[, 1]))
We can so something a bit more fancy using the ellipse
package.
library(ellipse)
## Warning: package 'ellipse' was built under R version 4.4.2
##
## Attaching package: 'ellipse'
## The following object is masked from 'package:graphics':
##
## pairs
plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), xlim = c(50, 190), ylim = c(290, 530))
lines(ellipse(cov(salmon[c(1:50), c(2, 3)]), centre = colMeans(salmon[c(1:50), c(2, 3)]), level = c(0.5)))
lines(ellipse(cov(salmon[c(51:100), c(2, 3)]), centre = colMeans(salmon[c(51:100), c(2, 3)]), level = 0.5), col = 2)
Before we use LDA, we need to split the data into training and test sets. We don’t need a validation set because LDA uses probabilities to classify the data. It doesn’t have any hyperparameters to tune. So we don’t need a dataset for tuning the hyperparameters. We will only use the test set for evaluating the model.
The first 50 observations are from Alaska and the next 50 are from Canada.
train <- salmon[c(1:40, 51:90), ]
test <- salmon[c(41:50, 91:100), ]
library(MASS)
lsol <- lda(train[, c(2, 3)], grouping = train[, 1])
print(lsol)
## Call:
## lda(train[, c(2, 3)], grouping = train[, 1])
##
## Prior probabilities of groups:
## Alaska Canada
## 0.5 0.5
##
## Group means:
## Freshwater Marine
## Alaska 100.550 422.275
## Canada 138.625 368.650
##
## Coefficients of linear discriminants:
## LD1
## Freshwater 0.04390178
## Marine -0.01806237
We can access these information from the lsol
object
with lsos$prior
, lsol$means
, \(\dots\) Note that the pooled covariance
matrix is not provided, but we can calculate it using the
cov
function.
alaska <- train[train == "Alaska", c(2, 3)]
canada <- train[train == "Canada", c(2, 3)]
n_alaska <- dim(alaska)[1]
n_canada <- dim(canada)[1]
pooled_cov <- ((n_alaska - 1) * cov(alaska) + (n_canada - 1) * cov(canada)) / (n_alaska + n_canada - 2)
pooled_cov
## Freshwater Marine
## Freshwater 322.22147 -15.24744
## Marine -15.24744 1087.44968
We can use the predict
function to classify the test
set. Lets try a salmon with Freshwater recording 120 and a Marine
recording 380.
predict(lsol, c(120, 380))
## $class
## [1] Canada
## Levels: Alaska Canada
##
## $posterior
## Alaska Canada
## [1,] 0.3132047 0.6867953
##
## $x
## LD1
## [1,] 0.2973989
Lets evaluate the model using the test set.
predict(lsol, test[, c(2, 3)])
## $class
## [1] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [11] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
##
## $posterior
## Alaska Canada
## 41 0.999934575 6.542453e-05
## 42 0.998909821 1.090179e-03
## 43 0.999641196 3.588039e-04
## 44 0.997267179 2.732821e-03
## 45 0.991071121 8.928879e-03
## 46 0.990434148 9.565852e-03
## 47 0.973525192 2.647481e-02
## 48 0.998445913 1.554087e-03
## 49 0.999459094 5.409062e-04
## 50 0.999593904 4.060962e-04
## 91 0.073753358 9.262466e-01
## 92 0.172305247 8.276948e-01
## 93 0.068420264 9.315797e-01
## 94 0.019825308 9.801747e-01
## 95 0.061697460 9.383025e-01
## 96 0.001990077 9.980099e-01
## 97 0.042753089 9.572469e-01
## 98 0.048058245 9.519418e-01
## 99 0.002611083 9.973889e-01
## 100 0.205956271 7.940437e-01
##
## $x
## LD1
## 41 -3.6492357
## 42 -2.5833037
## 43 -3.0045114
## 44 -2.2345978
## 45 -1.7837951
## 46 -1.7574513
## 47 -1.3653479
## 48 -2.4488379
## 49 -2.8489705
## 50 -2.9575969
## 91 0.9584339
## 92 0.5944260
## 93 0.9890377
## 94 1.4774783
## 95 1.0309356
## 96 2.3550095
## 97 1.1774384
## 98 1.1310284
## 99 2.2519041
## 100 0.5111346
How much uncertainty is there in relation to the classification of the datapoints?
table(predict(lsol, test[, c(2, 3)])$class, test[, 1])
##
## Alaska Canada
## Alaska 10 0
## Canada 0 10
The model correctly classified all 20 test observations. And it was
very sure with the, with the exception of 92 and 100.
We can easily perform leave-one-out cross validation by setting the
CV
parameter in the lda
function to
TRUE
.
lsol_cv <- lda(salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)
lsol_cv
## $class
## [1] Canada Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [11] Alaska Canada Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [21] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Canada
## [31] Alaska Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [41] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [51] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## [61] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## [71] Alaska Canada Canada Canada Canada Canada Canada Canada Canada Canada
## [81] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## [91] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
##
## $posterior
## Alaska Canada
## 1 3.948409e-01 6.051591e-01
## 2 1.175753e-02 9.882425e-01
## 3 9.949318e-01 5.068154e-03
## 4 9.999539e-01 4.613785e-05
## 5 9.302454e-01 6.975460e-02
## 6 9.945646e-01 5.435388e-03
## 7 9.944957e-01 5.504346e-03
## 8 9.912424e-01 8.757632e-03
## 9 9.988077e-01 1.192294e-03
## 10 9.335866e-01 6.641344e-02
## 11 8.829710e-01 1.170290e-01
## 12 9.950019e-02 9.004998e-01
## 13 9.950019e-02 9.004998e-01
## 14 9.052959e-01 9.470410e-02
## 15 6.247805e-01 3.752195e-01
## 16 9.021690e-01 9.783100e-02
## 17 8.912112e-01 1.087888e-01
## 18 5.421235e-01 4.578765e-01
## 19 9.701981e-01 2.980188e-02
## 20 9.754263e-01 2.457366e-02
## 21 5.884179e-01 4.115821e-01
## 22 9.975001e-01 2.499890e-03
## 23 9.869529e-01 1.304709e-02
## 24 9.736560e-01 2.634404e-02
## 25 9.969952e-01 3.004795e-03
## 26 9.986092e-01 1.390781e-03
## 27 7.421947e-01 2.578053e-01
## 28 9.772908e-01 2.270923e-02
## 29 9.981309e-01 1.869084e-03
## 30 2.694046e-01 7.305954e-01
## 31 7.470959e-01 2.529041e-01
## 32 4.514649e-01 5.485351e-01
## 33 9.990434e-01 9.566242e-04
## 34 9.993426e-01 6.574269e-04
## 35 9.999630e-01 3.698873e-05
## 36 9.725100e-01 2.748998e-02
## 37 9.995363e-01 4.636552e-04
## 38 9.871862e-01 1.281376e-02
## 39 9.837609e-01 1.623907e-02
## 40 9.976475e-01 2.352489e-03
## 41 9.999751e-01 2.490113e-05
## 42 9.991973e-01 8.027494e-04
## 43 9.997861e-01 2.139083e-04
## 44 9.976951e-01 2.304917e-03
## 45 9.917228e-01 8.277233e-03
## 46 9.910411e-01 8.958916e-03
## 47 9.728086e-01 2.719136e-02
## 48 9.987883e-01 1.211663e-03
## 49 9.996480e-01 3.520384e-04
## 50 9.997472e-01 2.527688e-04
## 51 4.505982e-01 5.494018e-01
## 52 5.373865e-03 9.946261e-01
## 53 5.686601e-04 9.994313e-01
## 54 5.431215e-03 9.945688e-01
## 55 6.857915e-04 9.993142e-01
## 56 2.253202e-01 7.746798e-01
## 57 2.473720e-02 9.752628e-01
## 58 1.231348e-02 9.876865e-01
## 59 9.331176e-03 9.906688e-01
## 60 2.314767e-03 9.976852e-01
## 61 1.488799e-02 9.851120e-01
## 62 1.108498e-01 8.891502e-01
## 63 1.229294e-02 9.877071e-01
## 64 6.432612e-04 9.993567e-01
## 65 6.576817e-05 9.999342e-01
## 66 1.110309e-02 9.888969e-01
## 67 6.576817e-05 9.999342e-01
## 68 4.880101e-01 5.119899e-01
## 69 5.345196e-03 9.946548e-01
## 70 1.002757e-02 9.899724e-01
## 71 9.673576e-01 3.264243e-02
## 72 1.328459e-03 9.986715e-01
## 73 8.407354e-02 9.159265e-01
## 74 1.018754e-02 9.898125e-01
## 75 1.370577e-01 8.629423e-01
## 76 5.617608e-02 9.438239e-01
## 77 1.134551e-01 8.865449e-01
## 78 2.325374e-02 9.767463e-01
## 79 2.745928e-01 7.254072e-01
## 80 1.487332e-01 8.512668e-01
## 81 1.526065e-02 9.847394e-01
## 82 1.701460e-03 9.982985e-01
## 83 6.751387e-03 9.932486e-01
## 84 6.706322e-04 9.993294e-01
## 85 7.792248e-02 9.220775e-01
## 86 6.867805e-02 9.313219e-01
## 87 2.373124e-01 7.626876e-01
## 88 4.619613e-02 9.538039e-01
## 89 7.075930e-04 9.992924e-01
## 90 3.433555e-03 9.965664e-01
## 91 4.280047e-02 9.571995e-01
## 92 1.130620e-01 8.869380e-01
## 93 4.127517e-02 9.587248e-01
## 94 9.901025e-03 9.900990e-01
## 95 3.495695e-02 9.650431e-01
## 96 7.623827e-04 9.992376e-01
## 97 2.459784e-02 9.754022e-01
## 98 2.772808e-02 9.722719e-01
## 99 1.029729e-03 9.989703e-01
## 100 1.631781e-01 8.368219e-01
##
## $call
## lda(x = salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)
We can see that output returned by lsol_cv
includes a
list of how data points were classified when they were the only point
left out. Additionally, a matrix of group membership probability is
returned. The class assigned to each data point is stored in
lsol_cv$class
. We can use this to count how many points
were missclassified.
table(lsol_cv$class, salmon[, 1])
##
## Alaska Canada
## Alaska 44 1
## Canada 6 49
We can see that the model misclassified 7 observations. It put 1 Canadian salmon to Alaska, and 6 Alaskan salmon to Canada.
We will plot the data colored with the true class, but the points will be marked with the class that the LDA model predicted.
plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), pch = as.numeric(lsol_cv$class))
QDA assumes that each cluster has a different covariance matrix.
qsol <- qda(train[, c(2, 3)], grouping = train[, 1])
predict(qsol, test[, c(2, 3)])
## $class
## [1] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [11] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
##
## $posterior
## Alaska Canada
## 41 0.999999603 3.973313e-07
## 42 0.999934364 6.563641e-05
## 43 0.999975258 2.474169e-05
## 44 0.999739798 2.602019e-04
## 45 0.992748801 7.251199e-03
## 46 0.997730481 2.269519e-03
## 47 0.985325240 1.467476e-02
## 48 0.999895544 1.044557e-04
## 49 0.999982257 1.774293e-05
## 50 0.999987507 1.249314e-05
## 91 0.112831650 8.871684e-01
## 92 0.237765000 7.622350e-01
## 93 0.111230243 8.887698e-01
## 94 0.030850466 9.691495e-01
## 95 0.074449894 9.255501e-01
## 96 0.008231741 9.917683e-01
## 97 0.076453258 9.235467e-01
## 98 0.085706823 9.142932e-01
## 99 0.007297430 9.927026e-01
## 100 0.182480379 8.175196e-01
We will again perform cross validation.
qsol_cv <- qda(salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)
table(qsol_cv$class, salmon[, 1])
##
## Alaska Canada
## Alaska 45 3
## Canada 5 47
plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), pch = as.numeric(qsol_cv$class))