EM算法

混合正态分布的EM算法

我们将生成两组数据,每组来自不同的正态分布,然后将这两组数据混合起来。假设这两个正态分布分别为 \(N\left(\mu_1=5, \sigma_1^2=2\right)\)\(N\left(\mu_2=10, \sigma_2^2=3\right)\) ,混合系数分别为 \(\pi_1=0.4\)\(\pi_2=0.6\)

rm(list = ls())
set.seed(123) # 为了结果可重现
# 生成第一个分布的数据
n <- 1000
mu1 <- 5
sigma1 <- sqrt(2)
Y1 <- rnorm(n = n, mean = mu1, sd = sigma1)

# 生成第二个分布的数据
mu2 <- 10
sigma2 <- sqrt(3)
Y2 <- rnorm(n = n, mean = mu2, sd = sigma2)

# 混合这两组数据
W <- rbinom(n = n, size = 1, prob = 0.3)

Y <- (1-W) * Y1 + W * Y2

看一下分布

library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.0      ✔ purrr   0.3.4 
## ✔ tibble  3.1.8      ✔ dplyr   1.0.10
## ✔ tidyr   1.2.1      ✔ stringr 1.4.1 
## ✔ readr   2.1.3      ✔ forcats 0.5.2
## Warning: 程辑包'ggplot2'是用R版本4.2.2 来建造的
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
plot_data <- data.frame(
    Y1 = Y1,
    Y2 = Y2,
    Y = Y
)
# 将数据从宽格式转换为长格式
long_data <- pivot_longer(plot_data, cols = c(Y1, Y2, Y), names_to = "Variable", values_to = "Value")

# 使用ggplot2绘制密度图
ggplot(long_data, aes(x = Value, fill = Variable)) + 
  geom_density(alpha = 0.5) + 
  labs(title = "数据的密度估计", x = " ", y = " ") + 
  theme_minimal() + 
  scale_fill_brewer(palette = "Set1") # 使用色彩美观的调色板

用EM算法估计参数

Qt <- function(lambda, gamma, data, mu, sigma) {
    return(sum((1 - gamma) * log(dnorm(x = data, mean = mu[1], sd = sqrt(sigma[1]))) + gamma * log(dnorm(x = data, mean = mu[2], sd = sqrt(sigma[2]))) + sum(gamma) * log(lambda) + (length(data) - sum(gamma)) * log(1 - lambda)))
}

EMAlgorithm <- function(data, max.iter = 1000, tol = 1e-6) {
  N <- length(data)
  # 初始化参数
  mu <- c(mean(data), mean(data)+sd(data))
  sigma <- c(var(data), var(data))
  lambda <- 0.4
  iter <- 0
  seq_lambda <- c(lambda)
  while (iter < max.iter) {
        gamma <- lambda * dnorm(data, mean = mu[2], sd = sqrt(sigma[2])) / (lambda * dnorm(data, mean = mu[2], sd = sqrt(sigma[2])) + (1 - lambda) * dnorm(data, mean = mu[1], sd = sqrt(sigma[1])))
        Qt_old <- Qt(lambda, gamma, data, mu, sigma)
        mu[1] <- sum((1 - gamma) * data) / sum(1 - gamma)
        sigma[1] <- sum((1 - gamma) * (data - mu[1])^2) / sum(1 - gamma)
        mu[2] <- sum(gamma * data) / sum(gamma)
        sigma[2] <- sum(gamma * (data - mu[2])^2) / sum(gamma) 
        lambda_new <- mean(gamma)
        iter <- iter + 1
        seq_lambda <- c(seq_lambda, lambda_new)
        # 判断是否满足停止准则
        if (abs(Qt(lambda = lambda_new, gamma, data, mu, sigma) - Qt_old) < tol) {
            break
        }
        lambda <- lambda_new
    }
# 输出结果
    if (iter == max.iter) {
      return(paste0("Maximum iterations reached. Solution may not be accurate."))
    } else {
      return(list(mu = mu, sigma = sigma, lambda = lambda, iterations = iter, seq = seq_lambda))
      }
}

# 运行EM算法
result <- EMAlgorithm(Y)
print(result$mu)
## [1]  5.03671 10.08486
print(result$sigma)
## [1] 1.927304 2.871522
print(result$lambda)
## [1] 0.284193
print(result$iterations)
## [1] 120
plot(result$seq, type = "l", ylim = c(0.1, 0.5), xlab = "迭代次数", ylab = "", main = "迭代的收敛")

混合Gamma分布的EM算法

考虑一个三个Gamma分布构成分混合分布, \[ X_i \sim \Gamma\left(\frac{1}{2}, \frac{1}{2\lambda_i}\right), \quad i = 1,2,3. \] 且满足\(\lambda_1 + \lambda_2 + \lambda_3 = 1\), 因此用极大似然估计, 求解是一个二维搜索问题.

对于二维求解, 我们可以用R语言内置的optim语句, 默认方法是Nelder Mead算法, 还有拟牛顿法, 共轭梯度法, 界约束最优化, 模拟退火等.

rm(list = ls())

set.seed(123)
m <- 2000
lambda <- c(0.6, 0.25, 0.15)
lam <- sample(lambda, size = 2000, replace = TRUE)
y <- rgamma(m, shape = 0.5, rate = 1/(2*lam))

二维最优化

LL <- function(lambda, y) {
    lambda3 <-  1 - sum(lambda)
    f1 <- dgamma(y, shape = 1/2, rate = 1/(2*lambda[1]))
    f2 <- dgamma(y, shape = 1/2, rate = 1/(2*lambda[2]))
    f3 <- dgamma(y, shape = 1/2, rate = 1/(2*lambda3))
    
    f <- f1/3 +f2/3 +f3/3
    return(-sum(log(f)))
}
opt <- optim(c(0.5, 0.3), LL, y=y)
print(as.data.frame(unlist(opt)))
##                  unlist(opt)
## par1               0.6234268
## par2               0.1879653
## value           -850.9974794
## counts.function   41.0000000
## counts.gradient           NA
## convergence        0.0000000
theta <- c(opt$par, 1 - sum(opt$par))
print(theta)
## [1] 0.6234268 0.1879653 0.1886079

EM算法

N <- 10000 # max numver of iterations
L <- c(0.5, 0.4, 0.1) # initial value
tol <- .Machine$double.eps^0.5
L_old <- L + 1
for (i in 1:N) {
    f1 <- dgamma(y, shape = 1/2, rate = 1/(2*L[1]))
    f2 <- dgamma(y, shape = 1/2, rate = 1/(2*L[2]))
    f3 <- dgamma(y, shape = 1/2, rate = 1/(2*L[3]))
    
    py <- f1 / (f1+f2+f3)
    qy <- f2 / (f1+f2+f3)
    ry <- f3 / (f1+f2+f3)
    
    mu1 <-  sum(y * py) / sum(py)
    mu2 <-  sum(y * qy) / sum(qy)
    mu3 <-  sum(y * ry) / sum(ry)
    L <- c(mu1, mu2, mu3)
    L <- L / sum(L)
    
    if (sum(abs(L - L_old))<tol) {
        break
    }
    L_old <- L
}
print(list(lambda = L/sum(L), iter = i, tol = tol))
## $lambda
## [1] 0.6176558 0.1911725 0.1911716
## 
## $iter
## [1] 725
## 
## $tol
## [1] 1.490116e-08