-
Laurent Modolo authoredLaurent Modolo authored
clustering.Rmd 9.81 KiB
title: "kmer clustering"
author: "Laurent Modolo"
date: "`r Sys.Date()`"
output: html_document
knitr::opts_chunk$set(echo = TRUE)
library(mclust)
library(tidyverse)
library(mvtnorm)
Simulation
If we expect to have with X_f the female k-mer count and X_m the male k-mer count, the following relation for a XY species:
- For the autosomal chromosomes: X_m = X_f
- For the X chromosomes: X_m = 2 X_m
- For the Y chromosomes: X_m = 0^+ X_f
Which becomes on the log scale:
- For the autosomal chromosomes: \log(X_m) = \log(X_f)
- For the X chromosomes: \log(X_m) = \log(2) + \log(X_m)
- For the Y chromosomes: \log(X_m) = log(0^+) + log(X_f)
Test slope for a given sigma matrice
test_slope <- function(x, y, rho) {
test <- mvtnorm::rmvnorm(1e4, mean = c(0, 0), sigma = matrix(c(x^2, rho*x*y, rho*x*y, y^2), ncol = 2), checkSymmetry = F, method = "svd") %>%
as_tibble()
ggplot(data = test, aes(x = V1, y = V2)) +
geom_point() +
geom_smooth(method = lm) +
labs(title = lm(test$V2 ~ test$V1)$coef[2]) +
coord_fixed()
}
test_slope(1.05, 2.05, 0.95)
Simulate k-mer counts data
n_kmer = 1e2
mean_coef = 1000
data <- tibble(
sex = "F",
count = mvtnorm::rmvnorm(n_kmer * .1, mean = c(1, 2)*mean_coef, sigma = matrix(c(1.05, 2, 2, 4.05) * mean_coef^1.5, ncol = 2), checkSymmetry = F, method = "svd") %>%
as_tibble()
) %>%
unnest(count) %>%
bind_rows(
tibble(
sex = "M",
count = mvtnorm::rmvnorm(n_kmer * .05, mean = c(1, 0)*mean_coef, sigma = matrix(c(.9, .05, .05, .05) * mean_coef^1.5, ncol = 2), method = "svd")
%>% as_tibble()
) %>%
unnest(count)
) %>% bind_rows(
tibble(
sex = "A",
count = mvtnorm::rmvnorm(n_kmer * 0.85, mean = c(2, 2)*mean_coef, sigma = matrix(c(1, .95, .95, 1) * mean_coef^1.5 * 4, ncol = 2), method = "svd")
%>% as_tibble()
) %>% unnest(count)
) %>%
rename(count_m = V1,
count_f = V2)
data %>%
ggplot(aes(x = count_m, y = count_f, color = sex)) +
geom_point() +
coord_fixed()
data_clust = data %>% select(-c("sex")) %>% mclust::Mclust(G = 3)
summary(data_clust)
plot(data_clust, what = "classification")
plot(data_clust, what = "uncertainty")
expand_theta <- function(theta, cluster_coef, sex) {
theta_ref <- list(
"a" = list(
"pi" = theta$pi[1],
"mu" = cluster_coef$a * theta$mu,
"sigma" = theta$sigma$a
),
"f" = list(
"pi" = theta$pi[2],
"mu" = cluster_coef$f * theta$mu,
"sigma" = theta$sigma$f
))
if (sex == "XY") {
theta_ref[["m"]] <- list(
"pi" = theta$pi[3],
"mu" = cluster_coef$m * theta$mu,
"sigma" = theta$sigma$m
)
}
return(theta_ref)
}
params_diff <- function(old_theta, theta, threshold) {
result <- max(
max(abs(old_theta$pi - theta$pi)),
max(abs(old_theta$mu - theta$mu))
)
if (is.finite(result)) {
return(results > threshold)
}
return(T)
}
proba_total <- function(x, theta, cluster_coef, sex) {
proba <- 0
for (params in expand_theta(theta, cluster_coef, sex)) {
proba <- proba + params$pi *
mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
}
return(proba)
}
proba_point <- function(x, theta, cluster_coef, sex) {
proba <- c()
for (params in expand_theta(theta, cluster_coef, sex)) {
proba <- cbind(proba, params$pi *
mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
)
}
return(proba)
}
loglik <- function(x, theta, cluster_coef, sex) {
-log(sum(proba_total(x, theta, cluster_coef, sex)))
}
# EM function
E_proba <- function(x, theta, cluster_coef, sex) {
proba <- proba_point(x, theta, cluster_coef, sex)
proba_norm <- rowSums(proba)
for (cluster in 1:ncol(proba)) {
proba[, cluster] <- proba[, cluster] / proba_norm
proba[proba_norm == 0, cluster] <- 1 / ncol(proba)
}
return(proba)
}
E_N_clust <- function(proba) {
colSums(proba)
}
# Function for mean update
M_mean <- function(x, proba, N_clust, sex) {
mu <- 0
for (cluster in 1:ncol(proba)) {
if (cluster == 1) {
mu <- mu +
mean(colSums(x * c(0.5, 0.5) * proba[, cluster]) / N_clust[cluster])
}
if (cluster == 2) {
mu <- mu +
mean(colSums(x * c(1, 0.5) * proba[, cluster]) / N_clust[cluster])
}
if (cluster == 3) {
mu <- mu +
(colSums(x * c(1, 0) * proba[, cluster]) / N_clust[cluster])[1]
}
}
if (sex == "XY") {
return(mu / 3)
}
return(mu / 2)
}
M_cov <- function(x, proba, mu, N_clust, cluster_coef, sex) {
cov_clust <- list()
for (cluster in 1:ncol(proba)) {
cov_clust[[cluster]] <- t(proba[, cluster] * (x - mu * cluster_coef[[cluster]])) %*% (x - mu * cluster_coef[[cluster]]) / N_clust[cluster]
}
sigma <- list()
sigma$a <- cov_clust[[1]]
sigma$f <- cov_clust[[2]]
if (sex == "XY") {
sigma$m <- cov_clust[[3]]
}
return(sigma)
}
plot_proba <- function(x, proba, sex = "XY") {
if (sex == "XY") {
as_tibble(x) %>%
mutate(
proba_a = proba[, 1],
proba_f = proba[, 2],
proba_m = proba[, 3],
clust_proba = rgb(proba_f, proba_m, proba_a, maxColorValue = 1)
) %>%
ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
geom_point() +
scale_color_identity()
} else {
as_tibble(x) %>%
mutate(
proba_a = proba[, 1],
proba_f = proba[, 2],
clust_proba = rgb(proba_f, 0, proba_a, maxColorValue = 1)
) %>%
ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
geom_point() +
scale_color_identity()
}
}
init_param <- function(x, sex) {
cluster_coef <- list(
"a" = c(2, 2),
"f" = c(1, 2)
)
theta <- list(
"pi" = c(.85, .1, .05),
"mu" = mean(colMeans(x)) * .5
)
theta$sigma <- list(
"a" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2),
"f" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
)
if (sex == "XY") {
cluster_coef$m <- c(1, 0)
theta$sigma$m <- matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
}
return(list(cluster_coef = cluster_coef, theta = theta))
}
EM_clust <- function(x, threshold = 0.1, sex = "XY") {
old_loglik <- -Inf
new_loglik <- 0
param <- init_param(x, sex)
while (abs(new_loglik - old_loglik) > threshold) {
old_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
proba <- E_proba(x, param$theta, param$cluster_coef, sex)
param$theta$pi <- E_N_clust(proba)
param$theta$mu <- M_mean(x, proba, param$theta$pi, sex)
param$theta$sigma <- M_cov(x, proba, param$theta$mu, param$theta$pi, param$cluster_coef, sex)
param$theta$pi <- param$theta$pi / nrow(x)
new_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
}
return(proba)
}
proba <- data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
EM_clust()
data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
plot_proba(proba)
proba <- data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
EM_clust(sex = "X0")
data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
plot_proba(proba, sex = "X0")
theta4 <- list(
"pi" = c(.1, .05, .85),
"mu" = list(c(1000, 2000, 1000, 2000), c(1000, 0, 1000, 0), c(1000, 1000, 1000, 1000)),
"sigma" = list(
"f" = diag(1000, nrow=4, ncol=4),
"m" = diag(1000, nrow=4, ncol=4),
"a" = diag(1000, nrow=4, ncol=4)
)
)
proba4 <- data %>%
dplyr::select(count_m, count_f) %>%
dplyr::mutate(
count_m2 = count_m,
count_f2 = count_f) %>%
as.matrix() %>%
EM_clust(theta4)
data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
plot_proba(proba4)
With real data
data <- read_tsv("../results/12/mbelari/mbelari.csv", show_col_types = FALSE)
format(object.size(data), units = "Mb")
annotation <- read_csv("../data/sample.csv", show_col_types = FALSE) %>%
pivot_longer(!c(sex, specie), names_to = "read", values_to = "file") %>%
mutate(
file = gsub("/scratch/Bio/lmodolo/kmer_diff/data/.*/", "", file, perl = T),
file = gsub("\\.fasta\\.gz", "", file, perl = T)
) %>%
mutate(
file = paste0(file, ".csv")
) %>%
select(!c(read)) %>%
group_by(specie, sex) %>%
nest(.key = "files")
count <- annotation %>%
group_by(specie) %>%
nest(.key = "sex") %>%
mutate(count = lapply(sex, function(files, data){
files_f <- files %>% filter(sex == "female") %>% unnest(files) %>% pull(file) %>% as.vector()
files_m <- files %>% filter(sex == "male") %>% unnest(files) %>% pull(file) %>% as.vector()
data %>%
select(kmer) %>%
mutate(
female = data %>% select(any_of(files_f)) %>% rowMeans(),
male = data %>% select(any_of(files_m)) %>% rowMeans()
)
}, data = data)) %>%
unnest(sex) %>%
unnest(count)
save(count, file = "../results/12/mbelari/counts.Rdata")
M belari data
mb_data <- data %>%
select(kmer) %>%
mutate(
female = data %>% select(any_of(mb_f)) %>% rowMeans(),
male = data %>% select(any_of(mb_m)) %>% rowMeans()
)
save(mb_data, file = "../results/mb_data.Rdata")
load("../results/mb_data.Rdata")
mb_data %>%
sample_frac(0.1) %>%
ggplot(aes(x = log1p(male), y = log1p(female))) +
geom_point() +
coord_fixed() +
theme_bw()