Skip to content
Snippets Groups Projects
title: "kmer clustering"
author: "Laurent Modolo"
date: "`r Sys.Date()`"
output: html_document
knitr::opts_chunk$set(echo = TRUE)
library(mclust)
library(tidyverse)
library(mvtnorm)

Simulation

If we expect to have with

XfX_f
the female k-mer count and
XmX_m
the male k-mer count, the following relation for a XY species:

  • For the autosomal chromosomes:
    Xm=XfX_m = X_f
  • For the X chromosomes:
    Xm=2XmX_m = 2 X_m
  • For the Y chromosomes:
    Xm=0+XfX_m = 0^+ X_f

Which becomes on the

loglog
scale:

  • For the autosomal chromosomes:
    log(Xm)=log(Xf)\log(X_m) = \log(X_f)
  • For the X chromosomes:
    log(Xm)=log(2)+log(Xm)\log(X_m) = \log(2) + \log(X_m)
  • For the Y chromosomes:
    log(Xm)=log(0+)+log(Xf)\log(X_m) = log(0^+) + log(X_f)

Test slope for a given sigma matrice

test_slope <- function(x, y, rho) {
    test <- mvtnorm::rmvnorm(1e4, mean = c(0, 0), sigma = matrix(c(x^2, rho*x*y, rho*x*y, y^2), ncol = 2), checkSymmetry = F, method = "svd") %>% 
        as_tibble()
    ggplot(data = test, aes(x = V1, y = V2)) +
        geom_point() +
        geom_smooth(method = lm) +
        labs(title = lm(test$V2 ~ test$V1)$coef[2]) +
        coord_fixed()
}
test_slope(1.05, 2.05, 0.95)

Simulate k-mer counts data

sim_kmer <- function(n_kmer, mean_coef, sex = "XY") {
    data <- tibble(
        sex = "F",
        count = mvtnorm::rmvnorm(n_kmer * .1, mean = c(1, 2)*mean_coef, sigma = matrix(c(1.05, 2, 2, 4.05) * mean_coef^1.5, ncol = 2), checkSymmetry = F, method = "svd") %>%
            as_tibble()
    ) %>%
        unnest(count) %>% 
        bind_rows(
            tibble(
                sex = "A",
                count = mvtnorm::rmvnorm(n_kmer * 0.85, mean = c(2, 2)*mean_coef, sigma = matrix(c(1, .95, .95, 1) * mean_coef^1.5 * 4, ncol = 2), method = "svd")
                %>% as_tibble()
            ) %>% unnest(count)
        )
    if (sex == "XY") {
        data <- data %>% 
            bind_rows(
                tibble(
                    sex = "M",
                    count = mvtnorm::rmvnorm(n_kmer * .05, mean = c(1, 0)*mean_coef, sigma = matrix(c(.9, .05, .05, .05) * mean_coef^1.5, ncol = 2), method = "svd")
                    %>% as_tibble()
                ) %>%
                unnest(count)
            )
    }
    data %>% 
        rename(count_m = V1,
               count_f = V2)
}
data <- sim_kmer(1e4, 1000, "XY")
data %>% 
    ggplot(aes(x = count_m, y = count_f, color = sex)) +
    geom_point() +
    coord_fixed()
data_clust = data %>% select(-c("sex")) %>% mclust::Mclust(G = 3)
summary(data_clust)
plot(data_clust, what = "classification")
plot(data_clust, what = "uncertainty")
expand_theta <- function(theta, cluster_coef, sex) {
    theta_ref <- list(
    "a" = list(
        "pi" = theta$pi[1],
        "mu" = cluster_coef$a * theta$mu,
        "sigma" = theta$sigma$a
    ),
    "f" = list(
        "pi" = theta$pi[2],
        "mu" = cluster_coef$f * theta$mu,
        "sigma" = theta$sigma$f
    ))
    if (sex == "XY") {
        theta_ref[["m"]] <- list(
            "pi" = theta$pi[3],
            "mu" = cluster_coef$m * theta$mu,
            "sigma" = theta$sigma$m
        )
    }
    return(theta_ref)
}

params_diff <- function(old_theta, theta, threshold) {
    result <- max(
        max(abs(old_theta$pi - theta$pi)),
        max(abs(old_theta$mu - theta$mu))
    )
    if (is.finite(result)) {
        return(results > threshold)
    }
    return(T)
}

