Discriminant analysis

LDA and QDA are parametric statistical methods for classification that put assume a MVN distribution over the data clusters.

Linear Discriminant Analysis (LDA)

LDA assumes that each cluster has the same covariance matrix. We can use the lda function from the MASS package to fit a LDA model.

salmon <- read.table("datasets/salmon.txt")
print(dim(salmon))
## [1] 100   3
head(salmon)
##   SalmonOrigin Freshwater Marine
## 1       Alaska        108    368
## 2       Alaska        131    355
## 3       Alaska        105    469
## 4       Alaska         86    506
## 5       Alaska         99    402
## 6       Alaska         87    423
plot(salmon[, -1], col = as.factor(salmon[, 1]))

We can so something a bit more fancy using the ellipse package.

library(ellipse)
## Warning: package 'ellipse' was built under R version 4.4.2
## 
## Attaching package: 'ellipse'
## The following object is masked from 'package:graphics':
## 
##     pairs
plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), xlim = c(50, 190), ylim = c(290, 530))
lines(ellipse(cov(salmon[c(1:50), c(2, 3)]), centre = colMeans(salmon[c(1:50), c(2, 3)]), level = c(0.5)))
lines(ellipse(cov(salmon[c(51:100), c(2, 3)]), centre = colMeans(salmon[c(51:100), c(2, 3)]), level = 0.5), col = 2)

Splitting the data

Before we use LDA, we need to split the data into training and test sets. We don’t need a validation set because LDA uses probabilities to classify the data. It doesn’t have any hyperparameters to tune. So we don’t need a dataset for tuning the hyperparameters. We will only use the test set for evaluating the model.

The first 50 observations are from Alaska and the next 50 are from Canada.

train <- salmon[c(1:40, 51:90), ]
test <- salmon[c(41:50, 91:100), ]

Fitting the LDA model

library(MASS)
lsol <- lda(train[, c(2, 3)], grouping = train[, 1])
print(lsol)
## Call:
## lda(train[, c(2, 3)], grouping = train[, 1])
## 
## Prior probabilities of groups:
## Alaska Canada 
##    0.5    0.5 
## 
## Group means:
##        Freshwater  Marine
## Alaska    100.550 422.275
## Canada    138.625 368.650
## 
## Coefficients of linear discriminants:
##                    LD1
## Freshwater  0.04390178
## Marine     -0.01806237

We can access these information from the lsol object with lsos$prior, lsol$means, \(\dots\) Note that the pooled covariance matrix is not provided, but we can calculate it using the cov function.

alaska <- train[train == "Alaska", c(2, 3)]
canada <- train[train == "Canada", c(2, 3)]
n_alaska <- dim(alaska)[1]
n_canada <- dim(canada)[1]
pooled_cov <- ((n_alaska - 1) * cov(alaska) + (n_canada - 1) * cov(canada)) / (n_alaska + n_canada - 2)
pooled_cov
##            Freshwater     Marine
## Freshwater  322.22147  -15.24744
## Marine      -15.24744 1087.44968

Classification

We can use the predict function to classify the test set. Lets try a salmon with Freshwater recording 120 and a Marine recording 380.

predict(lsol, c(120, 380))
## $class
## [1] Canada
## Levels: Alaska Canada
## 
## $posterior
##         Alaska    Canada
## [1,] 0.3132047 0.6867953
## 
## $x
##            LD1
## [1,] 0.2973989

Lets evaluate the model using the test set.

predict(lsol, test[, c(2, 3)])
## $class
##  [1] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [11] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
## 
## $posterior
##          Alaska       Canada
## 41  0.999934575 6.542453e-05
## 42  0.998909821 1.090179e-03
## 43  0.999641196 3.588039e-04
## 44  0.997267179 2.732821e-03
## 45  0.991071121 8.928879e-03
## 46  0.990434148 9.565852e-03
## 47  0.973525192 2.647481e-02
## 48  0.998445913 1.554087e-03
## 49  0.999459094 5.409062e-04
## 50  0.999593904 4.060962e-04
## 91  0.073753358 9.262466e-01
## 92  0.172305247 8.276948e-01
## 93  0.068420264 9.315797e-01
## 94  0.019825308 9.801747e-01
## 95  0.061697460 9.383025e-01
## 96  0.001990077 9.980099e-01
## 97  0.042753089 9.572469e-01
## 98  0.048058245 9.519418e-01
## 99  0.002611083 9.973889e-01
## 100 0.205956271 7.940437e-01
## 
## $x
##            LD1
## 41  -3.6492357
## 42  -2.5833037
## 43  -3.0045114
## 44  -2.2345978
## 45  -1.7837951
## 46  -1.7574513
## 47  -1.3653479
## 48  -2.4488379
## 49  -2.8489705
## 50  -2.9575969
## 91   0.9584339
## 92   0.5944260
## 93   0.9890377
## 94   1.4774783
## 95   1.0309356
## 96   2.3550095
## 97   1.1774384
## 98   1.1310284
## 99   2.2519041
## 100  0.5111346

How much uncertainty is there in relation to the classification of the datapoints?

table(predict(lsol, test[, c(2, 3)])$class, test[, 1])
##         
##          Alaska Canada
##   Alaska     10      0
##   Canada      0     10

The model correctly classified all 20 test observations. And it was very sure with the, with the exception of 92 and 100.

Cross validation

We can easily perform leave-one-out cross validation by setting the CV parameter in the lda function to TRUE.

lsol_cv <- lda(salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)
lsol_cv
## $class
##   [1] Canada Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [11] Alaska Canada Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [21] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Canada
##  [31] Alaska Canada Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [41] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
##  [51] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [61] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [71] Alaska Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [81] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
##  [91] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
## 
## $posterior
##           Alaska       Canada
## 1   3.948409e-01 6.051591e-01
## 2   1.175753e-02 9.882425e-01
## 3   9.949318e-01 5.068154e-03
## 4   9.999539e-01 4.613785e-05
## 5   9.302454e-01 6.975460e-02
## 6   9.945646e-01 5.435388e-03
## 7   9.944957e-01 5.504346e-03
## 8   9.912424e-01 8.757632e-03
## 9   9.988077e-01 1.192294e-03
## 10  9.335866e-01 6.641344e-02
## 11  8.829710e-01 1.170290e-01
## 12  9.950019e-02 9.004998e-01
## 13  9.950019e-02 9.004998e-01
## 14  9.052959e-01 9.470410e-02
## 15  6.247805e-01 3.752195e-01
## 16  9.021690e-01 9.783100e-02
## 17  8.912112e-01 1.087888e-01
## 18  5.421235e-01 4.578765e-01
## 19  9.701981e-01 2.980188e-02
## 20  9.754263e-01 2.457366e-02
## 21  5.884179e-01 4.115821e-01
## 22  9.975001e-01 2.499890e-03
## 23  9.869529e-01 1.304709e-02
## 24  9.736560e-01 2.634404e-02
## 25  9.969952e-01 3.004795e-03
## 26  9.986092e-01 1.390781e-03
## 27  7.421947e-01 2.578053e-01
## 28  9.772908e-01 2.270923e-02
## 29  9.981309e-01 1.869084e-03
## 30  2.694046e-01 7.305954e-01
## 31  7.470959e-01 2.529041e-01
## 32  4.514649e-01 5.485351e-01
## 33  9.990434e-01 9.566242e-04
## 34  9.993426e-01 6.574269e-04
## 35  9.999630e-01 3.698873e-05
## 36  9.725100e-01 2.748998e-02
## 37  9.995363e-01 4.636552e-04
## 38  9.871862e-01 1.281376e-02
## 39  9.837609e-01 1.623907e-02
## 40  9.976475e-01 2.352489e-03
## 41  9.999751e-01 2.490113e-05
## 42  9.991973e-01 8.027494e-04
## 43  9.997861e-01 2.139083e-04
## 44  9.976951e-01 2.304917e-03
## 45  9.917228e-01 8.277233e-03
## 46  9.910411e-01 8.958916e-03
## 47  9.728086e-01 2.719136e-02
## 48  9.987883e-01 1.211663e-03
## 49  9.996480e-01 3.520384e-04
## 50  9.997472e-01 2.527688e-04
## 51  4.505982e-01 5.494018e-01
## 52  5.373865e-03 9.946261e-01
## 53  5.686601e-04 9.994313e-01
## 54  5.431215e-03 9.945688e-01
## 55  6.857915e-04 9.993142e-01
## 56  2.253202e-01 7.746798e-01
## 57  2.473720e-02 9.752628e-01
## 58  1.231348e-02 9.876865e-01
## 59  9.331176e-03 9.906688e-01
## 60  2.314767e-03 9.976852e-01
## 61  1.488799e-02 9.851120e-01
## 62  1.108498e-01 8.891502e-01
## 63  1.229294e-02 9.877071e-01
## 64  6.432612e-04 9.993567e-01
## 65  6.576817e-05 9.999342e-01
## 66  1.110309e-02 9.888969e-01
## 67  6.576817e-05 9.999342e-01
## 68  4.880101e-01 5.119899e-01
## 69  5.345196e-03 9.946548e-01
## 70  1.002757e-02 9.899724e-01
## 71  9.673576e-01 3.264243e-02
## 72  1.328459e-03 9.986715e-01
## 73  8.407354e-02 9.159265e-01
## 74  1.018754e-02 9.898125e-01
## 75  1.370577e-01 8.629423e-01
## 76  5.617608e-02 9.438239e-01
## 77  1.134551e-01 8.865449e-01
## 78  2.325374e-02 9.767463e-01
## 79  2.745928e-01 7.254072e-01
## 80  1.487332e-01 8.512668e-01
## 81  1.526065e-02 9.847394e-01
## 82  1.701460e-03 9.982985e-01
## 83  6.751387e-03 9.932486e-01
## 84  6.706322e-04 9.993294e-01
## 85  7.792248e-02 9.220775e-01
## 86  6.867805e-02 9.313219e-01
## 87  2.373124e-01 7.626876e-01
## 88  4.619613e-02 9.538039e-01
## 89  7.075930e-04 9.992924e-01
## 90  3.433555e-03 9.965664e-01
## 91  4.280047e-02 9.571995e-01
## 92  1.130620e-01 8.869380e-01
## 93  4.127517e-02 9.587248e-01
## 94  9.901025e-03 9.900990e-01
## 95  3.495695e-02 9.650431e-01
## 96  7.623827e-04 9.992376e-01
## 97  2.459784e-02 9.754022e-01
## 98  2.772808e-02 9.722719e-01
## 99  1.029729e-03 9.989703e-01
## 100 1.631781e-01 8.368219e-01
## 
## $call
## lda(x = salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)

We can see that output returned by lsol_cv includes a list of how data points were classified when they were the only point left out. Additionally, a matrix of group membership probability is returned. The class assigned to each data point is stored in lsol_cv$class. We can use this to count how many points were missclassified.

table(lsol_cv$class, salmon[, 1])
##         
##          Alaska Canada
##   Alaska     44      1
##   Canada      6     49

We can see that the model misclassified 7 observations. It put 1 Canadian salmon to Alaska, and 6 Alaskan salmon to Canada.

Visializing the results

We will plot the data colored with the true class, but the points will be marked with the class that the LDA model predicted.

plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), pch = as.numeric(lsol_cv$class))

Quadratic Discriminant Analysis (QDA)

QDA assumes that each cluster has a different covariance matrix.

qsol <- qda(train[, c(2, 3)], grouping = train[, 1])
predict(qsol, test[, c(2, 3)])
## $class
##  [1] Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska Alaska
## [11] Canada Canada Canada Canada Canada Canada Canada Canada Canada Canada
## Levels: Alaska Canada
## 
## $posterior
##          Alaska       Canada
## 41  0.999999603 3.973313e-07
## 42  0.999934364 6.563641e-05
## 43  0.999975258 2.474169e-05
## 44  0.999739798 2.602019e-04
## 45  0.992748801 7.251199e-03
## 46  0.997730481 2.269519e-03
## 47  0.985325240 1.467476e-02
## 48  0.999895544 1.044557e-04
## 49  0.999982257 1.774293e-05
## 50  0.999987507 1.249314e-05
## 91  0.112831650 8.871684e-01
## 92  0.237765000 7.622350e-01
## 93  0.111230243 8.887698e-01
## 94  0.030850466 9.691495e-01
## 95  0.074449894 9.255501e-01
## 96  0.008231741 9.917683e-01
## 97  0.076453258 9.235467e-01
## 98  0.085706823 9.142932e-01
## 99  0.007297430 9.927026e-01
## 100 0.182480379 8.175196e-01

We will again perform cross validation.

qsol_cv <- qda(salmon[, c(2, 3)], grouping = salmon[, 1], CV = TRUE)
table(qsol_cv$class, salmon[, 1])
##         
##          Alaska Canada
##   Alaska     45      3
##   Canada      5     47
plot(salmon[, c(2, 3)], col = as.factor(salmon[, 1]), pch = as.numeric(qsol_cv$class))