diff --git a/src/clustering.Rmd b/src/clustering.Rmd
index 555a4d90503a919677270b59e04739e1ae81f644..60c80324df7ab3f4f0dd5c27856b600d75044238 100644
--- a/src/clustering.Rmd
+++ b/src/clustering.Rmd
@@ -79,10 +79,6 @@ data %>%
     coord_fixed()
 ```
 
-
-
-
-
 ```{r}
 data_clust = data %>% select(-c("sex")) %>% mclust::Mclust(G = 3)
 ```
@@ -93,7 +89,7 @@ plot(data_clust, what = "uncertainty")
 ```
 
 ```{r}
-theta <- list(
+theta2 <- list(
     "pi" = c(.1, .05, .85),
     "mu" = list(c(1000, 2000), c(1000, 0), c(1000, 1000)),
     "sigma" = list(
@@ -142,35 +138,36 @@ params_diff <- function(old_theta, theta, threshold) {
 proba_total <- function(x, theta) {
     proba <- 0
     for (params in expand_theta(theta)) {
-        proba <- proba + params$pi * 
-            mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
+        print(params)
+        proba <- proba + params$pi + 
+            mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma, log = T)
     }
+    print(proba)
     return(proba)
 }
 
 proba_point <- function(x, theta) {
     proba <- c()
     for (params in expand_theta(theta)) {
-        proba <- cbind(proba, params$pi * 
-            mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma)
+        proba <- cbind(proba, params$pi + 
+            mvtnorm::dmvnorm(x, mean = params$mu, sigma = params$sigma, log = T)
         )
     }
     return(proba)
 }
 
 loglik <- function(x, theta) {
-    -log(sum(proba_total(x, theta)))
+    -sum(proba_total(x, theta))
 }
 
 # EM function
 E_proba <- function(x, theta) {
     proba <- proba_point(x, theta)
-    proba_norm <- rowSums(proba)
+    proba_norm <- rowSums(exp(proba))
     for (cluster in 1:ncol(proba)) {
-        proba[, cluster] <- proba[, cluster] / proba_norm
+        proba[, cluster] <- exp(proba[, cluster]) / proba_norm
         proba[proba_norm == 0, cluster] <- 1 / ncol(proba)
     }
-    
     return(proba)
 }
 
@@ -205,181 +202,68 @@ plot_proba <- function(x, proba) {
             proba_f = proba[, 1],
             proba_m = proba[, 2],
             proba_a = proba[, 3],
-            color = rgb(proba_f * 255, proba_m * 255, proba_a * 255, maxColorValue = 255)
+            clust_proba = rgb(proba_f, proba_m, proba_a, maxColorValue = 1)
         ) %>% 
-        ggplot(aes(x = count_m, y = count_f, color = color)) +
-        geom_point()
+        ggplot(aes(x = count_m, y = count_f, color = clust_proba)) +
+        geom_point() +
+        scale_color_identity()
 }
 
 EM_clust <- function(x, theta, threshold = 0.1) {
-    old_theta <- theta
-    old_theta$mu <- list(c(-Inf, -Inf), c(-Inf, -Inf), c(-Inf, -Inf))
-    while (params_diff(old_theta, theta, threshold)) {
+    old_loglik <- -Inf
+    new_loglik <- 0
+    while (abs(new_loglik - old_loglik) > threshold) {
+        old_loglik <- loglik(x, theta)
         proba <- E_proba(x, theta)
         theta$pi <- E_N_clust(proba)
         theta$mu <- M_mean(x, proba, theta$pi)
         theta$sigma <- M_cov(x, proba, theta$mu, theta$pi)
         theta$pi <- theta$pi / nrow(x)
-        print(head(proba_point(x, theta)))
-        print(plot_proba(x, proba_point(x, theta)))
-        print(loglik(x, theta))
-        Sys.sleep(.5)
+        new_loglik <- loglik(x, theta)
+        # Sys.sleep(.5)
     }
     return(proba)
 }
 
 
+proba <- data %>%
+    dplyr::select(count_m, count_f) %>%
+    as.matrix() %>% 
+    EM_clust(theta2)
 data %>%
     dplyr::select(count_m, count_f) %>%
     as.matrix() %>% 
-    EM_clust(theta)
+    plot_proba(proba)
+    
 ```
 
 
 ```{r}
+theta4 <- list(
+    "pi" = c(.1, .05, .85),
+    "mu" = list(c(1000, 2000, 1000, 2000), c(1000, 0, 1000, 0), c(1000, 1000, 1000, 1000)),
+    "sigma" = list(
+        "f" = diag(1000, nrow=4, ncol=4),
+        "m" = diag(1000, nrow=4, ncol=4),
+        "a" = diag(1000, nrow=4, ncol=4)
+    )
+)
+proba4 <- data %>%
+    dplyr::select(count_m, count_f) %>%
+    dplyr::mutate(
+        count_m2 = count_m,
+        count_f2 = count_f) %>%
+    as.matrix() %>% 
+    EM_clust(theta4)
 
+data %>%
+    dplyr::select(count_m, count_f) %>%
+    as.matrix() %>% 
+    plot_proba(proba4)
 ```
 
