线性判别分析(Linear Discriminant Analysis,LDA)是一种经典的线性分类方法。LDA是一种有监督的线性分类算法。LDA的基本思想是将数据投影到低维空间后,使得:同一类数据尽可能接近,不同类数据尽可能疏远。在对新样本进行分类时,将其投影到同样的低维空间上,再根据投影点的位置来确定新样本的类别。

avatar

值得一提的是,LDA 可从贝时斯决策理论的角度来阐释,并可证明,当两类数据同先验、满足高斯分布且协方差相等时,LDA 可达到最优分类。

Fisher discriminant criterion

  1. $X_i$ 表示 $i$ 类示例的集合
  2. $N$ 表示有 $N$ 类,且第 $i$ 类的样本数为 $N_i$
  3. $n$ 表示总共有 $n$ 个数据
  4. $\mu_i$ 表示 $i$ 类示例的均值向量,$\mu$ 表示总的均值向量
  5. $\Sigma_i$ 表示 $i$ 类示例的协方差矩阵
  6. $S_w$ 表示类内散度矩阵(within-class scatter matrix)
  7. $S_b$ 表示类间散度矩阵(between-class scatter matrix)
  8. $S_t$ 表示全局散度矩阵
  9. $w$ 为变换矩阵
  10. $J(w)$ 是最大目标Fisher判别准则

$$\boldsymbol{\mu_i} = \frac{1}{N_i} \sum_{\boldsymbol{x} \in X_i}\boldsymbol{x}$$

$$\Sigma_i = \sum_{\boldsymbol{x} \in X_i}(\boldsymbol{x}-\boldsymbol{\mu_i})(\boldsymbol{x}-\boldsymbol{\mu_i})^T$$

$$S_w = \sum^N_{i=1}S_{w_i} = \sum^N_{i=1}\Sigma_i = \sum^N_{i=1}\sum_{\boldsymbol{x} \in X_i}(\boldsymbol{x}-\boldsymbol{\mu_i})(\boldsymbol{x}-\boldsymbol{\mu_i})^T$$

$$S_b = S_t - S_w = \sum^N_{i=1}N_i(\boldsymbol{\mu_i}-\boldsymbol{\mu})(\boldsymbol{\mu_i}-\boldsymbol{\mu})^T$$

$$S_t = S_b + S_w =\sum_{i=1}^{n}(\boldsymbol{x_i}-\boldsymbol{\mu})(\boldsymbol{x_i}-\boldsymbol{\mu})^T$$

$$J(w) = \frac{w^TS_bw}{w^TS_ww}$$

优化目标推导过程

$$ \frac{\mathrm{d}J(w)}{\mathrm{d}w} = 0 $$

$$
\frac{\mathrm{d}J(w)}{\mathrm{d}w} = \frac{\mathrm{d}}{\mathrm{d}w}(\frac{w^TS_bw}{w^TS_ww}) =
$$

avatar

多分类 LDA 可以有多种实现方法:使用 $S_b$, $S_w$ , $S_t$ 三者中的任何两个即可

LDA 算法的训练流程

  1. 计算类内散度矩阵 $S_w$
  2. 计算类间散度矩阵 $S_b$
  3. 计算矩阵 $S_w^{-1}S_b$
  4. 计算矩阵 $S_w^{-1}S_b$ 的特征值与特征向量,按从小到大的顺序选取前 $d$ 个特征值和对应的 $d$ 个特征向量,得到投影矩阵 $w$

sklearn包 LDA 的使用

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA

lda = LDA(solver='svd', n_components=LDA_components)

# fit
lda.fit(train_data, train_label)

# transform
train_data = lda.transform(train_data)
test_data = lda.transform(test_data)

LDA Python 代码实现

相关论文

参考

[1] CSLT-THU王东老师PPT-Static Analysis: LDA

[2] 西瓜书-线性判别分析

[3] CSLT-THU王东老师-现代机器学习导论

[4] 机器学习实验室微信公众号-数学推导LDA线性判别分析

[5] 博客园-LDA

[6] 知乎-线性判别分析LDA原理及推导过程(非常详细)

[7] THU袁博老师数据挖掘课程-数据预处理PPT

[8] 知乎-Fisher判别分析(Fisher Discriminant Analysis)