Skip to content
Snippets Groups Projects
Verified Commit d5e24c1b authored by Laurent Modolo's avatar Laurent Modolo
Browse files

add clustering XY and XO

parent a7d32fba
No related branches found
No related tags found
No related merge requests found
...@@ -90,24 +90,26 @@ plot(data_clust, what = "uncertainty") ...@@ -90,24 +90,26 @@ plot(data_clust, what = "uncertainty")
```{r} ```{r}
expand_theta <- function(theta, cluster_coef) { expand_theta <- function(theta, cluster_coef, sex) {
list( theta_ref <- list(
"f" = list(
"pi" = theta$pi[1],
"mu" = cluster_coef$f * theta$mu,
"sigma" = theta$sigma$f
),
"m" = list(
"pi" = theta$pi[2],
"mu" = cluster_coef$m * theta$mu,
"sigma" = theta$sigma$m
),
"a" = list( "a" = list(
"pi" = theta$pi[3], "pi" = theta$pi[1],
"mu" = cluster_coef$a * theta$mu, "mu" = cluster_coef$a * theta$mu,
"sigma" = theta$sigma$a "sigma" = theta$sigma$a
) ),
) "f" = list(
"pi" = theta$pi[2],
"mu" = cluster_coef$f * theta$mu,
"sigma" = theta$sigma$f
))
if (sex == "XY") {
theta_ref[["m"]] <- list(
"pi" = theta$pi[3],
"mu" = cluster_coef$m * theta$mu,
"sigma" = theta$sigma$m
)
}
return(theta_ref)
} }
params_diff <- function(old_theta, theta, threshold) { params_diff <- function(old_theta, theta, threshold) {
...@@ -121,18 +123,18 @@ params_diff <- function(old_theta, theta, threshold) { ...@@ -121,18 +123,18 @@ params_diff <- function(old_theta, theta, threshold) {
return(T) return(T)
} }
proba_total <- function(x, theta, cluster_coef) { proba_total <- function(x, theta, cluster_coef, sex) {
proba <- 0 proba <- 0
for (params in expand_theta(theta, cluster_coef)) { for (params in expand_theta(theta, cluster_coef, sex)) {
proba <- proba + params$pi * proba <- proba + params$pi *
mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma) mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
} }
return(proba) return(proba)
} }
proba_point <- function(x, theta, cluster_coef) { proba_point <- function(x, theta, cluster_coef, sex) {
proba <- c() proba <- c()
for (params in expand_theta(theta, cluster_coef)) { for (params in expand_theta(theta, cluster_coef, sex)) {
proba <- cbind(proba, params$pi * proba <- cbind(proba, params$pi *
mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma) mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
) )
...@@ -140,13 +142,13 @@ proba_point <- function(x, theta, cluster_coef) { ...@@ -140,13 +142,13 @@ proba_point <- function(x, theta, cluster_coef) {
return(proba) return(proba)
} }
loglik <- function(x, theta, cluster_coef) { loglik <- function(x, theta, cluster_coef, sex) {
-log(sum(proba_total(x, theta, cluster_coef))) -log(sum(proba_total(x, theta, cluster_coef, sex)))
} }
# EM function # EM function
E_proba <- function(x, theta, cluster_coef) { E_proba <- function(x, theta, cluster_coef, sex) {
proba <- proba_point(x, theta, cluster_coef) proba <- proba_point(x, theta, cluster_coef, sex)
proba_norm <- rowSums(proba) proba_norm <- rowSums(proba)
for (cluster in 1:ncol(proba)) { for (cluster in 1:ncol(proba)) {
proba[, cluster] <- proba[, cluster] / proba_norm proba[, cluster] <- proba[, cluster] / proba_norm
...@@ -160,78 +162,101 @@ E_N_clust <- function(proba) { ...@@ -160,78 +162,101 @@ E_N_clust <- function(proba) {
} }
# Function for mean update # Function for mean update
M_mean <- function(x, proba, N_clust) { M_mean <- function(x, proba, N_clust, sex) {
mu <- 0 mu <- 0
for (cluster in 1:ncol(proba)) { for (cluster in 1:ncol(proba)) {
if (cluster == 1) { if (cluster == 1) {
mu <- mu + 1/3 * mu <- mu +
mean(colSums(x * c(1, 0.5) * proba[, cluster]) / N_clust[cluster]) mean(colSums(x * c(0.5, 0.5) * proba[, cluster]) / N_clust[cluster])
} }
if (cluster == 2) { if (cluster == 2) {
mu <- mu + 1/3 * mu <- mu +
(colSums(x * c(1, 0) * proba[, cluster]) / N_clust[cluster])[1] mean(colSums(x * c(1, 0.5) * proba[, cluster]) / N_clust[cluster])
} }
if (cluster == 2) { if (cluster == 3) {
mu <- mu + 1/3 * mu <- mu +
mean(colSums(x * c(0.5, 0.5) * proba[, cluster]) / N_clust[cluster]) (colSums(x * c(1, 0) * proba[, cluster]) / N_clust[cluster])[1]
} }
} }
return(mu) if (sex == "XY") {
return(mu / 3)
}
return(mu / 2)
} }
M_cov <- function(x, proba, mu, N_clust, cluster_coef) { M_cov <- function(x, proba, mu, N_clust, cluster_coef, sex) {
cov_clust <- list() cov_clust <- list()
for (cluster in 1:ncol(proba)) { for (cluster in 1:ncol(proba)) {
print(cluster_coef[[cluster]])
cov_clust[[cluster]] <- t(proba[, cluster] * (x - mu * cluster_coef[[cluster]])) %*% (x - mu * cluster_coef[[cluster]]) / N_clust[cluster] cov_clust[[cluster]] <- t(proba[, cluster] * (x - mu * cluster_coef[[cluster]])) %*% (x - mu * cluster_coef[[cluster]]) / N_clust[cluster]
} }
sigma <- list() sigma <- list()
sigma$f <- cov_clust[[1]] sigma$a <- cov_clust[[1]]
sigma$m <- cov_clust[[2]] sigma$f <- cov_clust[[2]]
sigma$a <- cov_clust[[3]] if (sex == "XY") {
sigma$m <- cov_clust[[3]]
}
return(sigma) return(sigma)
} }
plot_proba <- function(x, proba) { plot_proba <- function(x, proba, sex = "XY") {
as_tibble(x) %>% if (sex == "XY") {
mutate( as_tibble(x) %>%
proba_f = proba[, 1], mutate(
proba_m = proba[, 2], proba_a = proba[, 1],
proba_a = proba[, 3], proba_f = proba[, 2],
clust_proba = rgb(proba_f, proba_m, proba_a, maxColorValue = 1) proba_m = proba[, 3],
) %>% clust_proba = rgb(proba_f, proba_m, proba_a, maxColorValue = 1)
ggplot(aes(x = count_m, y = count_f, color = clust_proba)) + ) %>%
geom_point() + ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
scale_color_identity() geom_point() +
scale_color_identity()
} else {
as_tibble(x) %>%
mutate(
proba_a = proba[, 1],
proba_f = proba[, 2],
clust_proba = rgb(proba_f, 0, proba_a, maxColorValue = 1)
) %>%
ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
geom_point() +
scale_color_identity()
}
} }
EM_clust <- function(x, threshold = 0.1) { init_param <- function(x, sex) {
old_loglik <- -Inf
new_loglik <- 0
cluster_coef <- list( cluster_coef <- list(
"f" = c(1, 2), "a" = c(2, 2),
"m" = c(1, 0), "f" = c(1, 2)
"a" = c(2, 2)
) )
theta <- list( theta <- list(
"pi" = c(.1, .05, .85), "pi" = c(.85, .1, .05),
"mu" = mean(colMeans(x)) * .5 "mu" = mean(colMeans(x)) * .5
) )
theta$sigma <- list( theta$sigma <- list(
"f" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2), "a" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2),
"m" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2), "f" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
"a" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
) )
if (sex == "XY") {
cluster_coef$m <- c(1, 0)
theta$sigma$m <- matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
}
return(list(cluster_coef = cluster_coef, theta = theta))
}
EM_clust <- function(x, threshold = 0.1, sex = "XY") {
old_loglik <- -Inf
new_loglik <- 0
param <- init_param(x, sex)
while (abs(new_loglik - old_loglik) > threshold) { while (abs(new_loglik - old_loglik) > threshold) {
old_loglik <- loglik(x, theta, cluster_coef) old_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
proba <- E_proba(x, theta, cluster_coef) proba <- E_proba(x, param$theta, param$cluster_coef, sex)
theta$pi <- E_N_clust(proba) param$theta$pi <- E_N_clust(proba)
theta$mu <- M_mean(x, proba, theta$pi) param$theta$mu <- M_mean(x, proba, param$theta$pi, sex)
theta$sigma <- M_cov(x, proba, theta$mu, theta$pi, cluster_coef) param$theta$sigma <- M_cov(x, proba, param$theta$mu, param$theta$pi, param$cluster_coef, sex)
theta$pi <- theta$pi / nrow(x) param$theta$pi <- param$theta$pi / nrow(x)
new_loglik <- loglik(x, theta, cluster_coef) new_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
} }
print(theta)
return(proba) return(proba)
} }
...@@ -245,6 +270,14 @@ data %>% ...@@ -245,6 +270,14 @@ data %>%
as.matrix() %>% as.matrix() %>%
plot_proba(proba) plot_proba(proba)
proba <- data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
EM_clust(sex = "X0")
data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
plot_proba(proba, sex = "X0")
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment