diff --git a/src/clustering.Rmd b/src/clustering.Rmd
index e4fef2d34b569b7fdfe3575b157c174359161467..6150909f30f6f68a370039554502f49a5d76065c 100644
--- a/src/clustering.Rmd
+++ b/src/clustering.Rmd
@@ -90,24 +90,26 @@ plot(data_clust, what = "uncertainty")
 
 ```{r}
 
-expand_theta <- function(theta, cluster_coef) {
-    list(
-    "f" = list(
-        "pi" = theta$pi[1],
-        "mu" = cluster_coef$f * theta$mu,
-        "sigma" = theta$sigma$f
-    ),
-    "m" = list(
-        "pi" = theta$pi[2],
-        "mu" = cluster_coef$m * theta$mu,
-        "sigma" = theta$sigma$m
-    ),
+expand_theta <- function(theta, cluster_coef, sex) {
+    theta_ref <- list(
     "a" = list(
-        "pi" = theta$pi[3],
+        "pi" = theta$pi[1],
         "mu" = cluster_coef$a * theta$mu,
         "sigma" = theta$sigma$a
-    )
-)
+    ),
+    "f" = list(
+        "pi" = theta$pi[2],
+        "mu" = cluster_coef$f * theta$mu,
+        "sigma" = theta$sigma$f
+    ))
+    if (sex == "XY") {
+        theta_ref[["m"]] <- list(
+            "pi" = theta$pi[3],
+            "mu" = cluster_coef$m * theta$mu,
+            "sigma" = theta$sigma$m
+        )
+    }
+    return(theta_ref)
 }
 
 params_diff <- function(old_theta, theta, threshold) {
@@ -121,18 +123,18 @@ params_diff <- function(old_theta, theta, threshold) {
     return(T)
 }
 
-proba_total <- function(x, theta, cluster_coef) {
+proba_total <- function(x, theta, cluster_coef, sex) {
     proba <- 0
-    for (params in expand_theta(theta, cluster_coef)) {
+    for (params in expand_theta(theta, cluster_coef, sex)) {
         proba <- proba + params$pi * 
             mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
     }
     return(proba)
 }
 
-proba_point <- function(x, theta, cluster_coef) {
+proba_point <- function(x, theta, cluster_coef, sex) {
     proba <- c()
-    for (params in expand_theta(theta, cluster_coef)) {
+    for (params in expand_theta(theta, cluster_coef, sex)) {
         proba <- cbind(proba, params$pi * 
             mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
         )
@@ -140,13 +142,13 @@ proba_point <- function(x, theta, cluster_coef) {
     return(proba)
 }
 
-loglik <- function(x, theta, cluster_coef) {
-    -log(sum(proba_total(x, theta, cluster_coef)))
+loglik <- function(x, theta, cluster_coef, sex) {
+    -log(sum(proba_total(x, theta, cluster_coef, sex)))
 }
 
 # EM function
-E_proba <- function(x, theta, cluster_coef) {
-    proba <- proba_point(x, theta, cluster_coef)
+E_proba <- function(x, theta, cluster_coef, sex) {
+    proba <- proba_point(x, theta, cluster_coef, sex)
     proba_norm <- rowSums(proba)
     for (cluster in 1:ncol(proba)) {
         proba[, cluster] <- proba[, cluster] / proba_norm
@@ -160,78 +162,101 @@ E_N_clust <- function(proba) {
 }
 
 # Function for mean update
-M_mean <- function(x, proba, N_clust) {
+M_mean <- function(x, proba, N_clust, sex) {
     mu <- 0
     for (cluster in 1:ncol(proba)) {
         if (cluster == 1) {
-            mu <- mu + 1/3 *
-                mean(colSums(x * c(1, 0.5) * proba[, cluster]) / N_clust[cluster])
+            mu <- mu + 
+                mean(colSums(x * c(0.5, 0.5) * proba[, cluster]) / N_clust[cluster])
         }
         if (cluster == 2) {
-            mu <- mu + 1/3 *
-                (colSums(x * c(1, 0) * proba[, cluster]) / N_clust[cluster])[1]
+            mu <- mu +
+                mean(colSums(x * c(1, 0.5) * proba[, cluster]) / N_clust[cluster])
         }
-        if (cluster == 2) {
-            mu <- mu + 1/3 *
-                mean(colSums(x * c(0.5, 0.5) * proba[, cluster]) / N_clust[cluster])
+        if (cluster == 3) {
+            mu <- mu +
+                (colSums(x * c(1, 0) * proba[, cluster]) / N_clust[cluster])[1]
         }
     }
-    return(mu)
+    if (sex == "XY") {
+        return(mu / 3)
+    }
+    return(mu / 2)
 }
 
-M_cov <- function(x, proba, mu, N_clust, cluster_coef) {
+M_cov <- function(x, proba, mu, N_clust, cluster_coef, sex) {
     cov_clust <- list() 
     for (cluster in 1:ncol(proba)) {
-        print(cluster_coef[[cluster]])
         cov_clust[[cluster]] <- t(proba[, cluster] * (x - mu * cluster_coef[[cluster]])) %*% (x - mu * cluster_coef[[cluster]]) / N_clust[cluster]
     }
     sigma <- list()
-    sigma$f <- cov_clust[[1]]
-    sigma$m <- cov_clust[[2]]
-    sigma$a <- cov_clust[[3]]
+    sigma$a <- cov_clust[[1]]
+    sigma$f <- cov_clust[[2]]
+    if (sex == "XY") {
+        sigma$m <- cov_clust[[3]]
+    }
     return(sigma)
 }
 
-plot_proba <- function(x, proba) {
-    as_tibble(x) %>% 
-        mutate(
-            proba_f = proba[, 1],
-            proba_m = proba[, 2],
-            proba_a = proba[, 3],
-            clust_proba = rgb(proba_f, proba_m, proba_a, maxColorValue = 1)
-        ) %>% 
-        ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
-        geom_point() +
-        scale_color_identity()
+plot_proba <- function(x, proba, sex = "XY") {
+    if (sex == "XY") {
+        as_tibble(x) %>% 
+            mutate(
+                proba_a = proba[, 1],
+                proba_f = proba[, 2],
+                proba_m = proba[, 3],
+                clust_proba = rgb(proba_f, proba_m, proba_a, maxColorValue = 1)
+            ) %>% 
+            ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
+            geom_point() +
+            scale_color_identity()
+    } else {
+        as_tibble(x) %>% 
+            mutate(
+                proba_a = proba[, 1],
+                proba_f = proba[, 2],
+                clust_proba = rgb(proba_f, 0,  proba_a, maxColorValue = 1)
+            ) %>% 
+            ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
+            geom_point() +
+            scale_color_identity()
+    }
 }
 
-EM_clust <- function(x, threshold = 0.1) {
-    old_loglik <- -Inf
-    new_loglik <- 0
+init_param <- function(x, sex) {
     cluster_coef <- list(
-        "f" = c(1, 2),
-        "m" = c(1, 0),
-        "a" = c(2, 2)
+        "a" = c(2, 2),
+        "f" = c(1, 2)
     )
     theta <- list(
-        "pi" = c(.1, .05, .85),
+        "pi" = c(.85, .1, .05),
         "mu" = mean(colMeans(x)) * .5
     )
     theta$sigma <- list(
-            "f" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2),
-            "m" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2),
-            "a" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
+            "a" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2),
+            "f" = matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
     )
+    if (sex == "XY") {
+        cluster_coef$m <- c(1, 0)
+        theta$sigma$m <- matrix(c(1, 1, 1, 1) * theta$mu, ncol = 2)
+    }
+    return(list(cluster_coef = cluster_coef, theta = theta))
+}
+
+
+EM_clust <- function(x, threshold = 0.1, sex = "XY") {
+    old_loglik <- -Inf
+    new_loglik <- 0
+    param <- init_param(x, sex)
     while (abs(new_loglik - old_loglik) > threshold) {
-        old_loglik <- loglik(x, theta, cluster_coef)
-        proba <- E_proba(x, theta, cluster_coef)
-        theta$pi <- E_N_clust(proba)
-        theta$mu <- M_mean(x, proba, theta$pi)
-        theta$sigma <- M_cov(x, proba, theta$mu, theta$pi, cluster_coef)
-        theta$pi <- theta$pi / nrow(x)
-        new_loglik <- loglik(x, theta, cluster_coef)
+        old_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
+        proba <- E_proba(x, param$theta, param$cluster_coef, sex)
+        param$theta$pi <- E_N_clust(proba)
+        param$theta$mu <- M_mean(x, proba, param$theta$pi, sex)
+        param$theta$sigma <- M_cov(x, proba, param$theta$mu, param$theta$pi, param$cluster_coef, sex)
+        param$theta$pi <- param$theta$pi / nrow(x)
+        new_loglik <- loglik(x, param$theta, param$cluster_coef, sex)
     }
-    print(theta)
     return(proba)
 }
 
@@ -245,6 +270,14 @@ data %>%
     as.matrix() %>% 
     plot_proba(proba)
     
+proba <- data %>%
+    dplyr::select(count_m, count_f) %>%
+    as.matrix() %>% 
+    EM_clust(sex = "X0")
+data %>%
+    dplyr::select(count_m, count_f) %>%
+    as.matrix() %>% 
+    plot_proba(proba, sex = "X0")
 ```