From 85dab5beb6be27cc4285a04e62f60644300ac440 Mon Sep 17 00:00:00 2001
From: Laurent Modolo <laurent.modolo@ens-lyon.fr>
Date: Fri, 20 Oct 2023 10:17:01 +0200
Subject: [PATCH] dev/flat_full_poisson.Rmd: add bayesian prior on clust
 proportion

---
 dev/flat_full_poisson.Rmd | 74 +++++++++++++++++++++++++--------------
 1 file changed, 48 insertions(+), 26 deletions(-)

diff --git a/dev/flat_full_poisson.Rmd b/dev/flat_full_poisson.Rmd
index c96b50c..28370e5 100644
--- a/dev/flat_full_poisson.Rmd
+++ b/dev/flat_full_poisson.Rmd
@@ -208,15 +208,25 @@ e_bipoiss_clust <- function(x, params) {
 #'
 #' @param x the data matrix (2 columns)
 #' @param s of size 3 time nrow(x)
+#' @param h of size 3 time nrow(x)
+#' @param gamma a vector of kappa prior
+#' @param m_chr the maximization function for each type of chromosome
 #' @return a list of 3 lambda
 #' @importFrom extraDistr dbvpois
 #' @export 
-m_bipoiss_clust <- function(x, s, h,
+m_bipoiss_clust <- function(x, s, h, gamma,
                             m_chr = list(m_bipoiss_clust_a_chr, m_bipoiss_clust_x_chr, m_bipoiss_clust_y_chr)) {
     params <- list()
+    gamma_sum <- sum(gamma)
+    kappa <- c()
     for (i in 1:ncol(h)) {
       params[[i]] <- m_chr[[i]](x, s[, i], h[, i])
-      params[[i]]$kappa <- sum(h[, i]) / nrow(h)
+      params[[i]]$kappa <- (sum(h[, i]) + gamma[i] - 1) / (nrow(h) + gamma_sum - ncol(h))
+      kappa <- c(kappa, params[[i]]$kappa)
+    }
+    params[[1]]$kappa <- 1
+    if (length(kappa) > 1) {
+      params[[1]]$kappa <- 1 - sum(kappa[2:length(kappa)])
     }
     return(params)
 }
@@ -342,21 +352,22 @@ m_bipoiss_clust_y_chr <- function(x, s, h) {
 #'
 #' @param x the data matrix (2 columns)
 #' @param params the list of 3 lambda
+#' @param gamma a vector of kappa prior
 #' @return logliklihood
 #' @importFrom extraDistr dbvpois
+#' @importFrom Boom ddirichlet
 #' @export 
-loglik_bipoiss_clust <- function(x, params) {
+loglik_bipoiss_clust <- function(x, params, gamma) {
   loglik <- matrix(0, nrow(x), length(params))
+  kappa <- c()
   for (i in 1:length(params)) {
     loglik[, i] <- params[[i]]$kappa *
         extraDistr::dbvpois(x = x, a = params[[i]]$l1, b = params[[i]]$l2, c = params[[i]]$l3)
+    kappa <- c(kappa, params[[i]]$kappa)
   }
   loglik <- sum(log(rowSums(loglik)))
-  if (is.infinite(loglik)) {
-    warning("Inf loglik")
-  }
-  if (is.na(loglik)) {
-    warning("NA loglik")
+  if (length(kappa) > 1) {
+    loglik <- loglik + Boom::ddirichlet(probabilities = kappa, nu = gamma, logscale = T)
   }
   return(loglik)
 }
@@ -376,10 +387,10 @@ loglik_bipoiss_clust <- function(x, params) {
 bic_bipoiss_clust <- function(x, params, loglik) {
   k <- 3 
   if (length(params) >= 2) {
-      k <- k + 3
+      k <- k + 5
   }
   if (length(params) >= 3) {
-      k <- k + 3
+      k <- k + 4
   }
   return(k * log(nrow(x)) - 2 * loglik)
 }
@@ -507,6 +518,7 @@ e_bipoiss_clust_batch <- function(x, params, hidden = NULL, nbatch = 10) {
 #'
 #' @param x the data matrix (2 columns)
 #' @param params the list of 3 lambda (initial value)
+#' @param gamma a vector of kappa prior
 #' @param nbatch (default: 10) the number of batch
 #' @param threshold (default: c(0.1, 5)) maximum difference between kappa and lambda
 #' @param max_iter (default: 100) maximum number of iteration
@@ -524,7 +536,9 @@ em_bipoiss_clust <- function(x, params = list(
     A = list(kappa = 0.4, l1 = mean(x), l2 = mean(x), l3 = mean(x)),
     X = list(kappa = 0.3, l1 = 1, l2 = mean(x), l3 = mean(x)),
     Y = list(kappa = 0.3, l1 = mean(x), l2 = 1, l3 = mean(x))
-    ), nbatch = 10, threshold = c(0.1, 1), max_iter = 100) {
+    ),
+    gamma = c(round(nrow(x) * .8), round(nrow(x) * .15), round(nrow(x) * .05)), 
+    nbatch = 10, threshold = c(0.1, 1), max_iter = 100) {
     old_params <- params
     old_params[[1]]$kappa <- Inf
     hidden <- NULL
@@ -532,7 +546,7 @@ em_bipoiss_clust <- function(x, params = list(
     while (param_diff_poiss(params, old_params, threshold = threshold) & iter < max_iter)  {
         old_params <- params
         hidden <- e_bipoiss_clust_batch(x, params = params, hidden = hidden, nbatch = nbatch)
-        params <- m_bipoiss_clust(x, s = hidden$s, h = round(hidden$h))
+        params <- m_bipoiss_clust(x, s = hidden$s, h = round(hidden$h), gamma = gamma)
         iter <- iter + 1
     }
     res <- params
@@ -545,8 +559,8 @@ em_bipoiss_clust <- function(x, params = list(
     if (length(res) == 3) {
       names(res) <- c("A", "X", "Y")
     }
-    res$loglik <- loglik_bipoiss_clust(x, params)
-    res$BIC <- bic_bipoiss_clust(x, params, loglik = res$loglik)
+    res$loglik <- loglik_bipoiss_clust(x, params = params, gamma = gamma)
+    res$BIC <- bic_bipoiss_clust(x, params = params, loglik = res$loglik)
     res <- get_cluster_poiss(x, res = res, params = params, hidden = hidden)
     return(res)
 }
@@ -616,7 +630,9 @@ res <- data %>% em_bipoiss_clust(params = list(
       l2 = mean(data),
       l3 = mean(data)
     )
- ))
+   ),
+   gamma = c(round(nrow(data) * .8), round(nrow(data) * .2))
+ )
 rbind(
   extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3), # X
   extraDistr::rbvpois(1000*res[[2]]$kappa, res[[2]]$l1, res[[2]]$l2, res[[2]]$l3) # A
@@ -647,7 +663,8 @@ res <- data %>% em_bipoiss_clust(params = list(
       l2 = mean(data),
       l3 = mean(data)
     )
- ))
+ ),
+ gamma = c(nrow(data)))
 extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3) %>%  # X
     as_tibble() %>% 
     dplyr::rename(male = 'V1', female = 'V2') %>% 
@@ -667,6 +684,7 @@ extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3)
 #'
 #' @param x a matrix of male female kmers counts
 #' @param params a list of init parameters
+#' @param gamma a vector of kappa prior
 #' @param threshold theshold to stop the EM algorithm
 #' @param nboot number of boostrap sample
 #' @param bootsize size of the boostrap sample (if < 0 we take a percentage of x)
@@ -682,7 +700,7 @@ extraDistr::rbvpois(1000*res[[1]]$kappa, res[[1]]$l1, res[[1]]$l2, res[[1]]$l3)
 #'   extraDistr::rbvpois(500, 0, 20, 20), # X
 #'   extraDistr::rbvpois(100, 20, 0, 5) # Y
 #'   ) %>% pois_boostrap_EM(nboot = 10)
-poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, max_iter, max_error) {
+poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, gamma, max_iter, max_error) {
   res <- list()
   res$loglik <- c(NA)
   iter <- 0
@@ -691,7 +709,7 @@ poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, max_iter, max_error
       as_tibble() %>%
       sample_frac(bootsize, replace = T) %>%
       as.matrix() %>%
-      em_bipoiss_clust(params = params, max_iter = max_iter)
+      em_bipoiss_clust(params = params, gamma = gamma, max_iter = max_iter)
     iter <- iter + 1
   }
   return(tibble(loglik = res$loglik, BIC = res$BIC))
@@ -703,6 +721,7 @@ poiss_boostrap_EM_sub <- function(iter, x, bootsize, params, max_iter, max_error
 #'
 #' @param x a matrix of male female kmers counts
 #' @param params a list of init parameters
+#' @param gamma a vector of kappa prior
 #' @param threshold theshold to stop the EM algorithm
 #' @param nboot number of boostrap sample
 #' @param bootsize size of the boostrap sample (if < 0 we take a percentage of x)
@@ -722,16 +741,19 @@ poiss_boostrap_EM <- function(x, sex = "XY", threshold = 1, nboot = 100, bootsiz
   params = list(
       A = list(kappa = 0.4, l1 = mean(x), l2 = mean(x), l3 = mean(x))
     )
-    if (sex %in% c("XO", "XY")) {
-      params$X <- list(kappa = 0.3, l1 = 1, l2 = mean(x), l3 = mean(x))
-    }
-    if (sex == "XY") {
-      params$Y <- list(kappa = 0.3, l1 = mean(x), l2 = 1, l3 = mean(x))
-    }
+  gamma <- c(round(nrow(x) * bootsize))
+  if (sex %in% c("XO", "XY")) {
+    params$X <- list(kappa = 0.3, l1 = 1, l2 = mean(x), l3 = mean(x))
+    gamma <- c(round(nrow(x) * bootsize) * .8, round(nrow(x) * bootsize) * .2)
+  }
+  if (sex == "XY") {
+    params$Y <- list(kappa = 0.3, l1 = mean(x), l2 = 1, l3 = mean(x))
+    gamma <- c(round(nrow(x) * bootsize) * .8, round(nrow(x) * bootsize) * .15, round(nrow(x) * bootsize) * .05)
+  }
   if (core == 1) {
-    res <- lapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, max_iter = max_iter, max_error = max_error)
+    res <- lapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, gamma = gamma, max_iter = max_iter, max_error = max_error)
   } else {
-    res <- parallel::mclapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, max_iter = max_iter, max_error = max_error, mc.cores = core)
+    res <- parallel::mclapply(as.list(1:nboot), FUN = poiss_boostrap_EM_sub, x = x, bootsize = bootsize, params = params, gamma = gamma, max_iter = max_iter, max_error = max_error, mc.cores = core)
   }
   return(res)
 }
-- 
GitLab