proba_point <- function(x, theta, cluster_coef, sex) {
    proba <- c()
    for (params in expand_theta(theta, cluster_coef, sex)) {
        proba <- cbind(proba, params$pi * 
            mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
        )
    }
    return(proba)
}

loglik <- function(x, theta, cluster_coef, sex) {
    sum(log(rowSums(proba_point(x, theta, cluster_coef, sex))))
}

# EM function
E_proba <- function(x, theta, cluster_coef, sex) {
    proba <- proba_point(x, theta, cluster_coef, sex)
    proba_norm <- rowSums(proba)
    for (cluster in 1:ncol(proba)) {
        proba[, cluster] <- proba[, cluster] / proba_norm
        proba[proba_norm == 0, cluster] <- 1 / ncol(proba)
    }
    return(proba)
}

E_N_clust <- function(proba) {
    colSums(proba)
}

# Function for mean update
M_mean <- function(x, proba, N_clust, sex) {
    mu <- 0
    for (cluster in 1:ncol(proba)) {
        if (cluster == 1) {
            mu <- mu + 
                mean(colSums(x * c(0.5, 0.5) * proba[, cluster]) / N_clust[cluster])
        }
        if (cluster == 2) {
            mu <- mu +
                mean(colSums(x * c(1, 0.5) * proba[, cluster]) / N_clust[cluster])
        }
        if (cluster == 3) {
            mu <- mu +
                (colSums(x * c(1, 0) * proba[, cluster]) / N_clust[cluster])[1]
        }
    }
    if (sex == "XY") {
        return(mu / 3)
    }
    return(mu / 2)
}

M_cov <- function(x, proba, mu, N_clust, cluster_coef, sex) {
    cov_clust <- list() 
    for (cluster in 1:ncol(proba)) {
        cov_clust[[cluster]] <- t(proba[, cluster] * (x - mu * cluster_coef[[cluster]])) %*% (x - mu * cluster_coef[[cluster]]) / N_clust[cluster]
    }
    sigma <- list()
    sigma$a <- cov_clust[[1]]
    sigma$f <- cov_clust[[2]]
    if (sex == "XY") {
        sigma$m <- cov_clust[[3]]
    }
    return(sigma)
}

plot_proba <- function(x, proba, sex = "XY") {
    if (sex == "XY") {
        as_tibble(x) %>% 
            mutate(
                proba_a = proba[, 1],
                proba_f = proba[, 2],
                proba_m = proba[, 3],
                clust_proba = rgb(proba_f, proba_m, proba_a, maxColorValue = 1)
            ) %>% 
            ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
            geom_point() +
            scale_color_identity()
    } else {
        as_tibble(x) %>% 
            mutate(
                proba_a = proba[, 1],
                proba_f = proba[, 2],
                clust_proba = rgb(proba_f, 0,  proba_a, maxColorValue = 1)
            ) %>% 
            ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
            geom_point() +
            scale_color_identity()
    }
}

init_param <- function(x, sex) {
    cluster_coef <- list(
        "a" = c(2, 2),
        "f" = c(1, 2)
    )
    theta <- list(
        "pi" = c(.85, .1, .05),
        "mu" = mean(colMeans(x)) * .5
    )
    theta$sigma <- list(
            "a" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2),
            "f" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
    )
    if (sex == "XY") {
        cluster_coef$m <- c(1, 0)
        theta$sigma$m <- matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
    }
    return(list(cluster_coef = cluster_coef, theta = theta))
}

compute_bic <- function(x, loglik, sex = "XY") {
    k <- 1 + 4 * 2
    if (sex == "YX") {
        k <- k + 4
    }
    return(k * log(nrow(x)) - 2 * loglik)
}


EM_clust <- function(x, threshold = 1, sex = "XY") {
    old_loglik <- -Inf
    new_loglik <- 0
    param <- init_param(x, sex)
    while (abs(new_loglik - old_loglik) > threshold) {
        old_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
        proba <- E_proba(x, param$theta, param$cluster_coef, sex)
        param$theta$pi <- E_N_clust(proba)
        param$theta$mu <- M_mean(x, proba, param$theta$pi, sex)
        param$theta$sigma <- M_cov(x, proba, param$theta$mu, param$theta$pi, param$cluster_coef, sex)
        param$theta$pi <- param$theta$pi / nrow(x)
        new_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
        if(is.infinite(new_loglik)) {
            break
        }
    }
    return(list(proba = proba, theta = param$theta, loglik = new_loglik, BIC = compute_bic(x, new_loglik, sex)))
}

