Skip to content
Snippets Groups Projects
Verified Commit 05b23237 authored by Laurent Modolo's avatar Laurent Modolo
Browse files

flat_full_bayesian.Rmd: add individual weight for each prior parameters

parent 10bbc366
Branches
Tags
No related merge requests found
...@@ -188,33 +188,33 @@ M_sigma <- function(x, h, mu, rho, nu, alpha, beta) { ...@@ -188,33 +188,33 @@ M_sigma <- function(x, h, mu, rho, nu, alpha, beta) {
#' @return a list of parameters #' @return a list of parameters
EM_init_params <- function( EM_init_params <- function(
x, x,
kappa_weight = nrow(x) * .5, kappa_weight = c(.9, .9, .9),
mu_weight = nrow(x) * .01, mu_weight = c(.1, .8, .8),
sigma_weight = nrow(x) * .001, sigma_weight = c(.001, .001, .001),
kappa_prior = c(.8, .15, .05), kappa_prior = c(.7, .2, .1),
mu_prior = list(c(1, 1), c(1, 2), c(1,-2)) mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5))
) { ) {
x_mean <- colMeans(x) x_mean <- colMeans(x)
params <- list( params <- list(
gamma = kappa_prior * kappa_weight + 1 gamma = kappa_prior * nrow(x) * kappa_weight + 1
) )
params$rho <- list() params$rho <- list()
params$nu <- list() params$nu <- list()
params$beta <- list()
params$mu <- list()
params$sigma <- list()
params$sigma_prior <- list()
for (k in 1:length(params$gamma)) { for (k in 1:length(params$gamma)) {
params$rho[[k]] <- x_mean params$rho[[k]] <- x_mean
params$nu[[k]] <- rep(mu_weight, ncol(x)) params$nu[[k]] <- nrow(x) * mu_weight[k]
} }
params$kappa <- as.vector(MCMCpack::rdirichlet(1, params$gamma)) params$kappa <- as.vector(MCMCpack::rdirichlet(1, params$gamma))
params$kappa <- params$kappa / sum(params$kappa) params$kappa <- params$kappa / sum(params$kappa)
params$beta <- list() params$alpha <- (nrow(x) * sigma_weight + ncol(x)) / 2
params$mu <- list()
params$sigma <- list()
params$sigma_prior <- list()
params$alpha <- (sigma_weight + ncol(x)) / 2
params$alpha <- max(c(params$alpha, ncol(x)))
params$alpha <- rep(params$alpha, length(params$gamma))
for (k in 1:length(params$gamma)) { for (k in 1:length(params$gamma)) {
params$beta[[k]] <- sigma_weight / 2 * cov(x) params$alpha[k] <- max(c(params$alpha[k], ncol(x)))
params$beta[[k]] <- nrow(x) * sigma_weight[k] / 2 * cov(x)
params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish( params$sigma_prior[[k]] <- retry::retry(MCMCpack::riwish(
params$alpha[k], params$beta[[k]] params$alpha[k], params$beta[[k]]
), when = "random error", max_tries = 5) ), when = "random error", max_tries = 5)
...@@ -223,6 +223,15 @@ EM_init_params <- function( ...@@ -223,6 +223,15 @@ EM_init_params <- function(
} }
return(params) return(params)
} }
test_model_XY <- function(data) {
data %>%
dplyr::select(count_m, count_f) %>%
as.matrix() %>%
compute_tpm() %>%
EM_constraint()
}
model_XY <- test_model_XY(data)
plot_model(x = model_XY$x, h = model_XY$h, params = model_XY$params)
``` ```
## cluster id ## cluster id
...@@ -257,10 +266,10 @@ h_2_clust_proba <- function(h) { ...@@ -257,10 +266,10 @@ h_2_clust_proba <- function(h) {
EM_constraint <- function( EM_constraint <- function(
x, x,
threshold = 1, threshold = 1,
frac = .8, frac = 1,
kappa_weight = nrow(x) * .9, kappa_weight = c(.9, .9, .9),
mu_weight = nrow(x) * .1, mu_weight = c(.1, .8, .8),
sigma_weight = nrow(x) * .001, sigma_weight = c(.001, .001, .001),
kappa_prior = c(.7, .2, .1), kappa_prior = c(.7, .2, .1),
mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5)), mu_prior = list(c(1, 1), c(1, 1.5), c(1, .5)),
verbose = F verbose = F
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment