From 74d7a14be3d85bac1ad83f8bb42fc989ea348cf9 Mon Sep 17 00:00:00 2001
From: Laurent Modolo <laurent.modolo@ens-lyon.fr>
Date: Wed, 30 Aug 2023 11:03:37 +0200
Subject: [PATCH] update package with fusen

---
 DESCRIPTION                                   |  4 +-
 R/boostrap_em_constraint.R                    | 22 ++++++--
 R/em_constraint.R                             | 12 ++---
 R/em_init_params.R                            | 38 +++++++-------
 R/plot_model.R                                |  2 -
 dev/0-dev_history.Rmd                         |  2 +-
 dev/config_fusen.yaml                         |  6 +--
 dev/config_not_registered.csv                 |  2 +-
 dev/flat_full_bayesian.Rmd                    | 50 +++++++++++--------
 .../getting-started-bayesian-version.Rmd      | 11 ++--
 10 files changed, 88 insertions(+), 61 deletions(-)

diff --git a/DESCRIPTION b/DESCRIPTION
index dc5d084..31c1d18 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,7 +1,7 @@
 Package: kmerclust
 Title: Kmerclust A Small Package To Test If The Kmer Count Correspond To A
     XY Or XO Genome
-Version: 0.0.3
+Version: 0.0.4
 Authors@R: 
     person("Laurent", "Modolo", , "laurent.modolo@ens-lyon.fr", role = c("aut", "cre"),
            comment = c(ORCID = "0000-0002-7606-4110"))
@@ -30,7 +30,7 @@ Suggests:
 VignetteBuilder: 
     knitr
 Config/fusen/version: 0.5.2
-Date/Publication: 2023/08/05
+Date/Publication: 2023/08/30
 Encoding: UTF-8
 Roxygen: list(markdown = TRUE)
 RoxygenNote: 7.2.3