boostrap_BIC <- function(x, sex = "XY", threshold = 1, nboot = 100, bootsize = 1000, core = 6) {
    parallel::mclapply(as.list(1:nboot), function(iter, x, bootsize, sex) {
        res <- x %>% 
            dplyr::select(count_m, count_f) %>% 
            sample_n(bootsize, replace = T) %>% 
            as.matrix() %>% 
            EM_clust(sex = sex)
        res$BIC
    }, x = x, bootsize = bootsize, sex = sex, mc.cores = 6) %>% 
        unlist()
}

compare_BIC <- function(x, threshold = 1, nboot = 100, bootsize = 1000, core = 6) {
    tibble(
        BIC_XY = boostrap_BIC(x, sex = "XY", threshold = threshold, nboot = nboot, bootsize = bootsize, core = core),
        BIC_XO = boostrap_BIC(x, sex = "X0", threshold = threshold, nboot = nboot, bootsize = bootsize, core = core)
    )
}

clustering XY

model_XY <- data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    EM_clust()
data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    plot_proba(model_XY$proba)

clustering XO

model_XO <- data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    EM_clust(sex = "X0")
data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    plot_proba(model_XO$proba, sex = "X0")

LRT

For XY

data <- sim_kmer(1e2, 1000, "XY")
model_XY <- data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    EM_clust()
model_XO <- data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    EM_clust(sex = "XO")

data <- sim_kmer(1e6, 1000, "XY")

For XO

data <- sim_kmer(1e2, 1000, "XO")
model_XY <- data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    EM_clust()
model_XO <- data %>%
    dplyr::select(count_m, count_f) %>%
    as.matrix() %>% 
    EM_clust(sex = "XO")
pchisq(-2 * (model_XY$loglik - model_XO$loglik), 4)

Get Y k-mer

res <- compare_BIC(data)
res %>%
    pivot_longer(cols = 1:2) %>% 
    ggplot(aes(x = name, y = value)) +
        geom_violin()

data %>%
    mutate(y_proba = model_XY$proba[,3]) %>% 
    ggplot(aes(x = count_m, count_f, color = y_proba)) +
    geom_point() +
    theme_bw()

With real data

data <- read_tsv("results/12/mbelari/mbelari.csv", show_col_types = FALSE)
format(object.size(data), units = "Mb")
annotation <- read_csv("data/sample.csv", show_col_types = FALSE) %>% 
  pivot_longer(!c(sex, specie), names_to = "read", values_to = "file") %>% 
  mutate(
    file = gsub("/scratch/Bio/lmodolo/kmer_diff/data/.*/", "", file, perl = T),
    file = gsub("\\.fasta\\.gz", "", file, perl = T)
  ) %>% 
  mutate(
    file = paste0(file, ".csv")
  ) %>% 
  select(!c(read)) %>% 
  group_by(specie, sex) %>% 
  nest(.key = "files")
count <- annotation %>% 
  group_by(specie) %>% 
  nest(.key = "sex") %>% 
  mutate(count = lapply(sex, function(files, data){
    files_f <- files %>% filter(sex == "female") %>% unnest(files) %>% pull(file) %>% as.vector()
    files_m <- files %>% filter(sex == "male") %>% unnest(files) %>% pull(file) %>% as.vector()
    data %>% 
      select(kmer) %>% 
      mutate(
         female = data %>% select(any_of(files_f)) %>% rowMeans(),
         male = data %>% select(any_of(files_m)) %>% rowMeans()
      )
  }, data = data)) %>%
  unnest(sex) %>%
  unnest(count)
save(count, file = "results/12/mbelari/counts.Rdata")

M belari data

load("results/12/mbelari/counts.Rdata")
s_count <- count %>%
    ungroup() %>% 
    sample_frac(0.01) %>% 
    dplyr::select(male, female) %>% 
    mutate(
        count_m = log1p(male),
        count_f = log1p(female)
    )
model_XY <-  s_count %>%
    as.matrix() %>% 
    EM_clust()
model_XO <- s_count %>%
    as.matrix() %>% 
    EM_clust(sex = "XO")
model_XO$BIC
model_XY$BIC

s_count %>% 
    as.matrix() %>% 
    plot_proba(model_XO$proba, sex = "XO")