diff --git a/dev/flat_full_bayesian.Rmd b/dev/flat_full_bayesian.Rmd
index 562293d62d557932d2e266e916132ffca66141fa..e483fa2e44629958bfb379ebff1572db3d947763 100644
--- a/dev/flat_full_bayesian.Rmd
+++ b/dev/flat_full_bayesian.Rmd
@@ -188,33 +188,33 @@ M_sigma <- function(x, h, mu, rho, nu, alpha, beta) {
 #' @return a list of parameters
 EM_init_params <- function(
     x,
-    kappa_weight = nrow(x) * .5,
-    mu_weight = nrow(x) * .01,
-    sigma_weight = nrow(x) * .001,
-    kappa_prior = c(.8, .15, .05),
-    mu_prior = list(c(1, 1), c(1, 2), c(1,-2))
+    kappa_weight = c(.9, .9, .9),
+    mu_weight = c(.1, .8, .8),
+    sigma_weight = c(.001, .001, .001),
+    kappa_prior = c(.7, .2, .1),
+    mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5))
     ) {
   x_mean <- colMeans(x)
   params <- list(
-    gamma = kappa_prior * kappa_weight + 1
+    gamma = kappa_prior * nrow(x) * kappa_weight + 1
   )
   params$rho <- list()
   params$nu <- list()
+  params$beta <- list()
+  params$mu <- list()
+  params$sigma <- list()
+  params$sigma_prior <- list()
+  
   for (k in 1:length(params$gamma)) {
     params$rho[[k]] <- x_mean
-    params$nu[[k]] <- rep(mu_weight, ncol(x))
+    params$nu[[k]] <- nrow(x) * mu_weight[k]
   }
   params$kappa <- as.vector(MCMCpack::rdirichlet(1, params$gamma))
   params$kappa <- params$kappa / sum(params$kappa)
-  params$beta <- list()
-  params$mu <- list()
-  params$sigma <- list()
-  params$sigma_prior <- list()
-  params$alpha <- (sigma_weight + ncol(x)) / 2
-  params$alpha <- max(c(params$alpha, ncol(x)))
-  params$alpha <- rep(params$alpha, length(params$gamma))
+  params$alpha <- (nrow(x) * sigma_weight + ncol(x)) / 2
   for (k in 1:length(params$gamma)) {
-    params$beta[[k]] <- sigma_weight / 2 * cov(x)
+    params$alpha[k] <- max(c(params$alpha[k], ncol(x)))
+    params$beta[[k]] <- nrow(x) * sigma_weight[k] / 2 * cov(x)
     params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish(
       params$alpha[k], params$beta[[k]]
     ), when = "random error", max_tries = 5)
@@ -223,6 +223,15 @@ EM_init_params <- function(
   }
   return(params)
 }
+test_model_XY <- function(data) {
+  data %>%
+      dplyr::select(count_m, count_f) %>%
+      as.matrix() %>% 
+      compute_tpm() %>%
+      EM_constraint()
+}
+model_XY <- test_model_XY(data)
+plot_model(x = model_XY$x, h = model_XY$h, params = model_XY$params)
 ```
 
 ## cluster id
@@ -257,10 +266,10 @@ h_2_clust_proba <- function(h) {
 EM_constraint <- function(
     x,
     threshold = 1,
-    frac = .8,
-    kappa_weight = nrow(x) * .9,
-    mu_weight = nrow(x) * .1,
-    sigma_weight = nrow(x) * .001,
+    frac = 1,
+    kappa_weight = c(.9, .9, .9),
+    mu_weight = c(.1, .8, .8),
+    sigma_weight = c(.001, .001, .001),
     kappa_prior = c(.7, .2, .1),
     mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5)),
     verbose = F