diff --git a/R/boostrap_em_constraint.R b/R/boostrap_em_constraint.R
index 7c3e92e..c67aa08 100644
--- a/R/boostrap_em_constraint.R
+++ b/R/boostrap_em_constraint.R
@@ -17,16 +17,25 @@
 #' @noRd
 boostrap_EM_constraint <- function(x, sex = "XY", threshold = 1, nboot = 100, bootsize = 1000, core = 6) {
     if (sex == "XY") {
-      kappa_prior <- c(.8, .15, .05)
-      mu_prior <- list(c(0, 0), c(-0.5, 2), c(0,-2))
+      kappa_prior <- c(.7, .2, .1)
+      mu_prior <- list(c(1, 1), c(1, 1.5), c(1, .5))
+      kappa_weight <- c(.9, .9, .9)
+      mu_weight <- c(.1, .8, .8)
+      sigma_weight <- c(.001, .001, .001)
     }
     if (sex == "XO") {
       kappa_prior <- c(.85, .15)
-      mu_prior <- list(c(0, 0), c(-0.5, 2))
+      mu_prior <- list(c(1, 1), c(1, 1.5))
+      kappa_weight = c(.9, .9)
+      mu_weight <- c(.1, .8)
+      sigma_weight <- c(.001, .001)
     }
     if (sex == "OO") {
       kappa_prior <- c(.1)
-      mu_prior <- list(c(0, 0))
+      mu_prior <- list(c(1, 1))
+      kappa_weight <- c(.9)
+      mu_weight <- c(.1, .8)
+      sigma_weight <- c(.001)
     }
     parallel::mclapply(as.list(1:nboot), function(iter, x, bootsize, sex) {
         x_boot <- x %>% 
@@ -37,7 +46,10 @@ boostrap_EM_constraint <- function(x, sex = "XY", threshold = 1, nboot = 100, bo
         res <- x_boot %>% 
               EM_constraint(
                 kappa_prior = kappa_prior,
-                mu_prior = mu_prior
+                mu_prior = mu_prior,
+                kappa_weight = kappa_weight,
+                mu_weight = mu_weight,
+                sigma_weight = sigma_weight
               )
         res$bic <- compute_bic_constraint(x_boot, loglik = res$loglik, sex = sex)
         return(tibble(loglik = res$loglik, BIC = res$BIC))
diff --git a/R/em_constraint.R b/R/em_constraint.R
index c174e93..83fd7e3 100644
--- a/R/em_constraint.R
+++ b/R/em_constraint.R
@@ -15,12 +15,12 @@
 EM_constraint <- function(
     x,
     threshold = 1,
-    frac = .8,
-    kappa_weight = nrow(x) * .5,
-    mu_weight = nrow(x) * .01,
-    sigma_weight = nrow(x) * .0001,
-    kappa_prior = c(.8, .15, .05),
-    mu_prior = list(c(0, 0), c(-0.5, 2), c(0,-2)),
+    frac = 1,
+    kappa_weight = c(.9, .9, .9),
+    mu_weight = c(.1, .8, .8),
+    sigma_weight = c(.001, .001, .001),
+    kappa_prior = c(.7, .2, .1),
+    mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5)),
     verbose = F
   ) {
   params <- EM_init_params(
diff --git a/R/em_init_params.R b/R/em_init_params.R
index 2faa279..4323526 100644
--- a/R/em_init_params.R
+++ b/R/em_init_params.R
@@ -15,35 +15,37 @@
 #' @noRd
 EM_init_params <- function(
     x,
-    kappa_weight = nrow(x) * .5,
-    mu_weight = nrow(x) * .01,
-    sigma_weight = nrow(x) * .0001,
-    kappa_prior = c(.8, .15, .05),
-    mu_prior = list(c(0, 0), c(-.5, 2), c(0,-2))
+    kappa_weight = c(.9, .9, .9),
+    mu_weight = c(.1, .8, .8),
+    sigma_weight = c(.001, .001, .001),
+    kappa_prior = c(.7, .2, .1),
+    mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5))
     ) {
   x_mean <- colMeans(x)
   params <- list(
-    gamma = kappa_prior * kappa_weight + 1
+    gamma = kappa_prior * nrow(x) * kappa_weight + 1
   )
   params$rho <- list()
   params$nu <- list()
-  for (k in 1:length(params$gamma)) {
-    params$rho[[k]] <- x_mean + mu_prior[[k]]
-    params$nu[[k]] <- rep(mu_weight, ncol(x))
-  }
-  params$kappa <- as.vector(MCMCpack::rdirichlet(1, params$gamma))
-  params$kappa <- params$kappa / sum(params$kappa)
   params$beta <- list()
   params$mu <- list()
   params$sigma <- list()
   params$sigma_prior <- list()
-  # params$alpha <- rep((sigma_weight + ncol(x)) / 2, length(params$gamma))
-  params$alpha <- rep(ncol(x), length(params$gamma))
+  
+  for (k in 1:length(params$gamma)) {
+    params$rho[[k]] <- x_mean * mu_prior[[k]]
+    params$nu[[k]] <- nrow(x) * mu_weight[k]
+  }
+  params$kappa <- as.vector(MCMCpack::rdirichlet(1, params$gamma))
+  params$kappa <- params$kappa / sum(params$kappa)
+  params$alpha <- (nrow(x) * sigma_weight + ncol(x)) / 2
   for (k in 1:length(params$gamma)) {
-    params$beta[[k]] <- sigma_weight / 2 * cov(x)
-    params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish(
-      params$alpha[k], params$beta[[k]]
-    ), when = "random error", max_tries = 5)
+    params$alpha[k] <- max(c(params$alpha[k], ncol(x)))
+    params$beta[[k]] <- nrow(x) * sigma_weight[k] / 2 * cov(x)
+    # params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish(
+    #   params$alpha[k], params$beta[[k]]
+    # ), when = "random error", max_tries = 5)
+    params$sigma_prior[[k]] <- cov(x)
     params$sigma[[k]] <- params$sigma_prior[[k]]
     params$mu[[k]] <- mvtnorm::rmvnorm(1, mean = params$rho[[k]], sigma = params$nu[[k]]^-1 * params$sigma[[k]])
   }
diff --git a/R/plot_model.R b/R/plot_model.R
index b277d28..431e415 100644
--- a/R/plot_model.R
+++ b/R/plot_model.R
@@ -30,8 +30,6 @@ plot_model <- function(x, h, params) {
     chr_name <- c("A")
   }
   p1 <- x %>%
-    exp() %>% 
-    log10() %>% 
     as_tibble() %>% 
     mutate(
         clust_proba = clust_proba,
diff --git a/dev/0-dev_history.Rmd b/dev/0-dev_history.Rmd
index 5a31083..dc77786 100755
--- a/dev/0-dev_history.Rmd
+++ b/dev/0-dev_history.Rmd
@@ -21,7 +21,7 @@ fusen::fill_description(
     `Authors@R` = c(
       person("Laurent", "Modolo", email = "laurent.modolo@ens-lyon.fr", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-7606-4110"))
     ),
-    Version = "0.0.3",
+    Version = "0.0.4",
     "Date/Publication" = format(Sys.time(), "%Y/%m/%d")
   ),
   overwrite = T
diff --git a/dev/config_fusen.yaml b/dev/config_fusen.yaml
index 9652dfc..f73f416 100644
--- a/dev/config_fusen.yaml
+++ b/dev/config_fusen.yaml
@@ -2,7 +2,7 @@ flat_full.Rmd:
   path: dev/flat_full.Rmd
   state: active
   R:
-  - R/sim_kmer.R
+  - R/compute_tpm.R
   - R/expand_theta.R
   - R/parse_annotation.R
   tests: []
@@ -11,7 +11,7 @@ flat_full.Rmd:
     flat_file: dev/flat_full.Rmd
     vignette_name: Getting started
     open_vignette: false
-    check: true
+    check: false
     document: true
     overwrite: 'yes'
 flat_full_bayesian.Rmd:
@@ -41,6 +41,6 @@ flat_full_bayesian.Rmd:
     flat_file: dev/flat_full_bayesian.Rmd
     vignette_name: Getting started bayesian version
     open_vignette: false
-    check: true
+    check: false
     document: true
     overwrite: 'yes'
diff --git a/dev/config_not_registered.csv b/dev/config_not_registered.csv
index 844ecc1..aea35fb 100644
--- a/dev/config_not_registered.csv
+++ b/dev/config_not_registered.csv
@@ -1,4 +1,4 @@
 "type","path","origin"
-"R","R/compute_tpm.R","Possibly deprecated file. Please check its link with detected flat source: dev/flat_full.Rmd"
 "R","R/kmerclust-package.R","No existing source path found."
+"R","R/sim_kmer.R","Possibly deprecated file. Please check its link with detected flat source: dev/flat_full.Rmd"
 "R","R/utils-pipe.R","No existing source path found."
diff --git a/dev/flat_full_bayesian.Rmd b/dev/flat_full_bayesian.Rmd
index e483fa2..002e5c5 100644
--- a/dev/flat_full_bayesian.Rmd
+++ b/dev/flat_full_bayesian.Rmd
@@ -206,7 +206,7 @@ EM_init_params <- function(
   params$sigma_prior <- list()
   
   for (k in 1:length(params$gamma)) {
-    params$rho[[k]] <- x_mean
+    params$rho[[k]] <- x_mean * mu_prior[[k]]
     params$nu[[k]] <- nrow(x) * mu_weight[k]
   }
   params$kappa <- as.vector(MCMCpack::rdirichlet(1, params$gamma))
@@ -215,23 +215,15 @@ EM_init_params <- function(
   for (k in 1:length(params$gamma)) {
     params$alpha[k] <- max(c(params$alpha[k], ncol(x)))
     params$beta[[k]] <- nrow(x) * sigma_weight[k] / 2 * cov(x)
-    params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish(
-      params$alpha[k], params$beta[[k]]
-    ), when = "random error", max_tries = 5)
+    # params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish(
+    #   params$alpha[k], params$beta[[k]]
+    # ), when = "random error", max_tries = 5)
+    params$sigma_prior[[k]] <- cov(x)
     params$sigma[[k]] <- params$sigma_prior[[k]]
     params$mu[[k]] <- mvtnorm::rmvnorm(1, mean = params$rho[[k]], sigma = params$nu[[k]]^-1 * params$sigma[[k]])
   }
   return(params)
 }
-test_model_XY <- function(data) {
-  data %>%
-      dplyr::select(count_m, count_f) %>%
-      as.matrix() %>% 
-      compute_tpm() %>%
-      EM_constraint()
-}
-model_XY <- test_model_XY(data)
-plot_model(x = model_XY$x, h = model_XY$h, params = model_XY$params)
 ```
 
 ## cluster id
@@ -540,16 +532,25 @@ compute_bic_constraint <- function(x, loglik, sex = "OO") {
 #' res <- boostrap_EM(data, nboot = 10, bootsize = 0.01)
 boostrap_EM_constraint <- function(x, sex = "XY", threshold = 1, nboot = 100, bootsize = 1000, core = 6) {
     if (sex == "XY") {
-      kappa_prior <- c(.8, .15, .05)
-      mu_prior <- list(c(0, 0), c(-0.5, 2), c(0,-2))
+      kappa_prior <- c(.7, .2, .1)
+      mu_prior <- list(c(1, 1), c(1, 1.5), c(1, .5))
+      kappa_weight <- c(.9, .9, .9)
+      mu_weight <- c(.1, .8, .8)
+      sigma_weight <- c(.001, .001, .001)
     }
     if (sex == "XO") {
       kappa_prior <- c(.85, .15)
-      mu_prior <- list(c(0, 0), c(-0.5, 2))
+      mu_prior <- list(c(1, 1), c(1, 1.5))
+      kappa_weight = c(.9, .9)
+      mu_weight <- c(.1, .8)
+      sigma_weight <- c(.001, .001)
     }
     if (sex == "OO") {
       kappa_prior <- c(.1)
-      mu_prior <- list(c(0, 0))
+      mu_prior <- list(c(1, 1))
+      kappa_weight <- c(.9)
+      mu_weight <- c(.1, .8)
+      sigma_weight <- c(.001)
     }
     parallel::mclapply(as.list(1:nboot), function(iter, x, bootsize, sex) {
         x_boot <- x %>% 
@@ -560,7 +561,10 @@ boostrap_EM_constraint <- function(x, sex = "XY", threshold = 1, nboot = 100, bo
         res <- x_boot %>% 
               EM_constraint(
                 kappa_prior = kappa_prior,
-                mu_prior = mu_prior
+                mu_prior = mu_prior,
+                kappa_weight = kappa_weight,
+                mu_weight = mu_weight,
+                sigma_weight = sigma_weight
               )
         res$bic <- compute_bic_constraint(x_boot, loglik = res$loglik, sex = sex)
         return(tibble(loglik = res$loglik, BIC = res$BIC))
@@ -627,7 +631,10 @@ test_model_XO <- function(data) {
       compute_tpm() %>%
       EM_constraint(
         kappa_prior = c(.85, .15),
-        mu_prior = list(c(0, 0), c(-0.5, 2))
+        mu_prior = list(c(1, 1), c(1, 1.5)),
+        kappa_weight = c(.9, .9),
+        mu_weight = c(.1, .8),
+        sigma_weight = c(.001, .001)
       )
 }
 model_XO <- test_model_XO(data)
@@ -644,7 +651,10 @@ test_model_OO <- function(data) {
       compute_tpm() %>%
       EM_constraint(
         kappa_prior = c(1),
-        mu_prior = list(c(0, 0))
+        mu_prior = list(c(1, 1)),
+        kappa_weight = c(.9),
+        mu_weight = c(.1),
+        sigma_weight = c(.001)
       )
 }
 model_OO <- test_model_OO(data)
diff --git a/vignettes/getting-started-bayesian-version.Rmd b/vignettes/getting-started-bayesian-version.Rmd
index 74c0a1d..1403f53 100644
--- a/vignettes/getting-started-bayesian-version.Rmd
+++ b/vignettes/getting-started-bayesian-version.Rmd
@@ -103,7 +103,10 @@ test_model_XO <- function(data) {
       compute_tpm() %>%
       EM_constraint(
         kappa_prior = c(.85, .15),
-        mu_prior = list(c(0, 0), c(-0.5, 2))
+        mu_prior = list(c(1, 1), c(1, 1.5)),
+        kappa_weight = c(.9, .9),
+        mu_weight = c(.1, .8),
+        sigma_weight = c(.001, .001)
       )
 }
 model_XO <- test_model_XO(data)
@@ -120,7 +123,10 @@ test_model_OO <- function(data) {
       compute_tpm() %>%
       EM_constraint(
         kappa_prior = c(1),
-        mu_prior = list(c(0, 0))
+        mu_prior = list(c(1, 1)),
+        kappa_weight = c(.9),
+        mu_weight = c(.1),
+        sigma_weight = c(.001)
       )
 }
 model_OO <- test_model_OO(data)
@@ -134,7 +140,6 @@ comparison <- data %>%
     dplyr::select(count_m, count_f) %>%
     as.matrix() %>% 
     compute_tpm() %>%
-    # log() %>%
     compare_models_constaint(nboot = 10, bootsize = 1000, core = 1)
 comparison %>% 
   ggplot(aes(x = name, y = loglik, fill = name)) +
-- 
GitLab