From 85dab5beb6be27cc4285a04e62f60644300ac440 Mon Sep 17 00:00:00 2001 From: Laurent Modolo <laurent.modolo@ens-lyon.fr> Date: Fri, 20 Oct 2023 10:17:01 +0200 Subject: [PATCH] dev/flat_full_poisson.Rmd: add bayesian prior on clust proportion --- dev/flat_full_poisson.Rmd | 74 +++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/dev/flat_full_poisson.Rmd b/dev/flat_full_poisson.Rmd index c96b50c..28370e5 100644 --- a/dev/flat_full_poisson.Rmd +++ b/dev/flat_full_poisson.Rmd @@ -208,15 +208,25 @@ e_bipoiss_clust <- function(x, params) { #' #' @param x the data matrix (2 columns) #' @param s of size 3 time nrow(x) +#' @param h of size 3 time nrow(x) +#' @param gamma a vector of kappa prior +#' @param m_chr the maximization function for each type of chromosome #' @return a list of 3 lambda #' @importFrom extraDistr dbvpois #' @export -m_bipoiss_clust <- function(x, s, h, +m_bipoiss_clust <- function(x, s, h, gamma, m_chr = list(m_bipoiss_clust_a_chr, m_bipoiss_clust_x_chr, m_bipoiss_clust_y_chr)) { params <- list() + gamma_sum <- sum(gamma) + kappa <- c() for (i in 1:ncol(h)) { params[[i]] <- m_chr[[i]](x, s[, i], h[, i]) - params[[i]]$kappa <- sum(h[, i]) / nrow(h) + params[[i]]$kappa <- (sum(h[, i]) + gamma[i] - 1) / (nrow(h) + gamma_sum - ncol(h)) + kappa <- c(kappa, params[[i]]$kappa) + } + params[[1]]$kappa <- 1 + if (length(kappa) > 1) { + params[[1]]$kappa <- 1 - sum(kappa[2:length(kappa)]) } return(params) } @@ -342,21 +352,22 @@ m_bipoiss_clust_y_chr <- function(x, s, h) { #' #' @param x the data matrix (2 columns) #' @param params the list of 3 lambda +#' @param gamma a vector of kappa prior #' @return logliklihood #' @importFrom extraDistr dbvpois +#' @importFrom Boom ddirichlet #' @export -loglik_bipoiss_clust <- function(x, params) { +loglik_bipoiss_clust <- function(x, params, gamma) { loglik <- matrix(0, nrow(x), length(params)) + kappa <- c() for (i in 1:length(params)) { loglik[, i] <- params[[i]]$kappa * extraDistr::dbvpois(x = x, a = params[[i]]$l1, b = params[[i]]$l2, c = params[[i]]$l3) + kappa <- c(kappa, params[[i]]$kappa) } loglik <- sum(log(rowSums(loglik))) - if (is.infinite(loglik)) { - warning("Inf loglik") - } - if (is.na(loglik)) { - warning("NA loglik") + if (length(kappa) > 1) { + loglik <- loglik + Boom::ddirichlet(probabilities = kappa, nu = gamma, logscale = T) } return(loglik) } @@ -376,10 +387,10 @@ loglik_bipoiss_clust <- function(x, params) { bic_bipoiss_clust <- function(x, params, loglik) { k <- 3 if (length(params) >= 2) { - k <- k + 3 + k <- k + 5 } if (length(params) >= 3) { - k <- k + 3 + k <- k + 4 } return(k * log(nrow(x)) - 2 * loglik) } @@ -507,6 +518,7 @@ e_bipoiss_clust_batch <- function(x, params, hidden = NULL, nbatch = 10) { #' #' @param x the data matrix (2 columns) #' @param params the list of 3 lambda (initial value) +#' @param gamma a vector of kappa prior #' @param nbatch (default: 10) the number of batch #' @param threshold (default: c(0.1, 5)) maximum difference between kappa and lambda #' @param max_iter (default: 100) maximum number of iteration @@ -524,7 +536,9 @@ em_bipoiss_clust <- function(x, params = list( A = list(kappa = 0.4, l1 = mean(x), l2 = mean(x), l3 = mean(x)), X = list(kappa = 0.3, l1 = 1, l2 = mean(x), l3 = mean(x)), Y = list(kappa = 0.3, l1 = mean(x), l2 = 1, l3 = mean(x)) - ), nbatch = 10, threshold = c(0.1, 1), max_iter = 100) { + ), + gamma = c(round(nrow(x) * .8), round(nrow(x) * .15), round(nrow(x) * .05)), + nbatch = 10, threshold = c(0.1, 1), max_iter = 100) { old_params <- params old_params[[1]]$kappa <- Inf hidden <- NULL @@ -532,7 +546,7 @@ em_bipoiss_clust <- function(x, params = list( while (param_diff_poiss(params, old_params, threshold = threshold) & iter < max_iter) { old_params <- params hidden <- e_bipoiss_clust_batch(x, params = params, hidden = hidden, nbatch = nbatch) - params <- m_bipoiss_clust(x, s = hidden$s, h = round(hidden$h)) + params <- m_bipoiss_clust(x, s = hidden$s, h = round(hidden$h), gamma = gamma) iter <- iter + 1 } res <- params @@ -545,8 +559,8 @@ em_bipoiss_clust <- function(x, params = list( if (length(res) == 3) { names(res) <- c("A", "X", "Y") } - res$loglik <- loglik_bipoiss_clust(x, params) - res$BIC <- bic_bipoiss_clust(x, params, loglik = res$loglik) + res$loglik <- loglik_bipoiss_clust(x, params = params, gamma = gamma) + res$BIC <- bic_bipoiss_clust(x, params = params, loglik = res$loglik) res <- get_cluster_poiss(x, res = res, params = params, hidden = hidden) return(res) } @@ -616,7 +630,9 @@ res <- data %>% em_bipoiss_clust(params = list( l2 = mean(data), l3 = mean(data) ) - )) + ), + gamma = c(round(nrow(data) * .8), round(nrow(data) * .2)) + ) rbind( extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3), # X extraDistr::rbvpois(1000*res[[2]]$kappa, res[[2]]$l1, res[[2]]$l2, res[[2]]$l3) # A @@ -647,7 +663,8 @@ res <- data %>% em_bipoiss_clust(params = list( l2 = mean(data), l3 = mean(data) ) - )) + ), + gamma = c(nrow(data))) extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3) %>% # X as_tibble() %>% dplyr::rename(male = 'V1', female = 'V2') %>% @@ -667,6 +684,7 @@ extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3) #' #' @param x a matrix of male female kmers counts #' @param params a list of init parameters +#' @param gamma a vector of kappa prior #' @param threshold theshold to stop the EM algorithm #' @param nboot number of boostrap sample #' @param bootsize size of the boostrap sample (if < 0 we take a percentage of x) @@ -682,7 +700,7 @@ extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3) #' extraDistr::rbvpois(500, 0, 20, 20), # X #' extraDistr::rbvpois(100, 20, 0, 5) # Y #' ) %>% pois_boostrap_EM(nboot = 10) -poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, max_iter, max_error) { +poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, gamma, max_iter, max_error) { res <- list() res$loglik <- c(NA) iter <- 0 @@ -691,7 +709,7 @@ poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, max_iter, max_error as_tibble() %>% sample_frac(bootsize, replace = T) %>% as.matrix() %>% - em_bipoiss_clust(params = params, max_iter = max_iter) + em_bipoiss_clust(params = params, gamma = gamma, max_iter = max_iter) iter <- iter + 1 } return(tibble(loglik = res$loglik, BIC = res$BIC)) @@ -703,6 +721,7 @@ poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, max_iter, max_error #' #' @param x a matrix of male female kmers counts #' @param params a list of init parameters +#' @param gamma a vector of kappa prior #' @param threshold theshold to stop the EM algorithm #' @param nboot number of boostrap sample #' @param bootsize size of the boostrap sample (if < 0 we take a percentage of x) @@ -722,16 +741,19 @@ poiss_boostrap_EM <- function(x, sex = "XY", threshold = 1, nboot = 100, bootsiz params = list( A = list(kappa = 0.4, l1 = mean(x), l2 = mean(x), l3 = mean(x)) ) - if (sex %in% c("XO", "XY")) { - params$X <- list(kappa = 0.3, l1 = 1, l2 = mean(x), l3 = mean(x)) - } - if (sex == "XY") { - params$Y <- list(kappa = 0.3, l1 = mean(x), l2 = 1, l3 = mean(x)) - } + gamma <- c(round(nrow(x) * bootsize)) + if (sex %in% c("XO", "XY")) { + params$X <- list(kappa = 0.3, l1 = 1, l2 = mean(x), l3 = mean(x)) + gamma <- c(round(nrow(x) * bootsize) * .8, round(nrow(x) * bootsize) * .2) + } + if (sex == "XY") { + params$Y <- list(kappa = 0.3, l1 = mean(x), l2 = 1, l3 = mean(x)) + gamma <- c(round(nrow(x) * bootsize) * .8, round(nrow(x) * bootsize) * .15, round(nrow(x) * bootsize) * .05) + } if (core == 1) { - res <- lapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, max_iter = max_iter, max_error = max_error) + res <- lapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, gamma = gamma, max_iter = max_iter, max_error = max_error) } else { - res <- parallel::mclapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, max_iter = max_iter, max_error = max_error, mc.cores = core) + res <- parallel::mclapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, gamma = gamma, max_iter = max_iter, max_error = max_error, mc.cores = core) } return(res) } -- GitLab