-```{r}
-## Example code for clustering on a three-component mixture model using the EM-algorithm.
-
-### First we load some libraries and define some useful functions
+## With real data
 
-library(mvtnorm)
-library(MASS)
-
-# Create a 'true' data set (an easy one)
-.create.data <- function(n)
-{
-  l <- list()
-  l[[1]] <- list(component=1,
-    mixing.weight=0.5,
-                 means=c(0,0),
-                 cov=matrix(c(1,0,0,1), ncol=2, byrow=T))
-  l[[2]] <- list(mixing.weight=0.3,
-                 component=2,
-                 means=c(5,5),
-                 cov=matrix(c(1, 0.5, 0.5, 1), ncol=2, byrow=T))
-  l[[3]] <- list(mixing.weight=0.2,
-                 component=3,
-                 means=c(10,10),
-                 cov=matrix(c(1,0.75,0.75,1), ncol=2, byrow=T))
-
-  do.call("rbind",sapply(l, function(e) {
-    dat <- mvtnorm::rmvnorm(e$mixing.weight * n, e$means, e$cov)
-    cbind(component=e$component,
-          x1=dat[,1],
-          x2=dat[,2])
-  }))
-}
-
-# Function for covariance update
-.cov <- function(n, r, dat, m, N.k)
-{
-  (t(r * (dat[,2:3] -m))  %*%  (( dat[,2:3]-m))) / N.k
-}
-
-# Generate starting values for means/covs/mixing weights
-.init <- function()
-{
-  l <- list()
-  l[[1]] <- list(mixing.weight=0.1,
-                 means=c(-2, -2),
-                 cov=matrix(c(1,0,0,1), ncol=2, byrow=T))
-  l[[2]] <- list(mixing.weight=0.1,
-                 means=c(10, 0),
-                 cov=matrix(c(1,0,0,1), ncol=2, byrow=T))
-  l[[3]] <- list(mixing.weight=0.8,
-                 means=c(0, 10),
-                 cov=matrix(c(1,0,0,1), ncol=2, byrow=T))
-
-  l
-}
-
-# Plot the 2D contours of the estimated Gaussian components
-.contour <- function(means, cov, l)
-{
-  X <- mvtnorm::rmvnorm(1000, means, cov)
-  z <- MASS::kde2d(X[,1], X[,2], n=50)
-  contour(z, drawlabels=FALSE, add=TRUE, lty=l, lwd=1.5)
-}
-
-# Do a scatter plot
-.scatter <- function(dat, clusters)
-{
-  plot(dat[,2], dat[,3],
-       xlab="X", ylab="Y", main="Three component Gaussian mixture model",
-       col=c("blue", "red", "orange", "black")[clusters],
-       pch=(1:4)[clusters])
-  col    <- c("blue", "red", "orange")
-  pch    <- 1:3
-  legend <- paste("Cluster", 1:3)
-  if (clusters == 4) {
-    col <- "black"
-    pch <- 4
-    legend = "No clusters"
-  }
-  legend("topleft", col=col, pch=pch, legend=legend)
-}
-
-n   <- 10000
-# create data with n samples
-dat <- .create.data(n)
-repeat
-{
-  # set initial parameters
-  l    <- .init()
-  # plot initial data
-  .scatter(dat, 4)
-  invisible(lapply(1:3, function(e) .contour(l[[e]]$means, l[[e]]$cov, 1)))
-  # Usually we would do a convergence criterion, e.g. compare difference of likelihoods
-  # but this will suffice for the hands on
-  for (i in seq(50))
-  {
-
-    ### E step
-
-    # Compute the sum of all responsibilities (for normalization)
-    r <- sapply(l, function(r)
-    {
-      r$mixing.weight *  mvtnorm::dmvnorm(dat[,2:3], r$means, r$cov)
-    })
-    r <- apply(r, 1, sum)
-    # Compute the responsibilities for each sample
-    rs <- sapply(l, function(e)
-    {
-      e$mixing.weight * mvtnorm::dmvnorm(dat[,2:3], e$means, e$cov) / r
-    })
-    # Compute number of points per cluster
-    N.k <- apply(rs, 2, sum)
-
-
-    ### M step
-
-    # Compute the new means
-    m <- lapply(1:3, function(e)
-    {
-      apply(rs[,e] * dat[,2:3], 2, sum) / N.k[e]
-    })
-    # Compute the new covariances
-    c <- lapply(1:3, function(e)
-    {
-      .cov(n, rs[,e], dat, m[[e]], N.k[e])
-    })
-    # Compute the new mixing weights
-    mi <- N.k / n
-    # Update the old parameters
-    l <- lapply(1:3, function(e)
-    {
-      list(mixing.weight = mi[e], means=m[[e]], cov=c[[e]])
-    })
-    # Plot a 2D density (contour) to show the estimated means and covariances
-    if (i %% 5 == 0)
-    {
-      Sys.sleep(1.5)
-      .scatter(dat, apply(rs, 1, which.max))
-      invisible(lapply(1:3, function(e) .contour(l[[e]]$means, l[[e]]$cov, e + 1)))
-    }
-  }
-}
+```{r}
+data <- read_csv("../results/fusion.csv")
 ```
\ No newline at end of file