diff --git a/dev/flat_full_bayesian.Rmd b/dev/flat_full_bayesian.Rmd index 562293d62d557932d2e266e916132ffca66141fa..e483fa2e44629958bfb379ebff1572db3d947763 100644 --- a/dev/flat_full_bayesian.Rmd +++ b/dev/flat_full_bayesian.Rmd @@ -188,33 +188,33 @@ M_sigma <- function(x, h, mu, rho, nu, alpha, beta) { #' @return a list of parameters EM_init_params <- function( x, - kappa_weight = nrow(x) * .5, - mu_weight = nrow(x) * .01, - sigma_weight = nrow(x) * .001, - kappa_prior = c(.8, .15, .05), - mu_prior = list(c(1, 1), c(1, 2), c(1,-2)) + kappa_weight = c(.9, .9, .9), + mu_weight = c(.1, .8, .8), + sigma_weight = c(.001, .001, .001), + kappa_prior = c(.7, .2, .1), + mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5)) ) { x_mean <- colMeans(x) params <- list( - gamma = kappa_prior * kappa_weight + 1 + gamma = kappa_prior * nrow(x) * kappa_weight + 1 ) params$rho <- list() params$nu <- list() + params$beta <- list() + params$mu <- list() + params$sigma <- list() + params$sigma_prior <- list() + for (k in 1:length(params$gamma)) { params$rho[[k]] <- x_mean - params$nu[[k]] <- rep(mu_weight, ncol(x)) + params$nu[[k]] <- nrow(x) * mu_weight[k] } params$kappa <- as.vector(MCMCpack::rdirichlet(1, params$gamma)) params$kappa <- params$kappa / sum(params$kappa) - params$beta <- list() - params$mu <- list() - params$sigma <- list() - params$sigma_prior <- list() - params$alpha <- (sigma_weight + ncol(x)) / 2 - params$alpha <- max(c(params$alpha, ncol(x))) - params$alpha <- rep(params$alpha, length(params$gamma)) + params$alpha <- (nrow(x) * sigma_weight + ncol(x)) / 2 for (k in 1:length(params$gamma)) { - params$beta[[k]] <- sigma_weight / 2 * cov(x) + params$alpha[k] <- max(c(params$alpha[k], ncol(x))) + params$beta[[k]] <- nrow(x) * sigma_weight[k] / 2 * cov(x) params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish( params$alpha[k], params$beta[[k]] ), when = "random error", max_tries = 5) @@ -223,6 +223,15 @@ EM_init_params <- function( } return(params) } +test_model_XY <- function(data) { + data %>% + dplyr::select(count_m, count_f) %>% + as.matrix() %>% + compute_tpm() %>% + EM_constraint() +} +model_XY <- test_model_XY(data) +plot_model(x = model_XY$x, h = model_XY$h, params = model_XY$params) ``` ## cluster id @@ -257,10 +266,10 @@ h_2_clust_proba <- function(h) { EM_constraint <- function( x, threshold = 1, - frac = .8, - kappa_weight = nrow(x) * .9, - mu_weight = nrow(x) * .1, - sigma_weight = nrow(x) * .001, + frac = 1, + kappa_weight = c(.9, .9, .9), + mu_weight = c(.1, .8, .8), + sigma_weight = c(.001, .001, .001), kappa_prior = c(.7, .2, .1), mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5)), verbose = F