---
title: "flat_full.Rmd for working package"
output: html_document
editor_options: 
chunk_output_type: console
---

<!-- Run this 'development' chunk -->
<!-- Store every call to library() that you need to explore your functions -->

```{r development, include=FALSE}
library(testthat)
```

<!--
 You need to run the 'description' chunk in the '0-dev_history.Rmd' file before continuing your code there.

If it is the first time you use {fusen}, after 'description', you can directly run the last chunk of the present file with inflate() inside.
--> 

```{r development-load}
# Load already included functions if relevant
pkgload::load_all(export_all = FALSE)
```


```{r function-utils, filename = "utils"}
#' Join two data frames using data.table
#'
#' @param d1 Data frame 1
#' @param d2 Data frame 2
#' @param k1 Key columns for data frame 1
#' @param k2 Key columns for data frame 2
#' @importFrom data.table data.table
#' @return Joined data frame
#' @export
#'
#' @examples
#'
#' # Example usage:
#' df1 <- data.frame(id = 1:5, value = letters[1:5])
#' df2 <- data.frame(id = 1:5, category = LETTERS[1:5])
#' join_dtf(df1, df2, "id", "id")
join_dtf <- function(d1, d2, k1, k2) {
  d1.dt_table <- data.table::data.table(d1, key = k1)
  d2.dt_table <- data.table::data.table(d2, key = k2)
  dt_joined <- d1.dt_table[d2.dt_table, allow.cartesian = TRUE]
  return(dt_joined %>% as.data.frame())
}

#' Finds the index of the first non-null element in a list.
#'
#' This function searches a list and returns the index of the first non-null element.
#'
#' @param lst The list to search.
#' @return The index of the first non-null element, or NULL if no non-null element is found.
#' @export
#' 
#' @examples
#' my_list <- list(NULL, NULL, 3, 5, NULL)
#' first_non_null_index(my_list)  # Returns 3
first_non_null_index <- function(lst) {
  for (i in seq_along(lst)) {
    if (!is.null(lst[[i]])) {
      return(i)
    }
  }
  return(NULL)
}



#' Detect rows in a matrix with all values below a given threshold
#'
#' This function detects rows in a matrix where all values are below a specified threshold.
#'
#' @param matrix The input matrix
#' @param threshold The threshold value
#' @return A logical vector indicating rows below the threshold
#' @export
detect_row_matx_bellow_threshold <- function(matrix, threshold) {
    apply(matrix, 1, function(row) all(row < threshold))
}


#' Clean Variable Name
#'
#' This function removes digits, spaces, and special characters from a variable name.
#' If any of these are present, they will be replaced with an underscore '_'.
#'
#' @param name The input variable name to be cleaned.
#' @return The cleaned variable name without digits, spaces, or special characters.
#'
#' @details
#' This function will check the input variable name for the presence of digits,
#' spaces, and special characters. If any of these are found, they will be removed
#' from the variable name and replaced with an underscore '_'. Additionally, it will
#' check if the cleaned name is not one of the reserved names "interactions" or
#' "correlations" which are not allowed as variable names.
#' @export
#' @examples
#' clean_variable_name("my_var,:&$àà(-i abl23 e_na__ç^me ")
clean_variable_name <- function(name){
      message("Variable name should not contain digits, spaces, or special characters.\nIf any of these are present, they will be removed from the variable name.")
      # avoid space in variable name
      name <- gsub(" ", "_", name, fixed = TRUE)
      # avoid digit in variable name
      name <-  gsub("[0-9]", "", name)
      # avoid special character in variable name
      name <-  gsub("[[:punct:]]", "", name)
  
      forbidden_names <- c("interactions", "correlations")
      if (name %in% forbidden_names) {
        forbidden_str <- paste(forbidden_names, collapse = " and ")
        stop(forbidden_str, "cannot be used as variable name")
      }
      return(name)
    
}

#' Convert specified columns to factor
#'
#' @param data The input data frame
#' @param columns The column names to be converted to factors
#' @return The modified data frame with specified columns converted to factors
#' @export
#' @examples
#' data <- data.frame( Category1 = c("A", "B", "A", "B"),
#'                      Category2 = c("X", "Y", "X", "Z"),
#'                      Value = 1:4,
#'                      stringsAsFactors = FALSE )
#' ## -- Convert columns to factors
#' convert2Factor(data, columns = c("Category1", "Category2"))
convert2Factor <- function(data, columns) {
  if (is.character(columns)) {
    columns <- match(columns, colnames(data))
  }

  if (length(columns) > 1) data[, columns] <- lapply(data[, columns], as.factor )
  else data[, columns] <- as.factor(data[, columns])
  return(data)
}

#' Get Setting Table
#'
#' Create a table of experimental settings.
#'
#' This function takes various experimental parameters and returns a data frame
#' that represents the experimental settings.
#'
#' @param n_genes Number of genes in the experiment.
#' @param max_replicates Maximum number of replicates for each gene.
#' @param min_replicates Minimum number of replicates for each gene.
#' @param lib_size  total number of reads
#'
#' @return A data frame containing the experimental settings with their corresponding values.
#' @export
getSettingsTable <- function(n_genes, max_replicates, min_replicates, lib_size ){
  
  settings_df <- data.frame(parameters = c("# genes", "Max # replicates", "Min # replicates", "Library size" ),
                            values = c(n_genes, max_replicates, min_replicates, lib_size))
  rownames(settings_df) <- NULL
  
  return(settings_df)
}


#' Check if a matrix is positive definite
#' This function checks whether a given matrix is positive definite, i.e., all of its eigenvalues are positive.
#' @param mat The matrix to be checked.
#' @return A logical value indicating whether the matrix is positive definite.
#' @export
#' @examples
#' # Create a positive definite matrix
#' mat1 <- matrix(c(4, 2, 2, 3), nrow = 2)
#' is_positive_definite(mat1)
#' # Expected output: TRUE
#'
#' # Create a non-positive definite matrix
#' mat2 <- matrix(c(4, 2, 2, -3), nrow = 2)
#' is_positive_definite(mat2)
#' # Expected output: FALSE
#'
#' # Check an empty matrix
#' mat3 <- matrix(nrow = 0, ncol = 0)
#' is_positive_definite(mat3)
#' # Expected output: TRUE
#'
is_positive_definite <- function(mat) {
  if (nrow(mat) == 0 && ncol(mat) == 0) return(TRUE)
  eigenvalues <- eigen(mat)$values
  all(eigenvalues > 0)
}




#' Get the list of variable names
#'
#' @param list_var R list, e.g., output of init_variable
#'
#' @return
#' A character vector with the names of variables
#' @examples
#' getListVar(init_variable())
#' @export
getListVar <- function(list_var) attributes(list_var)$names

#' Get a given attribute from a list of variables
#'
#' @param list_var A list of variables (already initialized with init_variable)
#' @param attribute A string specifying the attribute to retrieve in all occurrences of the list
#' @export
#' @return
#' A list without NULL values
#' @examples
#' getGivenAttribute(init_variable(), "level")
getGivenAttribute <- function(list_var, attribute) {
  l <- lapply(list_var, FUN = function(var) var[[attribute]]) 
  l_withoutNull <- l[!vapply(l, is.null, logical(1))]
  return(l_withoutNull)
}


#' Get labels for variables
#'
#' @param l_variables2labelized A list of variables
#' @param l_nb_label A list of numeric values representing the number of levels per variable
#' @export
#' @return
#' A list of labels per variable
#' 
#' @examples
#' labels <- getLabels(c("varA", "varB"), c(2, 3))
getLabels <- function(l_variables2labelized, l_nb_label) {
  getVarNameLabel <- function(name, level) {
    list_label <- paste(name, 1:level, sep = "")
    return(list_label)
  }
  
  listLabels <- Map(getVarNameLabel, l_variables2labelized, l_nb_label)
  return(listLabels)
}


#' getGridCombination
#'
#' Generates all possible combinations of labels.
#'
#' @param l_labels List of label vectors
#'
#' @return A data frame with all possible combinations of labels
#' @export
#'
#' @examples
#' l_labels <- list(
#'   c("A", "B", "C"),
#'   c("X", "Y")
#' )
#' getGridCombination(l_labels)
getGridCombination <- function(l_labels) {
  grid <- expand.grid(l_labels)
  colnames(grid) <- paste("label", seq_along(l_labels), sep = "_")
  return(grid)
}



#' Get grid combination from list_var
#'
#' @param list_var A list of variables (already initialized)
#'
#' @return
#' The grid combination between variable in list_var
#' @export
#' @examples
#' generateGridCombination_fromListVar(init_variable())
generateGridCombination_fromListVar <- function (list_var){
  l_levels <- getGivenAttribute(list_var, "level") %>% unlist()
  vars <- names(l_levels)
  l_levels <- l_levels[vars]
  l_labels <- getLabels(l_variables2labelized = vars, l_nb_label = l_levels)
  gridComb <- getGridCombination(l_labels)
  colnames(gridComb) <- paste("label", vars, sep = "_")
  return(gridComb)
}

#' Remove Duplicated Words from Strings
#'
#' This function takes a vector of strings and removes duplicated words within each string.
#'
#' @param strings A character vector containing strings with potential duplicated words.
#' @return A character vector with duplicated words removed from each string.
#' @export
#' @examples
#' words <- c("hellohello", "worldworld", "programmingprogramming", "R isis great", "duplicateeee1333")
#' cleaned_words <- removeDuplicatedWord(words)
removeDuplicatedWord <- function(strings){
  gsub("([A-Za-z]{1,})(\\1{1,})", "\\1", strings, perl = TRUE)
  #gsub("(.*)\\1+", "\\1", strings, perl = TRUE)
}


#' Reorder the columns of a dataframe
#'
#' This function reorders the columns of a dataframe according to the specified column order.
#'
#' @param df The input dataframe.
#' @param columnOrder A vector specifying the desired order of columns.
#'
#' @return A dataframe with columns reordered according to the specified column order.
#' @export
#' @examples
#' # Example dataframe
#' df <- data.frame(A = 1:3, B = 4:6, C = 7:9)
#'
#' # Define the desired column order
#' columnOrder <- c("B", "C", "A")
#'
#' # Reorder the columns of the dataframe
#' df <- reorderColumns(df, columnOrder)
reorderColumns <- function(df, columnOrder) {
  df <- df[, columnOrder, drop = FALSE]
  return(df)
}



#' Check if a list of glmmTMB objects is valid
#'
#' This function checks if a list of glmmTMB objects is valid. It ensures that the input 
#' list contains glmmTMB objects generated by the `fitModelParallel` function.
#'
#' @param list_tmb A list of glmmTMB objects.
#' @return TRUE if the list is valid, otherwise an error is thrown.
#' @export
isValidList_tmb <- function(list_tmb) {
  stopifnot(is.list(list_tmb))
  
  if (all(sapply(list_tmb, is.null))) {
    stop("All elements in 'list_tmb' are NULL")
  }
  
  invisible(lapply(names(list_tmb), function(i) isValidGlmmTmb(i, list_tmb[[i]])))
  return(TRUE)
}

#' Check if a glmmTMB object is valid
#'
#' This function checks if a glmmTMB object is valid. It ensures that the input object 
#' is a glmmTMB object generated by the `fitModelParallel` function.
#'
#' @param i The name of the object being checked.
#' @param obj The glmmTMB object being checked.
#' @return TRUE if the object is valid, otherwise an error is thrown.
#' @export
isValidGlmmTmb <- function(i, obj) {
  if (is.null(obj)) {
    return(TRUE)
  }
  
  if (!inherits(obj, "glmmTMB")) {
    stop(paste("Element", i, "is not a glmmTMB object. 'list_tmb' should be generated by fitModelParallel"))
  }
  return(TRUE)
}



#' Checks if an object corresponds to a mock object generated by `mock_rnaseq()`.
#'
#' This function verifies if the provided object matches the structure of a mock object generated
#' by `mock_rnaseq()`. A mock object should contain specific named elements: "settings", "init",
#' "groundTruth", "counts", and "metadata".
#'
#' @param obj Object to be checked.
#' @return TRUE or error message
#' @export
isValidMock_obj <- function(obj) {
  message_err <- "'mock_obj' does not correspond to HTRfit mock_obj. 'mock_obj' can be generated using mock_rnaseq()."
  
  if (!is.list(obj)) {
    stop(message_err)
  }
  
  expected_names <- c("settings", "init", "groundTruth", "counts", "metadata", "scaling_factors")
  
  if (!all(expected_names %in% names(obj))) {
    stop(message_err)
  }
  
  if (!all(names(obj) %in% expected_names)){
    warning("Unexpected list element in 'mock_obj'")
  }
    
  return(TRUE)
}



clear_memory <- function(except_obj){
  rm(list = setdiff(ls(), except_obj)) ; invisible(gc( reset = TRUE, verbose = FALSE ))
}


```


```{r tests-utils}

# Test for first_non_null_index function
test_that("first_non_null_index returns the correct index", {
  lst <- list(NULL, NULL, 3, 5, NULL)
  expect_equal(first_non_null_index(lst), 3)
})



# Test unitaires pour la fonction join_dtf
test_that("join_dtf réalise la jointure correctement", {
  # Création de données de test
  df1 <- data.frame(id = 1:5, value = letters[1:5])
  df2 <- data.frame(id = 1:5, category = LETTERS[1:5])
  
  # Exécution de la fonction
  result <- join_dtf(df1, df2, "id", "id")
  
  # Vérification des résultats
  expect_true(is.data.frame(result))
  expect_equal(nrow(result), 5)
  expect_equal(ncol(result), 3)
  expect_equal(names(result), c("id", "value", "category"))
  expect_true(all.equal(result$id, df1$id))
  expect_true(all.equal(result$id, df2$id))
})


test_that("clean_variable_name correctly removes digits, spaces, and special characters", {
  expect_equal(clean_variable_name("my variable name"), "myvariablename")
  expect_equal(clean_variable_name("variable_1"), "variable")
  expect_equal(clean_variable_name("^spec(ial#chars! "), "specialchars")
})

test_that("clean_variable_name handles reserved names properly", {
  expect_error(clean_variable_name("interactions"))
  expect_error(clean_variable_name("correlations"))
})


test_that("getLabels generates labels for variables", {
  labels <- getLabels(c("varA", "varB"), c(2, 3))
  expect_equal(length(labels), 2)
  expect_equal(length(labels[[1]]), 2)
  expect_equal(length(labels[[2]]), 3)
})

test_that("getGridCombination generates a grid of combinations", {
  labels <- list(A = c("A1", "A2"), B = c("B1", "B2", "B3"))
  grid_combination <- getGridCombination(labels)
  expect_equal(dim(grid_combination), c(6, 2))
})


test_that("generateGridCombination_fromListVar returns expected output", {
  result <- generateGridCombination_fromListVar(init_variable())
  expect <- data.frame(label_myVariable = c("myVariable1", "myVariable2"))
  expect_equal(nrow(result), nrow(expect))
  expect_equal(ncol(result), ncol(expect))
})

# Tests for convert2Factor
test_that("convert2Factor converts specified columns to factors", {
  # Create a sample data frame
  data <- data.frame(
    Category1 = c("A", "B", "A", "B"),
    Category2 = c("X", "Y", "X", "Z"),
    Value = 1:4,
    stringsAsFactors = FALSE
  )
  
  # Convert columns to factors
  result <- convert2Factor(data, columns = c("Category1", "Category2"))
  
  # Check the output
  expect_is(result$Category1, "factor")  # Category1 column converted to factor
  expect_is(result$Category2, "factor")  # Category2 column converted to factor
})

test_that("removeDuplicatedWord returns expected output", {
  words <- c("hellohello", "worldworld", "programmingprogramming", "R isis great")
  cleaned_words <- removeDuplicatedWord(words)
  expect_equal(cleaned_words, c("hello", "world", "programming", "R is great"))
})


# Test for detect_row_matx_bellow_threshold function
test_that("detect_row_matx_bellow_threshold detects rows below threshold", {
  # Create a sample matrix
  matrix <- matrix(c(0.5, 0.7, 1.2, 0.2, 0.9, 0.9), nrow = 2)
  # Test with threshold 1
  expect_equal(detect_row_matx_bellow_threshold(matrix, 1), c(FALSE, TRUE))
  # Test with threshold 0.5
  expect_equal(detect_row_matx_bellow_threshold(matrix, 0.5), c(FALSE, FALSE))
  expect_equal(detect_row_matx_bellow_threshold(matrix, 2), c(TRUE, TRUE))
})


test_that("reorderColumns returns expected output",{
    df <- data.frame(A = 1:3, B = 4:6, C = 7:9)
    # Define the desired column order
    columnOrder <- c("B", "C", "A")
    # Reorder the columns of the dataframe
    df_reorder <- reorderColumns(df, columnOrder)
    expect_equal(colnames(df_reorder), columnOrder)
    expect_equal(dim(df_reorder), dim(df))

})


test_that( "generateGridCombination_fromListVar return expected output", {
    ## case 1
    gridcom <- generateGridCombination_fromListVar(init_variable())
    expect_s3_class(gridcom, "data.frame")
    expect_equal(gridcom$label_myVariable, factor(c("myVariable1", "myVariable2")))

    ## case 2
    init_variables <- init_variable() %>% init_variable(name = "var" , mu = 2, sd = 1, level = 3) 
    gridcom <- generateGridCombination_fromListVar(init_variables)
    expect_s3_class(gridcom, "data.frame")
    expect_equal(unique(gridcom$label_myVariable), factor(c("myVariable1", "myVariable2")))
    expect_equal(unique(gridcom$label_var), factor(c("var1", "var2", "var3")))

  })

test_that( "getGivenAttribute return expected output", {
  ## -- case 1
  level_attr <- getGivenAttribute(init_variable(), "level")
  expect_equal(level_attr$myVariable, 2)

  ## -- case 2
  init_variables <- init_variable() %>% init_variable(name = "var" , mu = 2, sd = 1, level = 3) 
  mu_attr <- getGivenAttribute(init_variables, "mu")
  expect_equal(mu_attr$var, 2)
} )



test_that("isValidList_tmb function", {
  # Test with a valid list of glmmTMB objects
  l_tmb <- list("model1" = glmmTMB::glmmTMB(mpg ~ hp + vs + am + (1|cyl), data = mtcars),
                 "model2" = glmmTMB::glmmTMB(mpg ~ hp + vs + am + (1|cyl), data = mtcars))
  expect_true(isValidList_tmb(l_tmb))
  
  # Test with a list containing NULL elements
  expect_error(isValidList_tmb(list(NULL, NULL)), "All elements in 'list_tmb' are NULL")
  
  # Test with an empty list
  expect_error(isValidList_tmb(list()), "All elements in 'list_tmb' are NULL")
})

test_that("isValidGlmmTmb function", {
  # Test with a valid glmmTMB object
  valid_model <- glmmTMB::glmmTMB(mpg ~ hp + vs + am + (1|cyl), data = mtcars)
  expect_true(isValidGlmmTmb("model", valid_model))
  
  # Test with an invalid object (not a glmmTMB object)
  invalid_object <- list(a = 1, b = 2)
  expect_error(isValidGlmmTmb("object", invalid_object), "Element object is not a glmmTMB object.")
  
  # Test with NULL object
  expect_true(isValidGlmmTmb("model", NULL))
})



# Mock object

test_that("isValidMock_obj checks if the provided object is a valid mock object", {
  mock_obj <- mock_rnaseq(init_variable(), n_genes = 100, 4, 4)

  # Test with a valid mock object
  expect_true(isValidMock_obj(mock_obj))
  
  # Test with an object missing an element
  missing_element_obj <- list(settings = list(), init = list(), groundTruth = list(), counts = list())
  expect_error(isValidMock_obj(missing_element_obj))
  
  # Test with an object containing additional elements
  additional_element_obj <- mock_obj
  additional_element_obj$error_name <- list()
  expect_warning(isValidMock_obj(additional_element_obj))
})

```


```{r function-init_variable, filename = "simulation_initialization"}
#' Initialize variable
#'
#' @param list_var Either c() or output of init_variable
#' @param name Variable name
#' @param sd Either numeric value or NA. Use to specify range of effect sizes.
#' @param level Numeric value to specify the number of levels to simulate. Default = 2.
#' @param mu Either a numeric value or a numeric vector (of length = level). Default : 0. Not recommended to modify.
#'
#' @return
#' A list with initialized variables
#' @export
#'
#' @examples
#' init_variable(name = "my_varA", sd = 0.50, level = 200)
init_variable <- function(list_var = c(), name = "myVariable", sd = 0.2, level = 2, mu = 0) {
  
  name <- clean_variable_name(name)
  
  # Only mu specified by user => set level param
  if (is.na(sd)) {
    level <- length(mu)
  }
  
  # Validate inputs
  inputs_checking(list_var, name, mu, sd, level)
  
  if (endsWithDigit(name)) {
    warning("Names ending with digits are not allowed. They will be removed from the variable name.")
    name <- removeDigitsAtEnd(name)
  }
  
  # Initialize new variable
  list_var[[name]] <- fillInVariable(name, mu, sd, level)
  
  return(list_var)
}



#' Check if a string ends with a digit
#'
#' This function checks whether a given string ends with a digit.
#'
#' @param string The input string to be checked
#' @return \code{TRUE} if the string ends with a digit, \code{FALSE} otherwise
#' @export
#' @examples
#' endsWithDigit("abc123")  # Output: TRUE
#' endsWithDigit("xyz")     # Output: FALSE
endsWithDigit <- function(string) {
  lastChar <- substring(string, nchar(string))
  return(grepl("[0-9]", lastChar))
}

#' Remove digits at the end of a string
#'
#' This function removes any digits occurring at the end of a given string.
#'
#' @param string The input string from which digits are to be removed
#' @return The modified string with digits removed from the end
#' @export
#' @examples
#' removeDigitsAtEnd("abc123")  # Output: "abc"
#' removeDigitsAtEnd("xyz")     # Output: "xyz"
removeDigitsAtEnd <- function(string) {
  return(gsub("\\d+$", "", string))
}


#' Check Input Parameters
#'
#' This function checks the validity of the input parameters for initializing a variable.
#' It ensures that the necessary conditions are met for the input parameters.
#'
#' @param list_var List containing the variables to be initialized.
#' @param name Name of the variable.
#' @param mu Mean of the variable.
#' @param sd Standard deviation of the variable (optional).
#' @param level Number of levels for categorical variables.
#' 
#' @return NULL
#' @export
#'
#' @examples
#' inputs_checking(list_var = c(), name = "var1", mu = 0, sd = 1, level = 2)
inputs_checking <- function(list_var, name, mu, sd, level) {
  stopifnot(name != "")
  stopifnot(is.character(name))
  stopifnot(is.numeric(mu))
  stopifnot(is.numeric(sd) | is.na(sd))
  stopifnot(is.numeric(level))
  stopifnot(length(level) == 1)
  stopifnot(level >= 2)

  if (!is.null(list_var)) {
    error_msg <- "Non conformable list_var parameter.\nlist_var must be set as an init_var output or initialized as c()"
    if (!is.list(list_var)) {
      stop(error_msg)
    }
  }

  if (length(mu) > 1) {
    stopifnot(length(mu) == level)
  }

  if (is.na(sd)) {
    if (level != length(mu)) {
      stop("sd was specified as NA. mu should have the same length as the number of levels\n")
    }
  }

  # Check if variable is already initialized
  name_not_in_list_var <- identical(which(already_init_variable(list_var, name)), integer(0))
  if (!name_not_in_list_var) {
    message(paste(name, "is already initialized in list_var.\nIt will be updated", sep = " "))
  }

  return(NULL)
}


#' Check if Variable is Already Initialized
#'
#' This function checks if a variable is already initialized in the variable list.
#'
#' @param list_var A list object representing the variable list.
#' @param new_var_name A character string specifying the name of the new variable.
#'
#' @return TRUE if the variable is already initialized, FALSE otherwise.
#' @export
#'
#' @examples
#' my_list <- list(var1 = 1, var2 = 2, var3 = 3)
#' already_initialized <- already_init_variable(list_var = my_list, new_var_name = "myVariable")
already_init_variable <- function(list_var, new_var_name) {
  if (is.null(list_var)) {
    return(FALSE)
  }
  
  var_names <- names(list_var)
  return(new_var_name %in% var_names)
}

#' Fill in Variable
#'
#' This function fills in a variable with simulated data based on the provided parameters.
#'
#' @param name The name of the variable.
#' @param mu A numeric value or a numeric vector (of length = level) representing the mean.
#' @param sd A numeric value representing the standard deviation, or NA if not applicable.
#' @param level A numeric value specifying the number of levels to simulate.
#'
#' @return A data frame or a list containing the simulated data for the variable.
#' @export
#'
#' @examples
#' variable_data <- fillInVariable(name = "myVariable", mu = c(2, 3), sd = NA, level = 2)
fillInVariable <- function(name, mu, sd, level) {
  
  if (length(mu) > 1 | is.na(sd)) {  # Effects given by user
    level <- length(mu)
    l_labels <- paste(name, 1:level, sep = '')
    l_betaEffects <- mu
    column_names <- c(paste("label", name, sep = "_"), name)
    sub_obj <- build_sub_obj_return_to_user(level, metaData = l_labels,
                                       effectsGivenByUser = l_betaEffects,
                                       column_names)
  } else {
    sub_obj <- as.data.frame(list(mu = mu, sd = sd, level = level))
  }
  
  return(sub_obj)  
}

#' Build Sub Object to Return to User
#'
#' This function builds the sub-object to be returned to the user.
#'
#' @param level A numeric value specifying the number of levels.
#' @param metaData A list of labels.
#' @param effectsGivenByUser A list of effects given by the user.
#' @param col_names A character vector specifying the column names to use.
#' @importFrom utils tail
#'
#' @return A list with the sub-object details.
build_sub_obj_return_to_user <- function(level, metaData, effectsGivenByUser, col_names) {
  sub_obj <- list(level = level)
  data <- cbind(metaData, effectsGivenByUser) %>% as.data.frame()
  colnames(data) <- col_names
  var_name <- utils::tail(col_names, n = 1)
  data[, var_name] <- as.numeric(data[, var_name])
  sub_obj$data <- data
  return(sub_obj)
}


#' Add interaction
#'
#' @param list_var A list of variables (already initialized)
#' @param between_var A vector of variable names to include in the interaction
#' @param sd Either numeric value or NA. Use to specify range of effect sizes. Default 0 for no interaction effects.
#' @param mu Either a numeric value or a numeric vector (of length = level). Default : 0. Not recommended to modify.

#'
#' @return
#' A list with initialized interaction
#' @export
#'
#' @examples
#' init_variable(name = "myvarA", sd = 3, level = 200) %>%
#' init_variable(name = "myvarB", sd = 0.2, level = 2 ) %>%
#' add_interaction(between_var = c("myvarA", "myvarB"), sd = 2)
add_interaction <- function(list_var, between_var, sd = 0, mu = 0) {
  name_interaction <- paste(between_var, collapse = ":")
  check_input2interaction(name_interaction, list_var, between_var, mu, sd)
  
  # Check the number of variables in the interaction
  if (length(between_var) > 3) {
    stop("Cannot initialize an interaction with more than 3 variables.")
  }
  
  interactionCombinations <- getNumberOfCombinationsInInteraction(list_var, between_var)
  list_var$interactions[[name_interaction]] <- fillInInteraction(list_var, between_var, mu, sd, interactionCombinations)
  return(list_var)
}

#' Check input for interaction
#'
#' @param name_interaction String specifying the name of the interaction (example: "varA:varB")
#' @param list_var A list of variables (already initialized)
#' @param between_var A vector of variable names to include in the interaction
#' @param mu Either a numeric value or a numeric vector (of length = level)
#' @param sd Either numeric value or NA
#'
#' @return
#' NULL (throws an error if the input is invalid)
#' @export
check_input2interaction <- function(name_interaction, list_var, between_var, mu, sd) {
  # Check if variables in between_var are declared and initialized
  bool_checkInteractionValidity <- function(between_var, list_var) {
    nb_varInInteraction <- length(unique(between_var))
    stopifnot(nb_varInInteraction > 1)
    existingVar_nb <- getListVar(list_var) %in% between_var %>% sum()
    if (existingVar_nb != nb_varInInteraction) {
      return(FALSE)
    } else {
      return(TRUE)
    }
  }
  
  bool_valid_interaction <- bool_checkInteractionValidity(between_var, list_var)
  if (!bool_valid_interaction) {
    stop("At least one variable in between_var is not declared. Variable not initialized cannot be used in an interaction.")
  }
  
  requestedNumberOfValues <- getNumberOfCombinationsInInteraction(list_var, between_var)
  if (is.na(sd) && requestedNumberOfValues != length(mu)) {
    msg_e <- "sd was specified as NA. mu should have the same length as the possible number of interactions:\n"
    msg_e2 <- paste(requestedNumberOfValues, "interaction values are requested.")
    stop(paste(msg_e, msg_e2))
  }
  
  level <- requestedNumberOfValues
  inputs_checking(list_var$interactions, name_interaction, mu, sd, level)
}

#' Get the number of combinations in an interaction
#'
#' @param list_var A list of variables (already initialized)
#' @param between A vector of variable names to include in the interaction
#'
#' @return
#' The number of combinations in the interaction
#' @export
getNumberOfCombinationsInInteraction <- function(list_var, between) {
  levelInlistVar <- getGivenAttribute(list_var, "level") %>% unlist()
  n_combinations <- prod(levelInlistVar[between]) 
  return(n_combinations)
}

#' Fill in interaction
#'
#' @param list_var A list of variables (already initialized)
#' @param between A vector of variable names to include in the interaction
#' @param mu Either a numeric value or a numeric vector (of length = level)
#' @param sd Either numeric value or NA
#' @param level Number of interactions
#'
#' @return
#' A data frame with the filled-in interaction values
#' @export
fillInInteraction <- function(list_var, between, mu, sd, level) {
  if (length(mu) > 1 || is.na(sd)) {
    l_levels <- getGivenAttribute(list_var, "level") %>% unlist()
    l_levelsOfInterest <- l_levels[between]
    l_labels_varOfInterest <- getLabels(l_variables2labelized = between, l_nb_label = l_levelsOfInterest ) 
    
    grid_combination <- getGridCombination(l_labels_varOfInterest)
    n_combinations <- dim(grid_combination)[1]
    column_names <- c(paste("label", between, sep = "_"), paste(between, collapse = ":"))
    sub_dtf <- build_sub_obj_return_to_user(level = n_combinations,
                                            metaData = grid_combination,
                                            effectsGivenByUser = mu, 
                                            col_names = column_names)
  } else {
    sub_dtf <- list(mu = mu, sd = sd, level = level) %>% as.data.frame()
  }
  
  return(sub_dtf)
}

```


```{r tests-init_variable}

test_that("endsWithDigit returns the correct result", {
  expect_true(endsWithDigit("abc123"))
  expect_false(endsWithDigit("xyz"))
})

test_that("removeDigitsAtEnd removes digits at the end of a string", {
  expect_equal(removeDigitsAtEnd("abc123"), "abc")
  expect_equal(removeDigitsAtEnd("xyz"), "xyz")
})


test_that("init_variable initializes a variable correctly", {
  # Test case 1: Initialize a variable with default parameters
  list_var <- init_variable()
  expect_true("myVariable" %in% names(list_var))

  # Test case 2: Initialize a variable with custom parameters
  list_var <- init_variable(name = "custom_variable", mu = c(1, 2, 3), sd = 0.5, level = 3)
  expect_true("customvariable" %in% names(list_var))
  expect_equal(nrow(list_var$customvariable$data), 3)
})

test_that("inputs_checking performs input validation", {
  
  # Test case 1: Invalid inputs - sd is NA but mu has unique values
  expect_error(inputs_checking(list_var = c(), name = "myVariable", mu = 2, sd = NA, level = 2))
  
  # Test case 2: Invalid inputs - empty name
  expect_error(inputs_checking(list_var = c(), name = "", mu = 2, sd = NA, level = 2))
  
  # Test case 3: Invalid inputs - non-numeric mu
  expect_error(inputs_checking(list_var = c(), name = "myVariable", mu = "invalid", sd = NA, level = 2))
  
  # Test case 4: Invalid inputs - non-numeric sd
  expect_error(inputs_checking(list_var = c(), name = "myVariable", mu = 2, sd = "invalid", level = 2))
  
  # Test case 5: Invalid inputs - level less than 2
  expect_error(inputs_checking(list_var = c(), name = "myVariable", mu = 2, sd = NA, level = 1))
  
  # Test case 6: Invalid inputs - mu and level have different lengths
  expect_error(inputs_checking(list_var = c(), name = "myVariable", mu = c(1, 2, 3), sd = NA, level = 2))
  
  # Test case 7: Valid inputs
  expect_silent(inputs_checking(list_var = c(), name = "myVariable", mu = c(1, 2, 3), sd = NA, level = 3))
})



test_that("already_init_variable checks if a variable is already initialized", {
  list_var <- init_variable()
  
  # Test case 1: Variable not initialized
  list_var <- init_variable(name = "custom_variable", mu = c(2, 3), sd = NA, level = 2)
  expect_true(already_init_variable(list_var, "customvariable"))
  
  # Test case 2: Variable already initialized 
  expect_false(already_init_variable(list_var, "myVariable"))
  
})

test_that("fillInVariable fills in variable correctly", {
  # Test case 1: Effects given by user
  sub_obj <- fillInVariable("myVariable", c(1, 2, 3), NA, NA)
  expect_equal(sub_obj$level, 3)
  expect_equal(ncol(sub_obj$data), 2)
  
  # Test case 2: Effects simulated using mvrnorm
  sub_obj <- fillInVariable("myVariable", 2, 0.5, 3)
  expect_equal(sub_obj$level, 3)
  expect_equal(sub_obj$sd, 0.5)
  expect_equal(sub_obj$mu, 2)
})

test_that("build_sub_obj_return_to_user returns the expected output", {
  level <- 3
  metaData <- paste("label", 1:level, sep = "_")
  effectsGivenByUser <- c(2, 3, 4)
  col_names <- c("metadata", "effects")
  
  result <- build_sub_obj_return_to_user(level, metaData, effectsGivenByUser, col_names)
  
  expect_equal(result$level, level)
  expect_identical(result$data$metadata, metaData)
  expect_identical(result$data$effects, effectsGivenByUser)
  
  
})

test_that("add_interaction adds an interaction between variables", {
  list_var <- init_variable(name = "varA", mu = 1, sd = 1, level = 2)
  list_var <- init_variable(list_var, name = "varB", mu = 2, sd = 1, level = 3)
  list_var <- add_interaction(list_var, between_var = c("varA", "varB"), mu = 0.5, sd = 3)
  expect_true("varA:varB" %in% names(list_var$interactions))
})

test_that("add_interaction throws an error for invalid variables", {
  list_var <- init_variable(name = "varA", mu = 1, sd = 1, level = 2)
  expect_error(add_interaction(list_var, between_var = c("varA", "varB"), mu = 0.5, sd = NA))
})


test_that("getNumberOfCombinationsInInteraction calculates the number of combinations", {
  list_var <- init_variable(name = "varA", mu = 1, sd = 1, level = 2)
  list_var <- init_variable(list_var, name = "varB", mu = 2, sd = 1, level = 3)
  expect_equal(getNumberOfCombinationsInInteraction(list_var, c("varA", "varB")), 6)
})
```

```{r function-mvrnorm, filename = "datafrommvrnorm_manipulations" }
#' getInput2mvrnorm
#'
#' @inheritParams init_variable
#'
#' @return
#' a list that can be used as input for MASS::mvrnorm
#' @export
#'
#' @examples
#' list_var <- init_variable(name = "my_var", mu = 0, sd = 2, level = 3)
#' getInput2mvrnorm(list_var)
getInput2mvrnorm <- function(list_var){
  # -- pick up sd provided by user
  variable_standard_dev <- getGivenAttribute(list_var, attribute = "sd") %>% unlist()
  interaction_standard_dev <- getGivenAttribute(list_var$interactions, attribute = "sd") %>% unlist()
  list_stdev_2covmatx <- c(variable_standard_dev, interaction_standard_dev)
  if (is.null(list_stdev_2covmatx)) ## NO SD provided
    return(list(mu = NULL, covMatrix = NULL))

  # - COV matrix
  covar_userProvided = getGivenAttribute(list_var$correlations, "covar")
  covMatrix <- getCovarianceMatrix(list_stdev_2covmatx, covar_userProvided)

  # -- MU
  variable_mu <- getGivenAttribute(list_var, attribute = "mu") %>% unlist()
  interaction_mu <- getGivenAttribute(list_var$interactions, attribute = "mu") %>% unlist()
  list_mu <- c(variable_mu, interaction_mu)

  return(list(mu = list_mu, covMatrix = covMatrix))

}


#' getCovarianceMatrix 
#' @param list_stdev standard deviation list
#' @param list_covar covariance list
#' 
#' @return
#' covariance matrix
#' @export
#'
#' @examples
#' vector_sd <- c(1,2, 3)
#' names(vector_sd) <- c("varA", "varB", "varC")
#' vector_covar <- c(8, 12, 24)
#' names(vector_covar) <- c("varA.varB", "varA.varC", "varB.varC")
#' covMatrix <- getCovarianceMatrix(vector_sd, vector_covar)
getCovarianceMatrix <- function(list_stdev, list_covar){
  # -- cov(A, A) = sd(A)^2
  diag_cov <- list_stdev^2
  dimension <- length(diag_cov)
  covariance_matrix <- matrix(0,nrow = dimension, ncol = dimension)
  diag(covariance_matrix) <- diag_cov
  colnames(covariance_matrix) <- paste("label", names(diag_cov), sep = "_")
  rownames(covariance_matrix) <- paste("label", names(diag_cov), sep = "_")
  names_covaration <- names(list_covar)

  ###### -- utils -- #####
  convertDF <- function(name, value){
    ret <- data.frame(value)
    colnames(ret) <- name
    ret
  }

  ## -- needed to use reduce after ;)
  l_covarUserDf <- lapply(names_covaration, function(n_cov) convertDF(n_cov, list_covar[n_cov] ))
  covariance_matrix2ret <- Reduce(fillInCovarMatrice, x = l_covarUserDf, init =  covariance_matrix)
  covariance_matrix2ret
}


#' Fill in Covariance Matrix
#'
#' This function updates the covariance matrix with the specified covariance value between two variables.
#'
#' @param covarMatrice The input covariance matrix.
#' @param covar A data frame containing the covariance value between two variables.
#' @return The updated covariance matrix with the specified covariance value filled in.
#' @export
#' @examples
#' covarMat <- matrix(0, nrow = 3, ncol = 3)
#' colnames(covarMat) <- c("label_varA", "label_varB", "label_varC")
#' rownames(covarMat) <- c("label_varA", "label_varB", "label_varC")
#' covarValue <- data.frame("varA.varB" = 0.5)
#' fillInCovarMatrice(covarMatrice = covarMat, covar = covarValue)
fillInCovarMatrice <- function(covarMatrice, covar){
  varsInCovar <- strsplit(colnames(covar), split = "[.]") %>% unlist()
  index_matrix <- paste("label",varsInCovar, sep  = "_")
  covar_value <- covar[1,1]
  covarMatrice[index_matrix[1], index_matrix[2]] <- covar_value
  covarMatrice[index_matrix[2], index_matrix[1]] <- covar_value
  return(covarMatrice)
}

#' getGeneMetadata
#'
#' @inheritParams init_variable
#' @param n_genes Number of genes to simulate
#'
#' @return
#' metadata matrix
#' 
#' @export
#'
#' @examples
#' list_var <- init_variable()
#' metadata <- getGeneMetadata(list_var, n_genes = 10)
getGeneMetadata <- function(list_var, n_genes) {
  metaData <- generateGridCombination_fromListVar(list_var)
  n_combinations <- dim(metaData)[1]
  genes_vec <- paste("gene", 1:n_genes, sep = "")
  geneID <- rep(genes_vec, each = n_combinations)
  metaData <- cbind(geneID, metaData)
  
  return(metaData)
}


#' getDataFromMvrnorm
#'
#' @inheritParams init_variable 
#' @param input2mvrnorm list with mu and covariance matrix, output of getInput2mvrnorm
#' @param n_genes Number of genes to simulate
#' 
#' @return
#' data simulated from multivariate normal distribution
#' 
#' @export
#'
#' @examples
#' list_var <- init_variable()
#' input <- getInput2mvrnorm(list_var)
#' simulated_data <- getDataFromMvrnorm(list_var, input, n_genes = 10)
getDataFromMvrnorm <- function(list_var, input2mvrnorm, n_genes = 1) {
  if (is.null(input2mvrnorm$covMatrix))
    return(list())
  
  metaData <- getGeneMetadata(list_var, n_genes)
  n_tirages <- dim(metaData)[1]
  
  mtx_mvrnormSamplings <- samplingFromMvrnorm(n_samplings = n_tirages, 
                                             l_mu = input2mvrnorm$mu, matx_cov = input2mvrnorm$covMatrix)
  
  dataFromMvrnorm <- cbind(metaData, mtx_mvrnormSamplings)
  
  return(list(dataFromMvrnorm))
}


#' getDataFromMvrnorm
#'
#' @param n_samplings number of samplings using mvrnorm
#' @param l_mu vector of mu
#' @param matx_cov covariance matrix
#'
#' @return
#' samples generated from multivariate normal distribution
#' 
#' @export
#' @importFrom MASS mvrnorm
#' @examples
#' n <- 100
#' mu <- c(0, 0)
#' covMatrix <- matrix(c(1, 0.5, 0.5, 1), ncol = 2)
#' samples <- samplingFromMvrnorm(n_samplings = n, l_mu = mu, matx_cov = covMatrix)
samplingFromMvrnorm <- function(n_samplings, l_mu, matx_cov) {
  mvrnormSamp <-  MASS::mvrnorm(n = n_samplings, mu = l_mu, Sigma = matx_cov, empirical = TRUE)
  return(mvrnormSamp)
}

```

```{r  tests-mvrnorm}
test_that("getInput2mvrnorm returns the correct list", {
  list_var <- init_variable()
  input <- getInput2mvrnorm(list_var)
  expect_is(input, "list")
  expect_true("mu" %in% names(input))
  expect_true("covMatrix" %in% names(input))
})


test_that("fillInCovarMatrice returns the correct matrix", {
  covarMat <- matrix(0, nrow = 3, ncol = 3)
  colnames(covarMat) <- c("label_varA", "label_varB", "label_varC")
  rownames(covarMat) <- c("label_varA", "label_varB", "label_varC")
  covarValue <- data.frame("varA.varB" = 18)
  matrice <- fillInCovarMatrice(covarMatrice = covarMat, covar = covarValue)
  
  expected_matrice <- matrix(0, nrow = 3, ncol = 3)
  colnames(expected_matrice) <- c("label_varA", "label_varB", "label_varC")
  rownames(expected_matrice) <- c("label_varA", "label_varB", "label_varC")
  expected_matrice["label_varA", "label_varB"] <- 18
  expected_matrice["label_varB", "label_varA"] <- 18
  expect_identical(matrice, expected_matrice)
})

test_that("getCovarianceMatrix returns the correct covariance matrix", {
  vector_sd <- c(1,2, 3)
  names(vector_sd) <- c("varA", "varB", "varC")
  vector_covar <- c(8, 12, 24)
  names(vector_covar) <- c("varA.varB", "varA.varC", "varB.varC")
  covMatrix <- getCovarianceMatrix(vector_sd, vector_covar)
  
  expect_is(covMatrix, "matrix")
  expect_equal(dim(covMatrix), c(3, 3))
  expected_matrix <- matrix(c(1,8,12,8,4,24, 12,24,9), nrow = 3,  byrow = T)
  rownames(expected_matrix) <- c("label_varA", "label_varB", "label_varC")
  colnames(expected_matrix) <- c("label_varA", "label_varB", "label_varC")
  expect_equal(expected_matrix, covMatrix)
})

test_that("getGeneMetadata returns the correct metadata", {
  list_var <- init_variable()
  n_genes <- 10
  metadata <- getGeneMetadata(list_var, n_genes)
  expect_is(metadata, "data.frame")
  expect_equal(colnames(metadata), c("geneID", paste("label", (attributes(list_var)$names), sep ="_")))
  expect_equal(nrow(metadata), n_genes * list_var$myVariable$level)
})

test_that("getDataFromMvrnorm returns the correct data", {
  list_var <- init_variable(name = "varA", mu = 1, sd = 4, level = 3) %>% init_variable("varB", mu = 2, sd = 1, level = 2)
  input <- getInput2mvrnorm(list_var)
  n_genes <- 10
  n_samplings <- n_genes * (list_var$varA$level ) * (list_var$varB$level )
  data <- getDataFromMvrnorm(list_var, input, n_genes)
  expect_is(data, "list")
  expect_equal(length(data), 1)
  expect_is(data[[1]], "data.frame")
  expect_equal(nrow(data[[1]]), n_samplings)
  
})

test_that("getDataFromMvrnomr returns empty list",{
  list_var <- init_variable()
  input <- getInput2mvrnorm(list_var)
  n_genes <- 10
  n_samplings <- n_genes * (list_var$varA$level ) * (list_var$varB$level )
  data <- getDataFromMvrnorm(list_var, input, n_genes)
  expect_is(data, "list")
  expect_equal(colnames(data[[1]]), c("geneID","label_myVariable" ,"myVariable"))
})

test_that("samplingFromMvrnorm returns the correct sampling", {
  n_samplings <- 100
  l_mu <- c(1, 2)
  matx_cov <- matrix(c(1, 0.5, 0.5, 1), ncol = 2)
  sampling <- samplingFromMvrnorm(n_samplings, l_mu, matx_cov)
  
  expect_is(sampling, "matrix")
  expect_equal(dim(sampling), c(n_samplings, length(l_mu)))
})


```

```{r function-dataFromUser, filename = "datafromUser_manipulations"}

#' Get data from user
#'
#'
#' @param list_var A list of variables (already initialized)
#' @return A list of data to join
#' @export
#'
#' @examples
#' getDataFromUser(init_variable())
getDataFromUser <- function(list_var) {
  variable_data2join <- getGivenAttribute(list_var, "data")
  id_var2join <- names(variable_data2join)
  interaction_data2join <- getGivenAttribute(list_var$interactions, "data")
  id_interaction2join <- names(interaction_data2join)
  
  data2join <- list(variable_data2join, interaction_data2join) %>%
    unlist(recursive = FALSE)
  id2join <- c(id_var2join, id_interaction2join)
  l_data2join <- lapply(id2join, function(id) data2join[[id]])
  
  return(l_data2join)
}

```

```{r test-dataFromUser}
# Test unitaires pour la fonction join_dtf
test_that("join_dtf réalise la jointure correctement", {
  # Création de données de test
  df1 <- data.frame(id = 1:5, value = letters[1:5])
  df2 <- data.frame(id = 1:5, category = LETTERS[1:5])
  
  # Exécution de la fonction
  result <- join_dtf(df1, df2, "id", "id")
  
  # Vérification des résultats
  expect_true(is.data.frame(result))
  expect_equal(nrow(result), 5)
  expect_equal(ncol(result), 3)
  expect_equal(names(result), c("id", "value", "category"))
  expect_true(all.equal(result$id, df1$id))
  expect_true(all.equal(result$id, df2$id))
})


# Test unitaires pour la fonction getDataFromUser
test_that("getDataFromUser renvoie les données appropriées", {
  # Exécution de la fonction
  list_var <- init_variable()
  list_var <- init_variable(list_var, "second_var")
  result <- getDataFromUser(list_var)
  
  # Vérification des résultats
  expect_true(is.list(result))
  expect_equal(length(result), 0)
  
  
  list_var <- init_variable(mu = c(1,2,3), sd = NA)
  list_var <- init_variable(list_var, "second_var")
  result <- getDataFromUser(list_var)
  expect_true(all(sapply(result, is.data.frame)))
  expect_equal(names(result[[1]]), c("label_myVariable", "myVariable"))
})
```

```{r function-setcorrelation, filename =  "setcorrelation"}

#' Compute Covariation from Correlation and Standard Deviations
#'
#' This function computes the covariation between two variables (A and B) given their correlation and standard deviations.
#'
#' @param corr_AB The correlation coefficient between variables A and B.
#' @param sd_A The standard deviation of variable A.
#' @param sd_B The standard deviation of variable B.
#'
#' @return The covariation between variables A and B.
#' @export
#' @examples
#' corr <- 0.7
#' sd_A <- 3
#' sd_B <- 4
#' compute_covariation(corr, sd_A, sd_B)
compute_covariation <- function(corr_AB, sd_A, sd_B) {
  cov_AB <- corr_AB * sd_A * sd_B
  return(cov_AB)
}


#' Get Standard Deviations for Variables in Correlation
#'
#' This function extracts the standard deviations for the variables involved in the correlation.
#'
#' @param list_var A list containing the variables and their attributes.
#' @param between_var A character vector containing the names of the variables involved in the correlation.
#'
#' @return A numeric vector containing the standard deviations for the variables in the correlation.
#' @export
#' @examples
#' list_var <- init_variable(name = "varA", mu = 0, sd = 5, level = 3) %>%
#'          init_variable(name = "varB", mu = 0, sd = 25, level = 3)
#' between_var <- c("varA", "varB")
#' getStandardDeviationInCorrelation(list_var, between_var)
getStandardDeviationInCorrelation <- function(list_var, between_var){
  for (var in between_var) sd_List <- getGivenAttribute(list_var, "sd")
  for (var in between_var) sd_ListFromInteraction <- getGivenAttribute(list_var$interactions, "sd")
  sd_List <- c(sd_List, sd_ListFromInteraction)
  return(unname(unlist(sd_List[between_var])))
}



#' Set Correlation between Variables
#'
#' Set the correlation between two or more variables in a simulation.
#'
#' @param list_var A list containing the variables used in the simulation, initialized using \code{\link{init_variable}}.
#' @param between_var Character vector specifying the names of the variables to set the correlation between.
#' @param corr Numeric value specifying the desired correlation between the variables.
#'
#' @return Updated \code{list_var} with the specified correlation set between the variables.
#'
#' @details The function checks if the variables specified in \code{between_var} are declared and initialized in the \code{list_var}. It also ensures that at least two variables with provided standard deviation are required to set a correlation in the simulation.
#' The specified correlation value must be within the range (-1, 1). The function computes the corresponding covariance between the variables based on the specified correlation and standard deviations.
#' The correlation information is then added to the \code{list_var} in the form of a data frame containing the correlation value and the corresponding covariance value.
#' @export
#' @examples
#' list_var <- init_variable(name = "varA", mu = 0, sd = 5, level = 3) %>%
#'             init_variable(name = "varB", mu = 0, sd = 25, level = 3)
#' list_var <- set_correlation(list_var, between_var = c("varA", "varB"), corr = 0.7)
set_correlation <- function(list_var, between_var, corr) {

  # Check if variables in between_var are declared and initialized
  bool_checkBetweenVarValidity <- function(between_var, list_var) {
    nb_varInCorrelation <- length(unique(between_var))
    stopifnot(nb_varInCorrelation > 1)
    # -- check also for interaction
    varInitialized <- c(getListVar(list_var), getListVar(list_var$interactions))
    existingVar_nb <- varInitialized  %in% between_var %>% sum()
    if (existingVar_nb != nb_varInCorrelation) {
      return(FALSE)
    } else {
      return(TRUE)
    }
  }
  
  name_correlation <- paste(between_var, collapse = ".")
  bool_valid_corr <- bool_checkBetweenVarValidity(between_var, list_var)
  if (!bool_valid_corr) {
    stop("At least one variable in between_var is not declared. Variable not initialized cannot be used in a correlation.")
  }
  
  vec_standardDev <- getStandardDeviationInCorrelation(list_var, between_var)
  if (length(vec_standardDev) < 2) {
    stop("Exactly two variables with provided standard deviation are required to set a correlation in simulation.")
  }
  # Validate the specified correlation value to be within the range [-1, 1]
  if (corr < -1 || corr > 1) {
    stop("Invalid correlation value. Correlation must be in the range [-1, 1].")
  }
  
  name_interaction <- paste(between_var, collapse = ":")
  corr <- data.frame(cor = corr, covar = compute_covariation(corr, vec_standardDev[1], vec_standardDev[2] ))
  list_var$correlations[[name_correlation]] <- corr
  return(list_var)
}


```

```{r  test-setcorrelation}

test_that("compute_covariation returns the correct covariation", {
  # Test case 1: Positive correlation
  corr <- 0.7
  sd_A <- 3
  sd_B <- 4
  expected_cov <- corr * sd_A * sd_B
  actual_cov <- compute_covariation(corr, sd_A, sd_B)
  expect_equal(actual_cov, expected_cov)

  # Test case 2: Negative correlation
  corr <- -0.5
  sd_A <- 2.5
  sd_B <- 3.5
  expected_cov <- corr * sd_A * sd_B
  actual_cov <- compute_covariation(corr, sd_A, sd_B)
  expect_equal(actual_cov, expected_cov)

  # Test case 3: Zero correlation
  corr <- 0
  sd_A <- 1
  sd_B <- 2
  expected_cov <- corr * sd_A * sd_B
  actual_cov <- compute_covariation(corr, sd_A, sd_B)
  expect_equal(actual_cov, expected_cov)
})


# Unit tests for getStandardDeviationInCorrelation
test_that("getStandardDeviationInCorrelation returns correct standard deviations", {
  
  # Initialize list_var
  list_var <- init_variable(name = "varA", mu = 0, sd = 5, level = 3) %>%
              init_variable(name = "varB", mu = 0, sd = 25, level = 3)
  
  # Test case 1: Two variables correlation
  between_var_1 <- c("varA", "varB")
  sd_expected_1 <- c(5, 25)
  sd_result_1 <- getStandardDeviationInCorrelation(list_var, between_var_1)
  expect_equal(sd_result_1, sd_expected_1)
  
})



test_that("set_correlation sets the correlation between variables correctly", {
  # Initialize variables in the list_var
  list_var <- init_variable(name = "varA", mu = 0, sd = 5, level = 3) %>%
              init_variable(name = "varB", mu = 0, sd = 25, level = 3)

  # Test setting correlation between varA and varB
  list_var <- set_correlation(list_var, between_var = c("varA", "varB"), corr = 0.7)
  
  corr_result <- list_var$correlations$varA.varB$cor
  covar_result <- list_var$correlations$varA.varB$covar
  expect_equal(corr_result, 0.7)
  expect_equal(covar_result, 87.5)

  # Test setting correlation between varA and varC (should raise an error)
  expect_error(set_correlation(list_var, between_var = c("varA", "varC"), corr = 0.8),
               "At least one variable in between_var is not declared. Variable not initialized cannot be used in a correlation.")

  # Test setting correlation with invalid correlation value
  expect_error(set_correlation(list_var, between_var = c("varA", "varB"), corr = 1.5))

  # Test setting correlation with less than 2 variables with provided standard deviation
  expect_error(set_correlation(list_var, between_var = c("varA"), corr = 0.7))
})


```

```{r function-simulation , filename = "simulation"}

#' Get input for simulation based on coefficients
#'
#' This function generates input data for simulation based on the coefficients provided in the \code{list_var} argument.
#'
#' @param list_var A list of variables (already initialized)
#' @param n_genes Number of genes to simulate (default: 1)
#' @param normal_distr Specifies the distribution type for generating effects. Choose between 'univariate' (default) or 'multivariate' .
#' - 'univariate': Effects are drawn independently from univariate normal distributions. 
#' - 'multivariate': Effects are drawn jointly from a multivariate normal distribution. (not recommended)
#' @param input2mvrnorm Input to the \code{mvrnorm} function for simulating data from multivariate normal distribution (default: NULL)
#' @return A data frame with input coefficients for simulation
#' @export
#' @examples
#' # Example usage
#' list_var <- init_variable()
#' getInput2simulation(list_var, n_genes = 10)
getInput2simulation <- function(list_var, n_genes = 1, normal_distr = "univariate",  input2mvrnorm = NULL) {
  
  stopifnot( normal_distr %in% c("multivariate", "univariate") )

  if (normal_distr == "multivariate"){
      if (is.null(input2mvrnorm)) input2mvrnorm = getInput2mvrnorm(list_var)    
      l_dataFrom_normdistr <- getDataFromMvrnorm(list_var, input2mvrnorm, n_genes)
    } 
  if (normal_distr == "univariate"){
      l_dataFrom_normdistr <- getDataFromRnorm(list_var, n_genes)
    }
  
  l_dataFromUser = getDataFromUser(list_var)
  
  df_input2simu <- getCoefficients(list_var, l_dataFrom_normdistr, l_dataFromUser, n_genes)
  
  return(df_input2simu)
}



#' Get the reference level for categorical variables in the data
#'
#' This function extracts the reference level for each categorical variable in the data.
#' The reference level is the first level encountered for each categorical variable.
#'
#' @param data The data frame containing the categorical variables.
#' @return A list containing the reference level for each categorical variable.
#' @export
getRefLevel <- function(data){
  col_names <- colnames(data)
  categorical_vars <- col_names[grepl(col_names, pattern = "label_")]
  if (length(categorical_vars) == 1){
    l_labels <- list()
    l_labels[[categorical_vars]] <- levels(data[, categorical_vars])
    
  } else l_labels <- lapply(data[, categorical_vars], levels)
  l_labels_ref <- sapply(l_labels, function(vec) vec[1])
  return(l_labels_ref)
}

#' Replace the effect by 0 in the data
#'
#' This function replaces the effect in interactions columns by 0, when needed.
#'
#' @param list_var The list of variables containing the effects to modify.
#' @param l_labels_ref A list containing the reference level for each categorical variable.
#' @param data The data frame containing the effects to modify.
#' @return The modified data frame 
#' @export
replaceUnexpectedInteractionValuesBy0 <- function(list_var, l_labels_ref , data){
  varInteraction <- getListVar(list_var$interactions)
  df_interaction_with0 <- sapply(varInteraction, function(var){
    categorical_var <- paste("label", unlist(strsplit(var, ":")), sep = "_")
    bool_matrix <- sapply(categorical_var, function(uniq_cat_var) data[uniq_cat_var] ==  l_labels_ref[uniq_cat_var])
    idx_0 <- rowSums(bool_matrix) > 0 ## line without interactions effects
    return(replace(data[[var]], idx_0, 0)) 
  })
  col_names <- colnames(data)
  categorical_vars <- col_names[grepl(col_names, pattern = "label_")]
  data[, varInteraction] <- df_interaction_with0
  return(data)
}




#' Prepare data using effects from a normal distribution
#'
#' Prepares the data by generating effects from a normal distribution for each gene.
#'
#' @param list_var A list of variables (already initialized)
#' @param n_genes Number of genes to generate data for.
#' @return A dataframe containing gene metadata and effects generated from a normal distribution.
#' @export
getDataFromRnorm <- function(list_var, n_genes){
    ## -- check if all data have been provided by user
    if (is.null(getInput2mvrnorm(list_var)$covMatrix))
        return(list())
    metadata <- getGeneMetadata(list_var , n_genes)
    df_effects <- get_effects_from_rnorm(list_var, metadata)
    data <- cbind(metadata, df_effects)  
    
    if(!is.null(getListVar(list_var$interactions))){
      l_labels_ref <- getRefLevel(data)
      data <- replaceUnexpectedInteractionValuesBy0(list_var, l_labels_ref, data)
    }
    
    return(list(data))
}

#' Generate effects from a normal distribution
#'
#' Generates effects from a normal distribution for each gene.
#'
#' @param list_var A list of variables (already initialized)
#' @param metadata Gene metadata.
#' @return A dataframe containing effects generated from a normal distribution.
#' @export
get_effects_from_rnorm <- function(list_var, metadata){
  
  variable_standard_dev <- getGivenAttribute(list_var, attribute = "sd") %>% unlist()
  interaction_standard_dev <- getGivenAttribute(list_var$interactions, attribute = "sd") %>% unlist()
  list_stdev <- c(variable_standard_dev, interaction_standard_dev)
  
  # -- mu
  variable_mu <- getGivenAttribute(list_var, attribute = "mu") %>% unlist()
  interaction_mu <- getGivenAttribute(list_var$interactions, attribute = "mu") %>% unlist()
  list_mu <- c(variable_mu, interaction_mu)
  
  variable_2rnorm <- names(list_stdev)
  l_effects <- lapply(stats::setNames(variable_2rnorm, variable_2rnorm) , function(var){
    col_labels <- paste("label", unlist(strsplit(var, ":")), sep = "_")
    cols2paste <- c("geneID", col_labels)
    list_combinations <- apply( metadata[ , cols2paste ] , 1 , paste , collapse = "-" )
    list_effects <- unique(list_combinations)
    list_beta <- rnorm(length(list_effects), mean = list_mu[var], sd = list_stdev[var])
    names(list_beta) <- list_effects
    unname(list_beta[list_combinations])
  })
       
  
  df_effects <- do.call("cbind", l_effects)
  return(df_effects)
}




#' getCoefficients
#'
#' Get the coefficients.
#'
#' @param list_var A list of variables (already initialized)
#' @param l_dataFromMvrnorm Data from the `getGeneMetadata` function (optional).
#' @param l_dataFromUser Data from the `getDataFromUser` function (optional).
#' @param n_genes The number of genes.
#' @export
#' @return A dataframe containing the coefficients.
#' @examples
#' # Example usage
#' list_var <- init_variable()
#' input2mvrnorm = getInput2mvrnorm(list_var)
#' l_dataFromMvrnorm = getDataFromMvrnorm(list_var, input2mvrnorm, n_genes=3)
#' l_dataFromUser = getDataFromUser(list_var)
#' getCoefficients(list_var, l_dataFromMvrnorm, l_dataFromUser, n_genes = 3)
getCoefficients <- function(list_var, l_dataFromMvrnorm, l_dataFromUser, n_genes) {
  if (length(l_dataFromMvrnorm) == 0) {
    metaData <- getGeneMetadata(list_var, n_genes)
    l_dataFromMvrnorm <- list(metaData)
  }
  l_df2join <- c(l_dataFromMvrnorm, l_dataFromUser)
  
  
  df_coef <- Reduce(function(d1, d2){ column_names = colnames(d2)
                                      idx_key = grepl(pattern = "label", column_names )
                                      keys = column_names[idx_key]
                                      join_dtf(d1, d2, k1 = keys , k2 = keys)
                                    } 
                    , l_df2join ) %>% as.data.frame()
  column_names <- colnames(df_coef)
  idx_column2factor <- grep(pattern = "label_", column_names)
  
  if (length(idx_column2factor) > 1) {
    df_coef[, idx_column2factor] <- lapply(df_coef[, idx_column2factor], as.factor)
  } else {
    df_coef[, idx_column2factor] <- as.factor(df_coef[, idx_column2factor])
  }
  
  return(df_coef)
}


#' Get the log_qij values from the coefficient data frame.
#'
#' @param dtf_coef The coefficient data frame.
#' @return The coefficient data frame with log_qij column added.
#' @export
#' @examples
#' list_var <- init_variable()
#' dtf_coef <- getInput2simulation(list_var, 10)
#' dtf_coef <- getLog_qij(dtf_coef)
getLog_qij <- function(dtf_coef) {
  dtf_beta_numeric <- dtf_coef[sapply(dtf_coef, is.numeric)]
  dtf_coef$log_qij <- rowSums(dtf_beta_numeric, na.rm = TRUE)
  return(dtf_coef)
}


#' Calculate mu_ij values based on coefficient data frame and scaling factor
#'
#' This function calculates mu_ij values by raising 2 to the power of the log_qij values
#' from the coefficient data frame and multiplying it by the provided scaling factor.
#'
#' @param dtf_coef Coefficient data frame containing the log_qij values
#'
#' @return Coefficient data frame with an additional mu_ij column
#'
#' @examples
#' list_var <- init_variable()
#' dtf_coef <- getInput2simulation(list_var, 10)
#' dtf_coef <- getLog_qij(dtf_coef)
#' dtf_coef <- addBasalExpression(dtf_coef, 10, c(10, 20, 0))
#' getMu_ij(dtf_coef)
#' @export
getMu_ij <- function(dtf_coef) {
  log_qij_scaled <- dtf_coef$log_qij + dtf_coef$basalExpr
  dtf_coef$log_qij_scaled <- log_qij_scaled
  mu_ij <- exp(log_qij_scaled)  
  dtf_coef$mu_ij <- mu_ij
  return(dtf_coef)
}

#' getMu_ij_matrix
#'
#' Get the Mu_ij matrix.
#'
#' @param dtf_coef A dataframe containing the coefficients.
#' @importFrom reshape2 dcast
#' @importFrom stats as.formula

#' @export
#' @return A Mu_ij matrix.
#' @examples
#' list_var <- init_variable()
#' dtf_coef <- getInput2simulation(list_var, 10)
#' dtf_coef <- getLog_qij(dtf_coef)
#' dtf_coef <- addBasalExpression(dtf_coef, 10, c(10, 20, 0))
#' dtf_coef<- getMu_ij(dtf_coef)
#' getMu_ij_matrix(dtf_coef)
getMu_ij_matrix <- function(dtf_coef) {
  column_names <- colnames(dtf_coef)
  idx_var <- grepl(pattern = "label", column_names)
  l_var <- column_names[idx_var]
  str_formula_rigth <- paste(l_var, collapse = " + ")
  if (str_formula_rigth == "") stop("no variable label detected")
  str_formula <- paste(c("geneID", str_formula_rigth), collapse = " ~ ")
  formula <- stats::as.formula(str_formula)
  dtf_Muij <- dtf_coef %>% reshape2::dcast(formula = formula, value.var = "mu_ij", drop = F)
  dtf_Muij[is.na(dtf_Muij)] <- 0
  mtx_Muij <- data.frame(dtf_Muij[, -1], row.names = dtf_Muij[, 1]) %>% as.matrix()
  mtx_Muij <- mtx_Muij[, order(colnames(mtx_Muij)), drop = F]
  return(mtx_Muij)
}

#' getSubCountsTable
#'
#' Get the subcounts table.
#'
#' @param matx_Muij The Mu_ij matrix.
#' @param matx_dispersion The dispersion matrix.
#' @param replicateID The replication identifier.
#' @param l_bool_replication A boolean vector indicating the replicates.
#' @importFrom stats rnbinom
#' 
#' @return A subcounts table.
getSubCountsTable <- function(matx_Muij, matx_dispersion, replicateID, l_bool_replication) {
  getKijMatrix <- function(matx_Muij, matx_dispersion, n_genes, n_samples) {
    k_ij <- stats::rnbinom(n_genes * n_samples,
                           size = matx_dispersion,
                           mu = matx_Muij) %>%
              matrix(nrow = n_genes, ncol = n_samples)
    
    k_ij[is.na(k_ij)] <- 0
    return(k_ij)
  }
  
  if (!any(l_bool_replication))
    return(NULL) 
  
  matx_Muij <- matx_Muij[, l_bool_replication, drop = FALSE]
  matx_dispersion <- matx_dispersion[, l_bool_replication, drop = FALSE] 
  l_sampleID <- colnames(matx_Muij)
  l_geneID <- rownames(matx_Muij)
  dimension_mtx <- dim(matx_Muij)
  n_genes <- dimension_mtx[1]
  n_samples <- dimension_mtx[2]
  matx_kij <- getKijMatrix(matx_Muij, matx_dispersion, n_genes, n_samples)
  colnames(matx_kij) <- paste(l_sampleID, replicateID, sep = "_")
  rownames(matx_kij) <- l_geneID
  return(matx_kij)
}

#' getReplicationMatrix
#'
#' @param minN Minimum number of replicates for each sample
#' @param maxN Maximum number of replicates for each sample
#' @param n_samples Number of samples
#' @export
#' @return A replication matrix indicating which samples are replicated
getReplicationMatrix <- function(minN, maxN, n_samples) {
  
  # Create a list of logical vectors representing the minimum number of replicates
  l_replication_minimum = lapply(1:n_samples, 
                                 FUN = function(i) rep(TRUE, times = minN) )
  
  # Create a list of random logical vectors representing additional replicates
  l_replication_random = lapply(1:n_samples, 
                                FUN = function(i) sample(x = c(TRUE, FALSE), size = maxN-minN, replace = T) )
  
  # Combine the replication vectors into matrices
  matx_replication_minimum <- do.call(cbind, l_replication_minimum)
  matx_replication_random <- do.call(cbind, l_replication_random)
  
  # Combine the minimum replicates and random replicates into a single matrix
  matx_replication <- rbind(matx_replication_minimum, matx_replication_random)
  
  # Sort the columns of the replication matrix in descending order
  matx_replication = apply(matx_replication, 2, sort, decreasing = TRUE ) %>% matrix(nrow = maxN)
  
  return(matx_replication)
}

#' getCountsTable
#'
#' @param matx_Muij Matrix of mean expression values for each gene and sample
#' @param matx_dispersion Matrix of dispersion values for each gene and sample
#' @param matx_bool_replication Replication matrix indicating which samples are replicated
#'
#' @return A counts table containing simulated read counts for each gene and sample
getCountsTable <- function(matx_Muij ,  matx_dispersion, matx_bool_replication ){
  max_replicates <-  dim(matx_bool_replication)[1]
  
  # Apply the getSubCountsTable function to each row of the replication matrix
  l_countsTable = lapply(1:max_replicates, function(i) getSubCountsTable(matx_Muij , matx_dispersion, i, matx_bool_replication[i,]  ))
  
  # Combine the counts tables into a single matrix
  countsTable = do.call(cbind, l_countsTable)
  
  return(countsTable %>% as.data.frame())
}

#' getDispersionMatrix
#'
#' @param list_var A list of variables (already initialized)
#' @param n_genes Number of genes
#' @param dispersion Vector of dispersion values for each gene
#' @export
#'
#' @return A matrix of dispersion values for each gene and sample
getDispersionMatrix <- function(list_var, n_genes, dispersion = stats::runif(n_genes, min = 0, max = 1000)){
  l_geneID = paste("gene", 1:n_genes, sep = "")
  l_sampleID = getSampleID(list_var) 
  n_samples = length(l_sampleID) 
  l_dispersion <- dispersion
  
  # Create a data frame for the dispersion values
  dtf_dispersion = list(dispersion =  l_dispersion) %>% as.data.frame()
  dtf_dispersion <- dtf_dispersion[, rep("dispersion", n_samples)]
  rownames(dtf_dispersion) = l_geneID
  colnames(dtf_dispersion) = l_sampleID
  
  matx_dispersion = dtf_dispersion %>% as.matrix()
  
  return(matx_dispersion)
}





#' Replicate rows of a data frame by group
#'
#' Replicates the rows of a data frame based on a grouping variable and replication counts for each group.
#'
#' @param df Data frame to replicate
#' @param group_var Name of the grouping variable in the data frame
#' @param rep_list Vector of replication counts for each group
#' @return Data frame with replicated rows
#' @examples
#' df <- data.frame(group = c("A", "B"), value = c(1, 2))
#' replicateByGroup(df, "group", c(2, 3))
#'
#' @export
replicateByGroup <- function(df, group_var, rep_list) {
  l_group_var <- df[[group_var]]
  group_levels <- unique(l_group_var)
  names(rep_list) <- group_levels
  group_indices <- rep_list[l_group_var]
  replicated_indices <- rep(seq_len(nrow(df)), times = group_indices)
  replicated_df <- df[replicated_indices, ]
  suffix_sampleID <- sequence(group_indices)
  replicated_df[["sampleID"]] <- paste(replicated_df[["sampleID"]], suffix_sampleID, sep = "_")
  rownames(replicated_df) <- NULL
  return(replicated_df)
}



#' Replicate rows of a data frame
#'
#' Replicates the rows of a data frame by a specified factor.
#'
#' @param df Data frame to replicate
#' @param n Replication factor for each row
#' @return Data frame with replicated rows
#' @export
#' @examples
#' df <- data.frame(a = 1:3, b = letters[1:3])
#' replicateRows(df, 2)
#'
replicateRows <- function(df, n) {
  indices <- rep(seq_len(nrow(df)), each = n)
  replicated_df <- df[indices, , drop = FALSE]
  rownames(replicated_df) <- NULL
  return(replicated_df)
}

#' Get sample metadata
#'
#' Generates sample metadata based on the input variables, replication matrix, and number of genes.
#'
#' @param list_var A list of variables (already initialized)
#' @param replicationMatrix Replication matrix
#' @param n_genes Number of genes
#' @return Data frame of sample metadata
#' @importFrom data.table setorderv
#' @export
#' @examples
#' list_var <- init_variable()
#' replicationMatrix <- generateReplicationMatrix(list_var, 3, 3)
#' getSampleMetadata(list_var,  replicationMatrix)
getSampleMetadata <- function(list_var, replicationMatrix) {
  l_sampleIDs = getSampleID(list_var)
  metaData <- generateGridCombination_fromListVar(list_var)
  metaData[] <- lapply(metaData, as.character) ## before reordering
  data.table::setorderv(metaData, cols = colnames(metaData))
  metaData[] <- lapply(metaData, as.factor)
  metaData$sampleID <- l_sampleIDs
  rep_list <- colSums(replicationMatrix)
  metaData$sampleID <- as.character(metaData$sampleID) ## before replicating
  sampleMetadata <- replicateByGroup(metaData, "sampleID", rep_list)
  colnames(sampleMetadata) <- gsub("label_", "", colnames(sampleMetadata))
  return(sampleMetadata)
}


#' getSampleID
#'
#' @param list_var A list of variables (already initialized)
#' @export
#' @return A sorted vector of sample IDs
#' @examples
#' getSampleID(init_variable())
getSampleID <- function(list_var){
  gridCombination <- generateGridCombination_fromListVar(list_var)
  l_sampleID <- apply( gridCombination , 1 , paste , collapse = "_" ) %>% unname()
  return(sort(l_sampleID))
}

```

```{r test-simulation}

test_that("getDataFromRnorm generates correct data frame", {
  input_var_list <- init_variable(name = "varA", mu = 10, sd = 0.1, level = 3) %>%
                    init_variable(name = "varB", mu = 1, sd = 2, level = 2)
  metadata <- getGeneMetadata(input_var_list , n_genes = 5)
  df <- getDataFromRnorm(input_var_list, n_genes = 5)
  expect_is(df[[1]], "data.frame")
  expect_equal(nrow(df[[1]]), 30)
  expect_equal(colnames(df[[1]]), c("geneID", "label_varA", "label_varB", "varA", "varB"))  
})

test_that("get_effects_from_rnorm generates correct effects", {
  input_var_list <- init_variable(name = "varA", mu = 10, sd = 0.1, level = 3) %>%
                    init_variable(name = "varB", mu = 1, sd = 2, level = 2)
  metadata <- getGeneMetadata(input_var_list , n_genes = 5)
  df_effects <- get_effects_from_rnorm(input_var_list, metadata)
  
  expect_is(df_effects, "matrix")
  expect_equal(nrow(df_effects), nrow(metadata))
  expect_equal(colnames(df_effects), c("varA", "varB"))
})


# Test case 1: Check if the function returns a data frame
test_that("getInput2simulation returns a data frame", {
  list_var <- init_variable()
  set.seed(101)
  result <- getInput2simulation(list_var, normal_distr = 'multivariate')
  expect_is(result, "data.frame")
  expected <- data.frame(geneID = c("gene1", "gene1"), label_myVariable = as.factor(c("myVariable1", "myVariable2")), 
                         myVariable = c(-0.1414214,0.1414214))
  expect_equal(result, expected, tolerance = 1e-3)
  })

# Test for getCoefficients function
test_that("getCoefficients returns the correct output", {
  # Create dummy data
  n_genes <- 3
  list_var = init_variable()
  # Call the function
  coefficients <- getCoefficients(list_var, list(), list(), n_genes)
  
  # Check the output
  expect_equal(nrow(coefficients), n_genes*list_var$myVariable$level)
  expect_equal(colnames(coefficients), c("geneID", "label_myVariable")) 
})

# Test for getMu_ij_matrix function
test_that("getMu_ij_matrix returns the correct output", {
  # Create a dummy coefficients dataframe
  dtf_coef <- data.frame(geneID = c("Gene1", "Gene1", "Gene1"),
                         label_varA = c("A1", "A2", "A3"),
                         label_varB = c("B1", "B2", "B3"),
                         mu_ij = c(1, 2, 3))
  
  # Call the function
  mu_matrix <- getMu_ij_matrix(dtf_coef)
  # Check the output
  expect_equal(dim(mu_matrix), c(1, 9)) 
  
})

# Test for getSubCountsTable function
test_that("getSubCountsTable returns the correct output", {
  # Create dummy data
  l_genes <- c("gene1", "gene2", "gene3")
  matx_Muij = data.frame(sple1 = c(1,3,4), sple2 = c(2, 0, 9), sple3 = c(1, 69, 2)) %>% as.matrix()
  rownames(matx_Muij) <- l_genes
  matx_dispersion <- matrix(0.5, nrow = 3, ncol = 3)
  replicateID <- 1
  l_bool_replication <- c(TRUE, FALSE, TRUE)
  
  # Call the function
  subcounts_table <- getSubCountsTable(matx_Muij, matx_dispersion, 1, l_bool_replication)
  
  # Check the output
  expect_equal(dim(subcounts_table), c(3, 2))
  expect_equal(rownames(subcounts_table), l_genes)
})


test_that("getReplicationMatrix returns the correct replication matrix", {
  minN <- 2
  maxN <- 4
  n_samples <- 3
  expected <- matrix(c(TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, FALSE, TRUE, TRUE, TRUE, FALSE), nrow = maxN)
  
  set.seed(123)
  result <- getReplicationMatrix(minN, maxN, n_samples)
  
  expect_equal(result, expected)
})

test_that("getSampleID return the correct list of sampleID",{
   expect_equal(getSampleID(init_variable()), c("myVariable1", "myVariable2"))
})

# Create a test case for getMu_ij
test_that("getMu_ij returns the correct output", {
  # Create a sample coefficient data frame
  dtf_coef <- data.frame(
    log_qij = c(1, 9, 0.1),
    basalExpr = c(2, 3, 4)
  )

    # Call the getMu_ij function
  result <- getMu_ij(dtf_coef)

  # Check if the mu_ij column is added
  expect_true("mu_ij" %in% colnames(result))

  # Check the values of mu_ij
  #expected_mu_ij <- c(20.08554, 162754.79142 , 60.34029)
  #expect_equal(result$mu_ij, expected_mu_ij, tolerance = 0.000001)
})


# Create a test case for getLog_qij
test_that("getLog_qij returns the correct output", {
  # Create a sample coefficient data frame
  dtf_coef <- data.frame(
    beta1 = c(1.2, 2.3, 3.4),
    beta2 = c(0.5, 1.0, 1.5),
    non_numeric = c("a", "b", "c")
  )

  # Call the getLog_qij function
  result <- getLog_qij(dtf_coef)

  # Check if the log_qij column is added
  expect_true("log_qij" %in% colnames(result))

  # Check the values of log_qij
  expected_log_qij <- c(1.7, 3.3, 4.9)
  expect_equal(result$log_qij, expected_log_qij)
})

test_that("getCountsTable returns the correct counts table", {
  mat_mu_ij <- matrix(c(1,2,3,4,5,6), ncol = 3, byrow = T)
  rownames(mat_mu_ij) <- c("gene1", "gene2")
  colnames(mat_mu_ij) <- c("sample1", "sample2", "sample3")
  mat_disp <- matrix(c(0.3,0.3,0.3, 0.5,0.5,0.5), ncol = 3, byrow = T)
  rownames(mat_disp) <- c("gene1", "gene2")
  colnames(mat_disp) <- c("sample1", "sample2", "sample3")
  mat_repl <- matrix(c(TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE), ncol = 3, byrow = T)
  
  expected_df <- matrix(c(0,0,1,0,0,0,0,1,0,2,34,18,0,0,3,10,7,2), nrow = 2, byrow = T) %>% as.data.frame()
  rownames(expected_df) <- c("gene1", "gene2")
  colnames(expected_df) <- c("sample1_1", "sample2_1", "sample3_1", "sample1_2", 
                             "sample2_2","sample3_2","sample1_3", "sample2_3" ,"sample3_3")
  
  set.seed(123)
  result <- getCountsTable(mat_mu_ij, mat_disp, mat_repl)

  expect_true(is.data.frame(result))
  expect_equal(colnames(result), colnames(expected_df))
  expect_equal(rownames(result), rownames(expected_df))

})



test_that("getSampleMetadata returns expected output", {
  # Set up input variables
  list_var <- init_variable()
  replicationMatrix <- matrix(TRUE, nrow = 2, ncol = 2)

  # Run the function
  result <- getSampleMetadata(list_var,  replicationMatrix)
  
  # Define expected output
  expected_colnames <- c("myVariable", "sampleID")
  expect_equal(colnames(result), expected_colnames)
  
  # Check the output class
  expect_true(is.data.frame(result))
  
  # check nrow output
  expect_equal(nrow(result), 4)

})


test_that("replicateByGroup return the correct ouptut", {
  df <- data.frame(group = c("A", "B"), value = c(1, 2))
  result <- replicateByGroup(df, "group", c(2, 3))
  
  expect <- data.frame(group = c("A", "A", "B", "B", "B"), 
                       value = c(1, 1, 2,2,2), 
                       sampleID = c("_1", "_2", "_1", "_2", "_3" ))
  expect_equal(result, expect)

})


test_that("getDispersionMatrix returns the correct dispersion matrix", {
  n_genes = 3
  list_var = init_variable()
  dispersion <- 1:3
  expected <- matrix(1:3,byrow = F, nrow = 3, ncol = 2)
  rownames(expected) <- c("gene1", "gene2", "gene3")
  colnames(expected) <- c("myVariable1", "myVariable2")
  result <- getDispersionMatrix(list_var, n_genes, dispersion )
  expect_equal(result, expected)
})


# Test case: Valid input vector with numeric and positive values
test_that("Valid input vector with numeric and positive values", {
  input_vector <- c(0.5, 1.2, 0.8)
  result <- getValidDispersion(input_vector)
  expect_identical(result, input_vector)
})

# Test case: Valid input vector with numeric, positive, and negative values
test_that("Valid input vector with numeric, positive, and negative values", {
  input_vector <- c(0.5, -0.3, 1.2, 0.8)
  result <- getValidDispersion(input_vector)
  expect_identical(result, c(0.5, 1.2, 0.8))
})

# Test case: Valid input vector with non-numeric elements
test_that("Valid input vector with non-numeric elements", {
  input_vector <- c(0.5, "invalid", 0.8)
  result <- getValidDispersion(input_vector)
  expect_identical(result, c(0.5, 0.8))
})

# Test case: Empty input vector
test_that("Empty input vector", {
  input_vector <- numeric(0)
  expect_error(getValidDispersion(input_vector), "Invalid dispersion values provided.")
})

# Test case: unique value in vector
test_that("unique value in vector", {
  input_vector <- 5
  expect_equal(getValidDispersion(input_vector), 5)
})

# Test case: All negative values
test_that("All negative values", {
  input_vector <- c(-0.5, -1.2, -0.8)
  expect_error(getValidDispersion(input_vector), "Invalid dispersion values provided.")
})





# Test for getRefLevel function
test_that("getRefLevel returns correct reference levels", {
  # Create a sample data frame
  data <- data.frame(
    label_genotype = factor(c("A", "B", "A", "B")),
    label_environment = factor(c("X", "Y", "X", "Y"))
  )
  
  # Expected reference levels
  expected_ref_levels <- list(
    label_genotype = "A",
    label_environment = "X"
  ) %>% unlist()
  
  # Test getRefLevel function
  ref_levels <- getRefLevel(data)
  
  # Check if reference levels match the expected ones
  expect_identical(ref_levels, expected_ref_levels)
})

# Test for replaceUnexpectedInteractionValuesBy0 function
test_that("replaceUnexpectedInteractionValuesBy0 replaces effects correctly", {
  
  input_var_list <- init_variable( name = "genotype", mu = 0, sd = 2.18, level = 2) %>%
                      init_variable( name = "env", mu = 0, sd = 0.57, level = 4 ) %>%
                      init_variable(name = "T", mu = 0, sd = 0.39, level = 3) %>% 
                      add_interaction(between_var = c("genotype", "env"), mu = 0 , sd = 1.018) %>%
                      add_interaction(between_var = c("genotype", "env", "T"), mu = 0 , sd = 1.018)
  
  
  N_GENES <- 10
  l_dataFrom_normdistr <- getDataFromRnorm(input_var_list, N_GENES)
  metadata <- getGeneMetadata(input_var_list , N_GENES)
  set.seed(101)
  df_effects <- get_effects_from_rnorm(input_var_list, metadata)
  data <- cbind(metadata, df_effects) 
  
  l_labels_ref <- getRefLevel(data)
  data <- replaceUnexpectedInteractionValuesBy0(input_var_list, l_labels_ref, data)
  
  # Check if modified data matches the expected data
  expect_identical(colnames(data), c("geneID","label_genotype","label_env", 
                                     "label_T", "genotype", "env" , "T", "genotype:env"  ,"genotype:env:T"))
  expected_data <- data.frame(geneID = "gene1", label_genotype = "genotype1", label_env = "env1", label_T = "T1", 
                              genotype =  -0.7107595, env = -0.09334073, T = -0.1013228, "genotype:env" = 0, "genotype:env:T" = 0 )
  colnames(expected_data) <- c("geneID", "label_genotype", "label_env", "label_T", "genotype", 
                              "env", "T" , "genotype:env", "genotype:env:T" )
  expected_data$label_genotype <- factor(expected_data$label_genotype , levels = c("genotype1", "genotype2"))
  expected_data$label_env <- factor(expected_data$label_env , levels = c("env1", "env2", "env3", "env4"))
  expected_data$label_T <- factor(expected_data$label_T , levels = c("T1", "T2", "T3"))

  expect_equal(data[1,], expected_data ,tolerance = 0.0001 )

})



```


```{r function-mock , filename = "mock_rnaseq" }

#' Check the validity of the dispersion matrix
#'
#' Checks if the dispersion matrix has the correct dimensions.
#'
#' @param matx_dispersion Replication matrix
#' @param matx_bool_replication Replication matrix
#' @return TRUE if the dimensions are valid, FALSE otherwise
#' @export
#' @examples
#' matx_dispersion <- matrix(1:12, nrow = 3, ncol = 4)
#' matx_bool_replication <- matrix(TRUE, nrow = 3, ncol = 4)
#' is_dispersionMatrixValid(matx_dispersion, matx_bool_replication)
is_dispersionMatrixValid <- function(matx_dispersion, matx_bool_replication) {
  expected_nb_column <- dim(matx_bool_replication)[2]
  if (expected_nb_column != dim(matx_dispersion)[2]) {
    return(FALSE)
  }
  return(TRUE)
}

#' Generate count table
#'
#' Generates the count table based on the mu_ij matrix, dispersion matrix, and replication matrix.
#'
#' @param mu_ij_matx_rep Replicated mu_ij matrix
#' @param matx_dispersion_rep Replicated dispersion matrix
#' @return Count table
#' @export
#' @examples
#' mu_ij_matx_rep <- matrix(1:12, nrow = 3, ncol = 4)
#' matx_dispersion_rep <- matrix(1:12, nrow = 3, ncol = 4)
#' generateCountTable(mu_ij_matx_rep, matx_dispersion_rep)
generateCountTable <- function(mu_ij_matx_rep, matx_dispersion_rep) {
  message("k_ij ~ Nbinom(mu_ij, dispersion)")
  n_genes <- dim(mu_ij_matx_rep)[1]
  n_samples <- dim(mu_ij_matx_rep)[2]
  n_samplings <- prod(n_genes * n_samples)
  mat_countsTable <- rnbinom(n_samplings, 
                             size = matx_dispersion_rep, 
                             mu = mu_ij_matx_rep) %>%
                      matrix(nrow = n_genes, ncol = n_samples)
  colnames(mat_countsTable) <- colnames(mu_ij_matx_rep)
  rownames(mat_countsTable) <- rownames(mu_ij_matx_rep)
  mat_countsTable[is.na(mat_countsTable)] <- 0
  return(mat_countsTable)
}


#' Get messages related to sequencing depth
#'
#' This function generates informative messages regarding the scaling of counts by sequencing depth.
#'
#' @param scaling_factors Scaling factors obtained from sequencing depth
#' @param threshold_cov_var Threshold coefficient of variation to detect heterogeneity in scaling
#' @return NULL
#' @export
get_messages_sequencing_depth <- function(scaling_factors, threshold_cov_var = 1.5){
    message("Scaling count table according to sequencing depth: Done")
    message("INFO: Scaling counts by sequencing depth may exhibit some randomness due to certain parameter combinations, resulting in erratic behavior. This can be minimized by simulating more genes. We advise verifying the simulated sequencing depth to avoid drawing incorrect conclusions.\n")
    
    cov_var_sj <- sd(scaling_factors)/mean(scaling_factors)
    if (cov_var_sj > threshold_cov_var) 
        message("INFO: Heterogeneity in scaling by sequencing depth has been detected. Simulated effects may be distorted. This can occur when the number of genes is too low.\n")
  return(NULL)
}



#' Emit a warning message for rows with low mu_ij values
#'
#' This function emits a warning message if any rows in the mu_ij matrix have all values below a specified threshold.
#'
#' @param mu_ij_matrix Matrix of mu_ij values
#' @param threshold Threshold value
#' @return NULL
#' @export
warning_too_low_mu_ij_row <- function(mu_ij_matrix, threshold = 1 ){
  n_too_low_row <- length(which(detect_row_matx_bellow_threshold(mu_ij_matrix, threshold)))
  if (n_too_low_row > 0){
    msg <- paste("INFO:", n_too_low_row, "genes have all(mu_ij) < 1, indicating very low counts. Consider removing them for future analysis using prepareData2fit with row_threshold = 10. To detect them in future experiment, try increasing sequencing depth.\n", 
                 sep = " ")
    message(msg)
  }
  return(NULL)
}

    
 
  

#' Perform RNA-seq simulation
#'
#' Simulates RNA-seq data based on the input variables.
#'
#' @param list_var List of input variables
#' @param n_genes Number of genes
#' @param min_replicates  Minimum number of replicates (mandatory when generate_counts = TRUE).
#' If min_replicates is different from max replicates, the number of replicates is randomly selected 
#' from a uniform distribution between min and max replicates.
#' @param max_replicates Maximum replicates number (mandotory only if generate_counts = TRUE)
#' If min_replicates is different from max replicates, the number of replicates is randomly selected 
#' from a uniform distribution between min and max replicates.
#' @param sequencing_depth Sequencing depth
#' @param basal_expression base expression gene
#' @param dispersion User-provided dispersion vector (optional)
#' @param normal_distr Specifies the distribution type for generating effects. Choose between 'univariate' (default) or 'multivariate' .
#' - 'univariate': Effects are drawn independently from univariate normal distributions. 
#' - 'multivariate': Effects are drawn jointly from a multivariate normal distribution. (not recommended)
#' @param generate_counts  Logical indicating whether to generate counts (default = TRUE).
#' If TRUE, gene expression counts will be generated.
#' If FALSE, gene expression counts will not be generated. Useful for scenarios
#' where you plan to combine mock objects later to save computing time and resources.
#' When using `combine_mock_rnaseq`, counts will be generated during combination.
#' @return List containing the ground truth, counts, and metadata
#' @export
#' @examples
#' mock_rnaseq(list_var = init_variable(), 
#'              n_genes = 1000, min_replicates = 2,   
#'               max_replicates = 4)
mock_rnaseq <- function(list_var, n_genes, min_replicates = NULL, max_replicates = NULL, sequencing_depth = NULL,  
                        basal_expression = 0 , dispersion = stats::runif(n_genes, min = 0, max = 1000), 
                        normal_distr = "univariate", generate_counts = TRUE) {
  
  ## -- get my effect
  df_inputSimulation <- getInput2simulation(list_var, n_genes, normal_distr )
  ## -- add column logQij
  df_inputSimulation <- getLog_qij(df_inputSimulation)
  df_inputSimulation <- addBasalExpression(df_inputSimulation, n_genes, basal_expression)
  df_inputSimulation <- getMu_ij(df_inputSimulation )
  
  dispersion <- getValidDispersion(dispersion)
  genes_dispersion <- sample(dispersion , size = n_genes, replace = T)
  l_geneID = base::paste("gene", 1:n_genes, sep = "")
  names(genes_dispersion) <- l_geneID
  
  dtf_countsTable <- data.frame()
  metaData <- data.frame()
  scaling_factors <- NULL
  libSize <- 0
  if (isTRUE(generate_counts)){
      stopifnot(is.numeric(max_replicates))
      stopifnot(is.numeric(min_replicates))
      message("Building mu_ij matrix")
      ## -- matrix
      matx_Muij <- getMu_ij_matrix(df_inputSimulation)
      l_sampleID <- getSampleID(list_var)
      matx_bool_replication <- generateReplicationMatrix(list_var, min_replicates, max_replicates)
      mu_ij_matx_rep <- replicateMatrix(matx_Muij, matx_bool_replication)
      matx_dispersion <- getDispersionMatrix(list_var, n_genes, genes_dispersion)
      
      ## same order as mu_ij_matx_rep
      matx_dispersion <- matx_dispersion[ order(row.names(matx_dispersion)), ]
      matx_dispersion_rep <- replicateMatrix(matx_dispersion, matx_bool_replication)
      
      if (!is.null(sequencing_depth)) {
        scaling_factors <- get_scaling_factor(mu_ij_matx_rep, sequencing_depth)
        invisible(get_messages_sequencing_depth(scaling_factors))
        mu_ij_dtf_rep <- scaleCountsTable(mu_ij_matx_rep, scaling_factors)
        mu_ij_matx_rep <- as.matrix(mu_ij_dtf_rep)
        ## -- rescaling effect
        df_inputSimulation$log_qij_scaled <- df_inputSimulation$log_qij_scaled + log(mean(scaling_factors, na.rm = T))
      } else{
        scaling_factors <- NULL
      }
      
      invisible(warning_too_low_mu_ij_row(mu_ij_matx_rep))
      matx_countsTable <- generateCountTable(mu_ij_matx_rep, matx_dispersion_rep)
      message("Counts simulation: Done")
      
      dtf_countsTable <- matx_countsTable %>% as.data.frame()
      checkFractionOfZero(dtf_countsTable)
      metaData <- getSampleMetadata(list_var, matx_bool_replication)
      libSize <- sum(colSums(dtf_countsTable))
  }
  
  settings_df <- getSettingsTable(n_genes, min_replicates, max_replicates, libSize)
  list2ret <- list( settings = settings_df, init = list_var, 
                    groundTruth = list(effects = df_inputSimulation, gene_dispersion = genes_dispersion),
                    counts = dtf_countsTable,
                    metadata = metaData,
                    scaling_factors = scaling_factors)
  
  ## -- clean garbage collector to save memory 
  invisible(gc(reset = TRUE, verbose = FALSE));
  
  return(list2ret)
}


#' Check Fraction of Zero or One in Counts Table
#'
#' This function checks the percentage of counts in a given counts table that are either zero or one.
#' If more than 50% of the counts fall in this category, a warning is issued, suggesting a review of input parameters.
#'
#' @param counts_table A matrix or data frame representing counts.
#' @return NULL
#' @export
#' @examples
#' # Example usage:
#' counts_table <- matrix(c(0, 1, 2, 3, 4, 0, 0, 1, 1), ncol = 3)
#' checkFractionOfZero(counts_table)
checkFractionOfZero <- function(counts_table){
  
    dim_matrix <- dim(counts_table)
    
    n_counts <- dim_matrix[1]*dim_matrix[2]
    
    n_zero_or_one <- sum(counts_table < 1)
    
    fractionOfZero <- n_zero_or_one/n_counts*100
    
    if( fractionOfZero > 50) {
      message("50% of the counts simulated are bellow 1. Consider reviewing your input parameters.")
    }
    
    return(NULL)
    
}



#' Validate and Filter Dispersion Values
#'
#' This function takes an input vector and validates it to ensure that it meets certain criteria.
#'
#' @param input_vector A vector to be validated.
#' @return A validated and filtered numeric vector.
#' @details The function checks whether the input is a vector, suppresses warnings while converting to numeric,
#' and filters out non-numeric elements. It also checks for values greater than zero and removes negative values.
#' If the resulting vector has a length of zero, an error is thrown.
#' @examples
#' getValidDispersion(c(0.5, 1.2, -0.3, "invalid", 0.8))
#' @export
getValidDispersion <- function(input_vector) {
  # Verify if it's a vector
  if (!is.vector(input_vector)) {
    stop("dispersion param is not a vector.")
  }

  input_vector <- suppressWarnings(as.numeric(input_vector))

  # Filter numeric elements
  numeric_elements <- !is.na(input_vector)
  if (sum(!numeric_elements) > 0) {
    message("Non-numeric elements were removed from the dispersion vector")
    input_vector <- input_vector[numeric_elements]
  }

  # Check and filter values > 0
  numeric_positive_elements <- input_vector > 0
  if (sum(!numeric_positive_elements) > 0) {
    message("Negative numeric values were removed from the dispersion vector")
    input_vector <- input_vector[numeric_positive_elements]
  }

  if (length(input_vector) == 0) stop("Invalid dispersion values provided.")

  return(input_vector)
}


#' Generate replication matrix
#'
#' Generates the replication matrix based on the minimum and maximum replication counts.
#'
#' @param list_var Number of samples
#' @param min_replicates Minimum replication count
#' @param max_replicates Maximum replication count
#' @return Replication matrix
#' @export
#' @examples
#' list_var = init_variable()
#' generateReplicationMatrix(list_var, min_replicates = 2, max_replicates = 4)
generateReplicationMatrix <- function(list_var, min_replicates, max_replicates) {
  if (min_replicates > max_replicates) {
    message("min_replicates > max_replicates have been supplied")
    message("Automatic reversing")
    tmp_min_replicates <- min_replicates
    min_replicates <- max_replicates
    max_replicates <- tmp_min_replicates
  }
  l_sampleIDs <- getSampleID(list_var)
  n_samples <-  length(l_sampleIDs)
  return(getReplicationMatrix(min_replicates, max_replicates, n_samples = n_samples))
}

#' Replicate matrix
#'
#' Replicates a matrix based on a replication matrix.
#'
#' @param matrix Matrix to replicate
#' @param replication_matrix Replication matrix
#' @return Replicated matrix
#' @export
#' @examples
#' matrix <- matrix(1:9, nrow = 3, ncol = 3)
#' replication_matrix <- matrix(TRUE, nrow = 3, ncol = 3)
#' replicateMatrix(matrix, replication_matrix)
replicateMatrix <- function(matrix, replication_matrix) {
  n_genes <- dim(matrix)[1]
  rep_list <- colSums(replication_matrix)
  replicated_indices <- rep(seq_len(ncol(matrix)), times = rep_list)
  replicated_matrix <- matrix[, replicated_indices, drop = FALSE]
  suffix_sampleID <- sequence(rep_list)
  colnames(replicated_matrix) <- paste(colnames(replicated_matrix), suffix_sampleID, sep = "_")
  return(replicated_matrix)
}


```

```{r test-mock}

# Test for is_dispersionMatrixValid
test_that("is_dispersionMatrixValid returns TRUE for valid dimensions", {
  matx_dispersion <- matrix(1:6, nrow = 2, ncol = 3)
  matx_bool_replication <- matrix(TRUE, nrow = 2, ncol = 3)
  expect_true(is_dispersionMatrixValid(matx_dispersion, matx_bool_replication))
})

test_that("is_dispersionMatrixValid throws an error for invalid dimensions", {
  matx_dispersion <- matrix(1:4, nrow = 2, ncol = 2)
  matx_bool_replication <- matrix(TRUE, nrow = 2, ncol = 3)
  expect_false(is_dispersionMatrixValid(matx_dispersion, matx_bool_replication))
})

# Test for generateCountTable
test_that("generateCountTable generates count table with correct dimensions", {
  mu_ij_matx_rep <- matrix(1:6, nrow = 2, ncol = 3)
  matx_dispersion_rep <- matrix(1:6, nrow = 2, ncol = 3)
  count_table <- generateCountTable(mu_ij_matx_rep, matx_dispersion_rep)
  expect_equal(dim(count_table), c(2, 3))
})



test_that("checkFractionofZero issues a warning for high fraction of zeros/ones", {
  # Test case 1: Less than 50% zeros and ones
  counts_table_1 <- matrix(c(0, 1, 2, 0, 0, 0, 0, 1, 1), ncol = 3)
  expect_message(checkFractionOfZero(counts_table_1), "50% of the counts simulated are bellow 1. Consider reviewing your input parameters.")
  
  # Test case 2: More than 50% zeros and ones
  counts_table_2 <- matrix(c(0, 1, 0, 0, 1, 10, 100, 1, 1), ncol = 3)
  expect_null(checkFractionOfZero(counts_table_2))
  
})


# Test for replicateMatrix
test_that("replicateMatrix replicates matrix correctly", {
  matrix <- matrix(1:9, nrow = 3, ncol = 3)
  replication_matrix <- matrix(TRUE, nrow = 3, ncol = 3)
  replicated_matrix <- replicateMatrix(matrix, replication_matrix)
  expect_equal(dim(replicated_matrix), c(3, 9))
})

# Test for warning_too_low_mu_ij_row function
test_that("warning_too_low_mu_ij_row emits warning for low mu_ij values", {
  # Create a sample mu_ij matrix
  mu_ij_matrix <- matrix(c(0.5, 0.7, 1.2, 0.2, 0.9, 1.1), nrow = 2)
  # Verify if warning message is emitted
  expect_message(warning_too_low_mu_ij_row(mu_ij_matrix, 2))
})


# Test for generateReplicationMatrix
test_that("generateReplicationMatrix generates replication matrix correctly", {
  replication_matrix <- generateReplicationMatrix(init_variable(),min_replicates = 2, max_replicates = 4)
  expect_equal(dim(replication_matrix), c(4, 2))
})

```


```{r  function-preparingData , filename = "prepare_data2fit"}

#' Convert count matrix to long data frame
#'
#' Converts a count matrix to a long data frame format using geneID as the identifier.
#'
#' @param countMatrix Count matrix
#' @param value_name Name for the value column
#' @param id_vars Name for the id column (default "geneID")
#' @return Long data frame
#' @importFrom reshape2 melt
#' @export
#' @examples
#' list_var <- init_variable()
#' mock_data <- mock_rnaseq(list_var, n_genes = 3, 2, 2)
#' countMatrix_2longDtf(mock_data$counts)
countMatrix_2longDtf <- function(countMatrix, value_name = "kij", id_vars = "geneID") {
  countMatrix <- as.data.frame(countMatrix)
  countMatrix$geneID <- rownames(countMatrix)
  dtf_countLong <- reshape2::melt(countMatrix, id.vars = id_vars, variable.name = "sampleID", 
                                  value.name = value_name)
  dtf_countLong$sampleID <- as.character(dtf_countLong$sampleID)
  return(dtf_countLong)
}

#' Get column name with sampleID
#'
#' Returns the column name in the metadata data frame that corresponds to the given sampleID.
#'
#' @param dtf_countsLong Long data frame of counts
#' @param metadata Metadata data frame
#' @return Column name with sampleID
#' @export
#' @examples
#' list_var <- init_variable()
#' mock_data <- mock_rnaseq(list_var, n_genes = 3, 2,2)
#' dtf_countLong <- countMatrix_2longDtf(mock_data$counts)
#' getColumnWithSampleID(dtf_countLong, mock_data$metadata)
getColumnWithSampleID <- function(dtf_countsLong, metadata) {
  example_spleID <- as.character(dtf_countsLong[1, "sampleID"])
  regex <- paste("^", as.character(dtf_countsLong[1, "sampleID"]), "$", sep = "")
  ## -- init
  name_column <- NA
  for (indice_col in 1:dim(metadata)[2]) {
    bool_column_samples <- grep(pattern = regex, metadata[, indice_col])
    if (!identical(bool_column_samples, integer(0))) {
       name_column <- colnames(metadata)[indice_col]
    } 
  }
  return(name_column)
}

#' Validates a custom expression.
#'
#' This function checks whether the provided expression is a valid R expression and can be correctly parsed.
#'
#' @param custom_expression The custom expression to validate.
#' @return TRUE if the expression is valid, FALSE otherwise.
#' @export
#' @examples
#' isValidExpression("x + 1")
#' isValidExpression("log(x +)")
isValidExpression <- function(custom_expression) {
  if (!is.character(custom_expression)) {
    message("Custom expression must be a character string.")
    return(FALSE)
  }
  
  if (length(custom_expression) != 1) {
    message("Custom expression must be a single string.")
    return(FALSE)
  }
  
  tryCatch({
    parse(text = custom_expression)
    return(TRUE)
  }, error = function(e) {
    message("Invalid custom expression:", e$message)
    return(FALSE)
  })
}

#' Applies a custom function to a count matrix.
#'
#' This function applies a custom R function to each element of the count matrix and returns the transformed matrix.
#'
#' @param count_matrix The count matrix to transform.
#' @param custom_expression The custom R expression to apply to each element of the count matrix.
#' @return The transformed count matrix.
#' @export
#' @examples
#' # Define a count matrix
#' count_matrix <- matrix(1:9, nrow = 3)
#' # Apply the expression x + 1 to each element of the count matrix
#' custom_matrix_transform(count_matrix, "x + 1")
custom_matrix_transform <- function(count_matrix, custom_expression) {
  
  isValidExpression(custom_expression)
  
  # Convertir la chaîne en expression R
  custom_expression <- parse(text = custom_expression)
  
   custom_function <- function(x) {
     eval(custom_expression, list(x = x))
   }
   transformed_matrix <- apply(count_matrix, c(1, 2), custom_function)
   return(transformed_matrix)
}



#' Prepares data for fitting.
#'
#' This function prepares the countMatrix and metadata for fitting by converting the countMatrix to a long format 
#' and joining it with metadata. Optionally, it can apply median ratio normalization and a custom transformation to the countMatrix.
#'
#' @param countMatrix Count matrix.
#' @param metadata Metadata data frame.
#' @param response_name String referring to the target variable name that is being modeled and predicted (default: "kij").
#' @param groupID String referring to the group variable name (default: "geneID").
#' @param row_threshold Numeric threshold for removing rows with all counts below a specified value. Default 0.
#'                This filtering is applied before transformation and normalization.
#' @param transform A custom R expression to apply to each element of the countMatrix. 
#'                  This expression should be provided as a character string.
#'                  For example, to apply log transformation, use `"log(x)"`. 
#'                  Note that `x` represents each element in the countMatrix. See examples for more details.
#'                  The transformation is applied before normalization (if normalization = \code{TRUE}).
#' @param normalization a vector character specifying method to use (default: NULL, possible choices: c('MRN', 'TTM'))
#'        - MRN: median ratio normalization
#'        - TMM: Trimmed Mean of M-values
#' @return Data frame suitable for fitting.
#' @export
#' @examples
#' # Initialize variables and create mock RNA-Seq data
#' list_var <- init_variable()
#' mock_data <- mock_rnaseq(list_var, n_genes = 3, 2,2)
#' # Prepare data for fitting with log transformation
#' data2fit <- prepareData2fit(mock_data$counts, mock_data$metadata, transform = "log(x)")
#' # Prepare data for fitting with custom expression
#' data2fit <- prepareData2fit(mock_data$counts, mock_data$metadata, transform = "sqrt(x + 1)")
prepareData2fit <- function(countMatrix, metadata, response_name = "kij", 
                            groupID = "geneID", row_threshold = 0, transform = NULL , 
                            normalization = NULL) {
  ## -- first check
  stopifnot("ncol(countMatrix) and nrow(metadata) do not match!" = dim(countMatrix)[2] == dim(metadata)[1] )
  
  stopifnot( length(row_threshold) == 1 && is.numeric(row_threshold) && row_threshold >= 0  )
  if (row_threshold > 0){          
      message(paste("INFO: filtering", response_name, "<", row_threshold, sep = " " ))
      idx <- detect_row_matx_bellow_threshold(countMatrix, threshold = row_threshold)
      message(paste(length(which(idx)), "genes removed from data.", sep = " "))
      countMatrix <- countMatrix[ !idx , ]
  }
  
  ## -- user transform
  if (!is.null(transform))
    countMatrix <- custom_matrix_transform(countMatrix, transform)
  
  stopifnot(normalization %in% c(NULL, 'MRN', 'TTM'))
  if(!is.null(normalization)){
      ## -- median ratio normalization
      if ( normalization == "MRN")  {
          message("INFO: Median ratio normalization.")
          countMatrix <- medianRatioNormalization(countMatrix)
      }
       ## -- Trimmed Mean of M-values normalization
      if ( normalization == "TTM")  {
          message("INFO: Trimmed Mean of M-values normalization.")
          countMatrix <- trimmedMeanMvaluesNormalization(countMatrix)
      }
    
  }

  dtf_countsLong <- countMatrix_2longDtf(countMatrix, response_name)
  metadata_columnForjoining <- getColumnWithSampleID(dtf_countsLong, metadata)
  if (is.na(metadata_columnForjoining)) {
    stop("SampleIDs do not seem to correspond between countMatrix and metadata")
  }
  data2fit <- join_dtf(dtf_countsLong, metadata, k1 = "sampleID", k2 = metadata_columnForjoining)
  if (sum(is.na(data2fit[[groupID]])) > 0) {
    warning("Something went wrong. NA introduced in the geneID column. Check the coherence between countMatrix and metadata.")
  }
  return(data2fit)
}



#' Apply Median Ratio Normalization to a Counts Matrix
#'
#' This function performs median ratio normalization on a counts matrix to
#' adjust for differences in sequencing depth across samples.
#'
#' @param countsMatrix A counts matrix where rows represent genes and columns
#'                     represent samples.
#'
#' @return A normalized counts matrix after applying median ratio normalization.
#'
#' @details This function calculates the logarithm of the counts matrix,
#' computes the average log expression for each gene, and then scales each
#' sample's counts by the exponential of the difference between its average log
#' expression and the median of those averages.
#' 
#' @importFrom stats median
#'
#' @examples
#' counts <- matrix(c(100, 200, 300, 1000, 1500, 2500), ncol = 2)
#' normalized_counts <- medianRatioNormalization(counts)
#'
#' @export
medianRatioNormalization <- function(countsMatrix) {
  log_data <- log(countsMatrix)
  average_log <- rowMeans(log_data)
  
  if (all(is.infinite(average_log)))
    stop("Every gene contains at least one zero, cannot compute log geometric means")
  
  idx2keep <- average_log != "-Inf"
  average_log <- average_log[idx2keep]
  
  ratio_data <- sweep(log_data[idx2keep, ], 1, average_log, "-")
  sample_medians <- apply(ratio_data, 2, stats::median)
  
  scaling_factors <- exp(sample_medians)
  countsMatrix_normalized <- sweep(countsMatrix, 2, scaling_factors, "/")
  
  return(countsMatrix_normalized)
}



#' Normalize count data using Trimmed Mean of M-values (TMM) method
#'
#' This function normalizes count data using the Trimmed Mean of M-values (TMM) method,
#' which calculates scale factors to account for differences in library sizes
#' and RNA composition between samples.
#'
#' @param counts_matrix A matrix of count data where rows represent genes and columns represent samples.
#' @return Normalized count matrix
#' @export
trimmedMeanMvaluesNormalization <- function(counts_matrix) {
  # Check if all counts are 0
  if (all(counts_matrix == 0)) {
    # If all counts are 0, return the input matrix unchanged
    return(counts_matrix)
  }
  
  # Calculate the total count for each sample
  total_counts <- colSums(counts_matrix)
  
  # Calculate the effective library size (geometric mean of total counts)
  library_sizes <- exp(mean(log(total_counts)))
  
  # Calculate the scale factors using TMM method
  scale_factors <- library_sizes / total_counts
  
  # Apply the scale factors to normalize counts
  normalized_counts <- t(t(counts_matrix) * scale_factors)
  
  return(normalized_counts)
}


```

```{r  test-prepareData2fit}


# Unit tests for countMatrix_2longDtf
test_that("countMatrix_2longDtf converts count matrix to long data frame", {
  # Sample count matrix
  list_var <- init_variable()
  mock_data <- mock_rnaseq(list_var, n_genes = 3, 2,2, 1)
  # Convert count matrix to long data frame
  dtf_countLong <- countMatrix_2longDtf(mock_data$counts)
  expect_true(is.character(dtf_countLong$sampleID))
  expect_true(is.character(dtf_countLong$geneID))
  expect_true(is.numeric(dtf_countLong$kij))
  expect_equal(unique(dtf_countLong$geneID), c("gene1", "gene2", "gene3"))
  expect_equal(unique(dtf_countLong$sampleID),c("myVariable1_1", "myVariable1_2", 
                                                "myVariable2_1", "myVariable2_2"))
})

# Unit tests for getColumnWithSampleID
test_that("getColumnWithSampleID returns column name with sampleID", {
  # dummy data
  list_var <- init_variable()
  mock_data <- mock_rnaseq(list_var, n_genes = 3, 2,2)
  dtf_countLong <- countMatrix_2longDtf(mock_data$counts)
  
  # Expected output
  expected_output <- "sampleID"
  
  # Get column name with sampleID
  column_name <- getColumnWithSampleID(dtf_countLong, mock_data$metadata)
  
  # Check if the output matches the expected output
  expect_identical(column_name, expected_output)
})

# Unit tests for prepareData2fit
test_that("prepareData2fit prepares data for fitting", {
  # dummy data
  list_var <- init_variable()
  mock_data <- mock_rnaseq(list_var, n_genes = 3, 2,2)
  
  # Prepare data for fitting
  data2fit <- prepareData2fit(mock_data$counts, mock_data$metadata, normalization = NULL)
  
  expect_true(is.character(data2fit$sampleID))
  expect_true(is.character(data2fit$geneID))
  expect_true(is.numeric(data2fit$kij))
  expect_equal(unique(data2fit$geneID), c("gene1", "gene2", "gene3"))
  expect_equal(unique(data2fit$sampleID),c("myVariable1_1", "myVariable1_2", 
                                                "myVariable2_1", "myVariable2_2"))
  
  # Generate mock data
  countMatrix <- suppressWarnings(matrix(c(1, 2, 3, 4), ncol = 10, nrow = 3))
  colnames(countMatrix) <- paste0("sample", 1:10)
  metadata <- data.frame(sampleID = colnames(countMatrix), condition = rep(c("A", "B"), each = 5))
  
  # Call the function with log transformation
  data2fit <- prepareData2fit(countMatrix, metadata, normalization = NULL,transform = "log(x)")
  
  # Test if log transformation has been applied correctly
  expect_equal(unique(data2fit$kij), log(c(1, 2, 3, 4)))
  
  # Call the function with sqrt transformation
  data2fit <- prepareData2fit(countMatrix, metadata, normalization = NULL , transform = "sqrt(x)")
  # Test if log transformation has been applied correctly
  expect_equal(unique(data2fit$kij), sqrt(c(1, 2, 3, 4)))
})


# Tests unitaires pour isValidExpression
test_that("isValidExpression correctly validates custom expressions", {
  # Test avec une expression valide
  expect_true(isValidExpression("x + 1"))
  
  # Test avec une expression invalide (non chaîne de caractères)
  expect_false(isValidExpression(123))
  
  # Test avec une expression invalide (liste de chaînes de caractères)
  expect_false(isValidExpression(c("x + 1", "y + 2")))
  
  # Test avec une expression invalide (syntaxe incorrecte)
  expect_false(isValidExpression("log(x +)"))
})

# Tests unitaires pour custom_matrix_transform
test_that("custom_matrix_transform correctly applies custom function to count matrix", {
  count_matrix <- matrix(1:9, nrow = 3)
  # Test avec une expression valide
  transformed_matrix <- custom_matrix_transform(count_matrix, "x + 1")
  expect_equal(transformed_matrix, count_matrix + 1)
  
  # Test avec une expression invalide
  expect_error(custom_matrix_transform(count_matrix, "log(x +)"))
})



# Test case 1: Normalization with positive counts
test_that("Median ratio normalization works for positive counts", {
  counts <- matrix(c(100, 200, 300, 1000, 1500, 2500), ncol = 2)
  normalized_counts <- medianRatioNormalization(counts)
  
  expected_normalized_counts <- matrix(c(288.6751 , 577.3503 , 866.0254 , 346.4102, 519.6152, 866.0254), ncol = 2)
  
  expect_equal(normalized_counts, expected_normalized_counts, tolerance = 1e-4)
})

# Test case 2: Normalization with zero counts
test_that("Median ratio normalization return error for zero counts", {
  counts <- matrix(c(0, 0, 0, 1000, 1500, 2500), ncol = 2)
  expect_error(medianRatioNormalization(counts))
  
})


# Test case 5: All-zero genes
test_that("Throws an error when all-zero genes are encountered", {
  counts <- matrix(c(0, 0, 0, 0, 0, 0), ncol = 2)
  expect_error(medianRatioNormalization(counts))
})



# Test case 1: Check if normalization is performed correctly
test_that("Normalized counts are calculated correctly", {
  # Create a mock count matrix
  counts_matrix <- matrix(c(10, 20, 30, 40, 50, 60), nrow = 3, byrow = TRUE)
  
  # Expected normalized counts
  expected_normalized <- matrix(c(11.54701, 34.64102, 57.73503, 17.32051, 34.64102, 51.96152), nrow = 3, byrow = F)
  
  # Perform normalization
  normalized_counts <- trimmedMeanMvaluesNormalization(counts_matrix)
  
  # Check if normalized counts match the expected values
  expect_equal(normalized_counts, expected_normalized, tolerance = 1e-2)
})

# Test case 2: Check if input matrix is unchanged when all counts are 0
test_that("Input matrix remains unchanged when all counts are 0", {
  # Create a mock count matrix with all counts as 0
  counts_matrix <- matrix(0, nrow = 3, ncol = 2)
  
  # Perform normalization
  normalized_counts <- trimmedMeanMvaluesNormalization(counts_matrix)
  
  # Check if input matrix and normalized matrix are identical
  expect_identical(normalized_counts, counts_matrix)
})


```

```{r function-fitmodel, filename = "fitmodel"}
#' Check if Data is Valid for Model Fitting
#'
#' This function checks whether the provided data contains all the variables required in the model formula for fitting.
#'
#' @param data2fit The data frame or tibble containing the variables to be used for model fitting.
#' @param formula The formula specifying the model to be fitted.
#'
#' @return \code{TRUE} if all the variables required in the formula are present in \code{data2fit}, otherwise an error is raised.
#'
#' @examples
#' data(iris)
#' formula <- Sepal.Length ~ Sepal.Width + Petal.Length
#' isValidInput2fit(iris, formula) # Returns TRUE if all required variables are present
#' @export
isValidInput2fit <- function(data2fit, formula){
  vec_bool <- all.vars(formula) %in% colnames(data2fit)
  for (i in seq_along(vec_bool)){
    if (isFALSE(vec_bool[i]) ) {
      stop(paste("Variable", all.vars(formula)[i],  "not found"))
    }
  }
  return(TRUE)
}



#' Check if group by exist in data
#'
#' @param data The data framecontaining the variables to be used for model fitting.
#' @param group_by Column name in data representing the grouping variable 
#'
#' @return \code{TRUE} if exist otherwise an error is raised
#'
#' @examples
#' is_validGroupBy(mtcars, 'mpg')
#' @export
is_validGroupBy <- function(data, group_by){
  validGroupBy <- group_by %in% names(data)
  if (!validGroupBy) 
    stop("<Group by> doen't exist in data !")
  return(TRUE)
}




#' Drop Random Effects from a Formula
#'
#' This function allows you to remove random effects from a formula by specifying 
#' which terms to drop. It checks for the presence of vertical bars ('|') in the 
#' terms of the formula and drops the random effects accordingly. If all terms 
#' are random effects, the function updates the formula to have only an intercept. 
#'
#' @param form The formula from which random effects should be dropped.
#'
#' @return A modified formula with specified random effects dropped.
#'
#' @examples
#' # Create a formula with random effects
#' formula <- y ~ x1 + (1 | group) + (1 | subject)
#' # Drop the random effects related to 'group'
#' modified_formula <- drop_randfx(formula)
#'
#' @importFrom stats terms drop.terms
#' @importFrom stats update
#'
#' @export
drop_randfx <- function(form) {
  form.t <- stats::terms(form)
  dr <- grepl("|", labels(form.t), fixed = TRUE)
  if (mean(dr) == 1) {
    form.u <- stats::update(form, . ~ 1)
  } else {
    if (mean(dr) == 0) {
      form.u <- form
    } else {
      form.td <- stats::drop.terms(form.t, which(dr))
      form.u <- stats::update(form, form.td)
    }
  }
  form.u
}



#' Check if a Model Matrix is Full Rank
#'
#' This function checks whether a model matrix is full rank, which is essential for 
#' certain statistical analyses. It computes the eigenvalues of the crossproduct 
#' of the model matrix and determines if the first eigenvalue is positive and if 
#' the ratio of the last eigenvalue to the first is within a defined tolerance.
#'
#' This function is inspired by a similar function found in the Limma package.
#'
#' @param metadata The metadata used to create the model matrix.
#' @param formula The formula used to specify the model matrix.
#'
#' @return \code{TRUE} if the model matrix is full rank, \code{FALSE} otherwise.
#'
#' @examples
#' metadata <- data.frame(x = rnorm(10), y = rnorm(10))
#' formula <- y ~ x
#' is_fullrank(metadata, formula)
#'
#'
#' @importFrom stats model.matrix
#' @export
is_fullrank <- function(metadata, formula) {
  ## drop random eff
  formula <- drop_randfx(formula)
  ## TEST
  model_matrix <- stats::model.matrix(data = metadata, formula)
  e <- eigen(crossprod(model_matrix), symmetric = TRUE, only.values = TRUE)$values
  modelFullRank <- e[1] > 0 && abs(e[length(e)] / e[1]) > 1e-13
  
  if (!modelFullRank) {
    warning("The model matrix is not full rank. One or more variables or interaction terms in the design formula are linear combinations of the others.")
    return(FALSE)
  }

  
  return(TRUE)
}





#' Fit a model using the fitModel function.
#' @param group ID to fit
#' @param formula Formula specifying the model formula
#' @param data Data frame containing the data
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return Fitted model object or NULL if there was an error
#' @export
#' @examples
#' fitModel("mtcars" , formula = mpg ~ cyl + disp, data = mtcars)
fitModel <- function(group , formula, data, ...) {
  is_fullrank(data, formula)
  # Fit the model using glm.nb from the GLmmTMB package
  model <- glmmTMB::glmmTMB(formula, ..., data = data )
  
  ## -- save additional info
  model$frame <- data
  model$groupId <- group
   ## family in ... => avoid error in future update
  additional_args <- list(...)
  familyArgs <- additional_args[['family']]
  if (!is.null(familyArgs)) model$call$family <- familyArgs
  ## control in ... => avoid error in future update
  controlArgs <- additional_args[['control']]
  if (!is.null(controlArgs)) model$call$control <- controlArgs
  
  return(model)
}



#' Fit the model based using fitModel functions.
#'
#' @param groups list of group ID
#' @param group_by Column name in data representing the grouping variable
#' @param data Data frame containing the data
#' @return list of dataframe 
#' @export
#' @importFrom stats setNames
#' @examples
#' prepare_dataParallel(groups = iris$Species, group_by = "Species", 
#'                  data = iris )
prepare_dataParallel <- function(groups, group_by, data) {
  
  l_data2parallel <- lapply( stats::setNames(groups, groups) , function( group_id ){
                      subset_data <- data[ data[[ group_by ]] == group_id, ]
                      return(subset_data)
                  })
  return(l_data2parallel)
}



#' Launch the model fitting process for a specific group.
#'
#' This function fits the model using the specified group, group_by, formula, and data.
#' It handles warnings and errors during the fitting process and returns the fitted model or NULL if there was an error.
#'
#' @param data Data frame containing the data
#' @param group_by Column name in data representing the grouping variable
#' @param formula Formula specifying the model formula
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return List with 'glance' and 'summary' attributes representing the fitted model or NULL if there was an error
#' @export
#' @examples
#' launchFit(group_by = "Species", 
#'            formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
#'            data = iris[ iris[["Species"]] == "setosa" , ] )
launchFit <- function(data, group_by, formula, ...) {
  group <- unique(data[[ group_by ]]) 
  tryCatch(
    expr = {
      withCallingHandlers(
          fitModel(group , formula, data, ...),
          warning = function(w) {
            message(paste(Sys.time(), "warning for group", group, ":", conditionMessage(w)))
            invokeRestart("muffleWarning")
          })
    },
    error = function(e) {
      message(paste(Sys.time(), "error for group", group, ":", conditionMessage(e)))
      NULL
    }
  )
}


#' Fit models in parallel for each group using mclapply and handle logging.
#' Uses parallel_fit to fit the models.
#'
#' @param groups Vector of unique group values
#' @param group_by Column name in data representing the grouping variable
#' @param formula Formula specifying the model formula
#' @param data Data frame containing the data
#' @param n.cores The number of CPU cores to use for parallel processing.
#'  If set to NULL (default), the number of available CPU (minus 1) cores will be automatically detected.
#' @param log_file File to write log (default : Rtmpdir/htrfit.log)
#' @param cl_type cluster type (defautl "PSOCK"). "FORK" is recommanded for linux.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return List of fitted model objects or NULL for any errors
#' @export
#' @examples
#' parallel_fit(group_by = "Species", groups =  iris$Species, 
#'                formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                data = iris, n.cores = 1 )
parallel_fit <- function(groups, group_by, formula, data, n.cores = NULL, 
                         log_file = paste(tempdir(check = FALSE), "htrfit.log", sep = "/"), 
                         cl_type = "PSOCK",  ...) {
  
  if (is.null(n.cores)) n.cores <- max(1, parallel::detectCores(logical = FALSE) - 1)
  
  message(paste("CPU(s) number :", n.cores, sep = " "))
  message(paste("Cluster type :", cl_type, sep = " "))

  
  ## get data for parallelization
  l_data2parallel <- prepare_dataParallel(groups, group_by, data)

  clust <- parallel::makeCluster(n.cores, outfile = log_file , type= cl_type )
  parallel::clusterExport(clust, c("fitModel", "is_fullrank", "drop_randfx"))
  results_fit <- parallel::parLapply(clust, X = l_data2parallel, 
                                     fun = launchFit, 
                                     group_by = group_by, formula = formula, ...)
                                     
  parallel::stopCluster(clust) ; invisible(gc(reset = T, verbose = F, full = T));
  #closeAllConnections()
  return(results_fit)
}

#' Fit models in parallel for each group using mclapply and handle logging.
#' Uses parallel_fit to fit the models.
#'
#' @param formula Formula specifying the model formula
#' @param data Data frame containing the data
#' @param group_by Column name in data representing the grouping variable
#' @param n.cores The number of CPU cores to use for parallel processing.
#'               If set to NULL (default), the number of available CPU cores will be automatically detected.
#' @param log_file File path to save the log messages (default : Rtmpdir/htrfit.log)
#' @param cl_type cluster type (default "PSOCK"). "FORK" is recommended for linux.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return List of fitted model objects or NULL for any errors
#' @export
#' @examples
#' fitModelParallel(formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                  data = iris, group_by = "Species", n.cores = 1) 
fitModelParallel <- function(formula, data, group_by, n.cores = NULL, cl_type = "PSOCK" , 
                             log_file = paste(tempdir(check = FALSE), "htrfit.log", sep = "/"), ...) {
  
  stopifnot(cl_type %in% c("PSOCK", "FORK"))
  ## Some verification
  isValidInput2fit(data, formula)
  is_validGroupBy(data, group_by)

  ## -- print log location
  message( paste("Log file location", log_file, sep =': ') ) 
  
  # Fit models in parallel and capture the results
  groups <- unique(data[[ group_by ]])
  results <- parallel_fit(groups, group_by, formula, data, n.cores, log_file, cl_type, ...)
  #results <- mergeListDataframes(results)
  return(results)
}


```


```{r  test-fitmodel}


test_that("isValidInput2fit returns TRUE for valid data", {
  data(iris)
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  result <- isValidInput2fit(iris, formula)
  expect_true(result)
})

# Test that the function raises an error when a required variable is missing
test_that("isValidInput2fit raises an error for missing variable", {
  data(iris)
  formula <- Sepal.Length ~ Sepal.Width + NonExistentVariable
  expect_error(isValidInput2fit(iris, formula), "Variable NonExistentVariable not found")
})

test_that("fitModel returns a fitted model object", {
  data(mtcars)
  formula <- mpg ~ cyl + disp
  fitted_model <- suppressWarnings(fitModel("mtcars", formula, mtcars))
  #expect_warning(fitModel(formula, mtcars))
  expect_s3_class(fitted_model, "glmmTMB")
  
  # Test with invalid formula
  invalid_formula <- mpg ~ cyl + disp + invalid_var
  expect_error(fitModel("mtcars", invalid_formula, mtcars))
  
  ## check groupID attr
  expect_equal(fitted_model$groupId, "mtcars")
  
  
   # Additional parameters: 
   #change family + formula
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length + (1 | Species)
  fitted_models <- suppressWarnings(fitModel("mtcars",
                                             formula = formula, 
                                             data = iris, 
                                            family = glmmTMB::nbinom1(link = "log") ))
  expect_s3_class(fitted_models$call$family, "family")
  expect_equal(fitted_models$call$formula, formula)
  #change control settings
  fitted_models <- suppressWarnings(fitModel("mtcars",
                                              formula = formula, 
                                                    data = iris, 
                                                    family = glmmTMB::nbinom1(link = "log"), 
                                                control = glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,
                                                                                               eval.max=1e3))))
  expect_equal(fitted_models$call$control,  glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,eval.max=1e3)))
  
  
  
})


# Test if random effects are dropped correctly
test_that("Drop random effects from formula", {
  formula <- y ~ x1 + (1 | group) + (1 | subject)
  modified_formula <- drop_randfx(formula)
  expect_equal(deparse(modified_formula), "y ~ x1")
})

# Test if formula with no random effects remains unchanged
test_that("Keep formula with no random effects unchanged", {
  formula <- y ~ x1 + x2
  modified_formula <- drop_randfx(formula)
  expect_equal(deparse(modified_formula), "y ~ x1 + x2")
})

# Test if all random effects are dropped to intercept
test_that("Drop all random effects to intercept", {
  formula <- ~ (1 | group) + (1 | subject)
  modified_formula <- drop_randfx(formula)
  expect_equal(deparse(modified_formula), ". ~ 1")
})


# Test if a full-rank model matrix is identified correctly
test_that("Identify full-rank model matrix", {
  metadata <- data.frame(x = rnorm(10), y = rnorm(10))
  formula <- y ~ x
  expect_true(is_fullrank(metadata, formula))
})

# Test if a rank-deficient model matrix is detected and throws an error
test_that("Detect rank-deficient model matrix and throw error", {
  metadata <- data.frame(x = factor(rep(c("xA","xB"),each = 5)), 
                         w = factor(rep(c("wA","wB"), each = 5)), 
                         z = factor(rep(c("zA","zB"), each = 5)),
                         y = rnorm(10))
  formula <- y ~ x + w + z + y:w
  expect_warning(is_fullrank(metadata, formula))
  res <- suppressWarnings(is_fullrank(metadata, formula))
  expect_false(res)
})

# Test if a rank-deficient model matrix is detected and throws an error
test_that("Detect rank-deficient model matrix and throw error (with random eff)", {
  metadata <- data.frame(x = factor(rep(c("xA","xB"),each = 5)), 
                         w = factor(rep(c("wA","wB"), each = 5)), 
                         z = factor(rep(c("zA","zB"), each = 5)),
                         y = rnorm(10))
  formula <- y ~ x + w + z + y:w + (1 | w)
  expect_warning(is_fullrank(metadata, formula))
  res <- suppressWarnings(is_fullrank(metadata, formula))
  expect_false(res)
})

# Test if a rank-deficient model matrix is detected and throws an error
test_that("Identify full-rank model matrix (with random eff)", {
  metadata <- data.frame(x = factor(rep(c("xA","xB"),each = 5)), 
                         w = factor(rep(c("wA","wB"), each = 5)), 
                         z = factor(rep(c("zA","zB"), each = 5)),
                         y = rnorm(10))
  formula <- y ~ x + (1 | w)
  expect_true(is_fullrank(metadata, formula))
})


test_that("prepare_dataParallel returns a list of dataframe", {
  ## -- valid input
  data(iris)
  group_by <- "Species"
  groups <- unique(iris$Species)
  l_data <- prepare_dataParallel(groups , group_by, iris)
  expect_type(l_data, "list")
  expect_equal(names(l_data), c("setosa", "versicolor", "virginica"))
  
})

test_that("launchFit handles warnings and errors during the fitting process", {

  # Additional parameters: 
   #change family + formula
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- suppressWarnings(launchFit(formula = formula, 
                                                    data = iris[ iris$Species == "setosa" , ], 
                                                    group_by = "Species", 
                                                    family = glmmTMB::nbinom1(link = "log") ))
  expect_s3_class(fitted_models$call$family, "family")
  expect_equal(fitted_models$call$formula, formula)
  #change control settings
  fitted_models <- suppressWarnings(launchFit(formula = formula, 
                                                    data = iris[ iris$Species == "setosa" , ], 
                                                    group_by = "Species", 
                                                    family = glmmTMB::nbinom1(link = "log") , 
                                                control = glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,
                                                                                               eval.max=1e3))))
  expect_equal(fitted_models$call$control,  glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,eval.max=1e3)))
})

test_that("parallel_fit returns a list of fitted model objects or NULL for any errors", {
  data(iris)
  groups <- unique(iris$Species)
  group_by <- "Species"
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- parallel_fit(groups, group_by, formula, iris, n.cores = 1)
  expect_s3_class(fitted_models$setosa, "glmmTMB")
  expect_length(fitted_models, length(groups))

  # Test with invalid formula
  invalid_formula <- blabla ~ cyl + disp 
  result <- suppressWarnings(parallel_fit(groups, group_by, invalid_formula,  
                                           iris,  n.cores = 1))
  expect_equal(result, expected = list(setosa = NULL, versicolor = NULL, virginica = NULL))
  
  
   # Additional parameters: 
   #change family + formula
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- suppressWarnings(parallel_fit(formula = formula, 
                                                    data = iris, 
                                                    group_by = group_by, 
                                                    groups = "setosa",
                                                    n.cores = 1,
                                                    family = glmmTMB::nbinom1(link = "log") ))
  expect_s3_class(fitted_models$setosa$call$family, "family")
  expect_equal(fitted_models$setosa$call$formula, formula)
  #change control settings
  fitted_models <- suppressWarnings(parallel_fit(formula = formula, 
                                                    data = iris, 
                                                    group_by = group_by, 
                                                    groups = "setosa",
                                                    family = glmmTMB::nbinom1(link = "log"),
                                                    n.cores = 1,
                                                    control = glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,
                                                                                               eval.max=1e3))))
  expect_equal(fitted_models$setosa$call$control,  glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,eval.max=1e3)))
})

test_that("fitModelParallel fits models in parallel for each group and returns a list of fitted model objects or NULL for any errors", {
  data(iris)
  groups <- unique(iris$Species)
  group_by <- "Species"
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
  expect_s3_class(fitted_models$setosa, "glmmTMB")
  expect_length(fitted_models, length(groups))
  
  invalid_formula <- blabla ~ cyl + disp 
  expect_error(fitModelParallel(invalid_formula, iris,  group_by ,  n.cores = 1))
  
   # Additional parameters: 
   #change family + formula
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- suppressWarnings(fitModelParallel(formula = formula, 
                                                     data = iris, 
                                                     group_by = group_by, 
                                                      n.cores = 1,
                                                      family = glmmTMB::nbinom1(link = "log") ))
  expect_s3_class(fitted_models$setosa$call$family, "family")
  expect_equal(fitted_models$setosa$call$formula, formula)
  #change control settings
  fitted_models <- suppressWarnings(fitModelParallel(formula = formula, 
                                                     data = iris, 
                                                     group_by = group_by, 
                                                      n.cores = 1,
                                                     family = glmmTMB::nbinom1(link = "log"), 
                                                control = glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,
                                                                                               eval.max=1e3))))
  expect_equal(fitted_models$setosa$call$control,  glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,eval.max=1e3)))
  
  ## -- invalid group by 
  data(iris)
  group_by <- "invalid_groupBy"
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  expect_error(fitModelParallel(formula, iris, group_by, n.cores = 1))

})

```

```{r function-filtering_fit, filename = "filtering_fit"}

#' Identify top or low fitting observations based on specified diagnostic metric and filtering method.
#'
#' This function identifies top or low fitting observations based on a specified metric and filtering method.
#'
#' @param list_tmb List of glmmTMB objects.
#' @param metric The metric used for diagnostic (e.g., "AIC", "BIC", "logLik", "deviance", "dispersion").
#' @param filter_method The filtering method to be used (e.g., "mad"). 
#'        Feel free to implement your own filetering method
#' @param keep Whether to keep "top" or "low" fitting observations.
#' @param sort Logical indicating whether to sort the results.
#' @param decreasing Logical indicating whether to sort in decreasing order.
#' @param mad_tolerance Tolerance for MAD-based filtering.
#' @return A character vector of row names corresponding to the top or low fitting observations.
#' @export
#'
#' @examples
#' input_var_list <- init_variable()
#' ## -- simulate RNAseq data 
#' mock_data <- mock_rnaseq(input_var_list, 
#'                       n_genes = 5,
#'                       min_replicates  = 3,
#'                       max_replicates = 3,
#'                       basal_expression = 2,
#'                       sequencing_depth = 1e5)
#' ## -- prepare data & fit a model with mixed effect
#' data2fit = prepareData2fit(countMatrix = mock_data$counts, 
#'                         metadata =  mock_data$metadata)
#'l_tmb <- fitModelParallel(formula = kij ~ myVariable, data = data2fit, 
#'                     group_by = "geneID", family = glmmTMB::nbinom2(link = "log"), 
#'                     n.cores = 1)
#' # Identify top fitting observations based on AIC with MAD filtering
#' identifyTopFit(l_tmb, metric = "AIC", filter_method = "mad", keep = "top", 
#'                sort = TRUE, decreasing = TRUE, mad_tolerance = 3)
#' 
#' # Identify low fitting observations based on BIC without sorting
#' identifyTopFit(l_tmb, metric = "BIC", filter_method = "mad", keep = "low", sort = FALSE)
#' 
#' # Identify top fitting observations based on log-likelihood with MAD filtering and custom tolerance
#' identifyTopFit(l_tmb, metric = "logLik", filter_method = "mad", keep = "top", mad_tolerance = 2)
identifyTopFit <- function(list_tmb, metric = "AIC", filter_method = "mad", keep = "top", 
                           sort = F, decreasing = T, mad_tolerance = 3){
  
  ## -- verif
  invisible(isValidList_tmb(list_tmb))
  stopifnot(metric %in% c("AIC", "BIC", "logLik", "deviance", "dispersion"))
  
  glance_df <- glance_tmb(list_tmb)
  ## -- MAD method
  if (filter_method == "mad"){
    lft_threshold <- get_mad_left_threshold(glance_df[[metric]], mad_tolerance)
    get_mad_user_message(metric, glance_df[[metric]], lft_threshold, keep)
    index2keep <- if (keep == "top") which(glance_df[[metric]] > lft_threshold) 
                  else which(glance_df[[metric]] < lft_threshold)
    id2keep <- rownames(glance_df)[index2keep]
    glance_df <- glance_df[ id2keep , ]
  } ## -- feel free to implement other filtering methods .. 
  
  ## -- sort
  if (isTRUE(sort)){
    id2keep <- rownames(glance_df)[order(glance_df[[metric]], decreasing = decreasing)]
  }
  return(id2keep)
}

#' Calculate the left threshold for MAD-based filtering.
#'
#' This function calculates the left threshold for MAD-based filtering.
#'
#' @param x Vector of values used for filtering.
#' @param tolerance Tolerance value for MAD-based filtering.
#' @return The left threshold value.
#' @importFrom stats median mad
#' @export
#'
#' @examples
#' # Calculate the left threshold for MAD-based filtering with tolerance 3
#' get_mad_left_threshold(c(100, 200, 300), 3)
#' 
#' # Calculate the left threshold for MAD-based filtering with tolerance 2
#' get_mad_left_threshold(c(50, 75, 100), 2)
get_mad_left_threshold <- function(x, tolerance){
  med <- stats::median(x, na.rm = T)
  mad <- stats::mad(x, center = med, 
                    constant = 1, na.rm = TRUE, 
                    low = FALSE, high = FALSE)
  med - (tolerance * mad)
}

#' Generate user message for MAD filtering.
#'
#' This function generates a user message explaining the MAD filtering process.
#'
#' @param var The name of the metric used for filtering.
#' @param x Vector of values used for filtering.
#' @param left_threshold The left threshold for filtering.
#' @param keep Whether to keep "top" or "low" observations.
#' @return A message explaining the MAD filtering process.
#' @export
#'
#' @examples
#' # Generate user message for MAD filtering with top fitting observations
#' get_mad_user_message("AIC", c(100, 200, 300), 150, "top")
#' 
#' # Generate user message for MAD filtering with low fitting observations
#' get_mad_user_message("BIC", c(50, 75, 100), 75, "low")
get_mad_user_message <- function(var, x , left_threshold, keep) {
  message("Based on the specified metric (", var, ") and the MAD filtering method, the following selection criteria were applied:")
  message("1. The MAD-based threshold for considering outliers was calculated.")
  keep_str <- ifelse(keep == "top", "above", "bellow")
  message(paste("2. Values",  keep_str , "the threshold were identified, threshold:",  left_threshold))
  message("3. Summary of selection:")
  nb_keep <- ifelse(keep == "top", length(x[which(x > left_threshold)]), length(x[which(x < left_threshold)]) )
  sub_msg2 <- paste("values", keep_str,  "the threshold")
  message(paste("-", nb_keep, "out of", length(x), "observations had", sub_msg2, "for the", var, "metric."))
}

```


```{r  test-filtering_fit}

# Tests for identifyTopFit function
test_that("identifyTopFit correctly identifies top-fitting objects", {
    input_var_list <- init_variable()
    ## -- simulate RNAseq data 
    set.seed(101)
    mock_data <- mock_rnaseq(input_var_list, 
                         n_genes = 5,
                       min_replicates  = 3,
                       max_replicates = 3,
                       basal_expression = 2,
                       sequencing_depth = 1e5)
    ## -- prepare data & fit a model with mixed effect
    data2fit = prepareData2fit(countMatrix = mock_data$counts, 
                         metadata =  mock_data$metadata, 
                         normalization = NULL)
    l_tmb <- fitModelParallel(formula = kij ~ myVariable, data = data2fit, 
                     group_by = "geneID", family = glmmTMB::nbinom2(link = "log"), 
                     n.cores = 1)
    glance_tmb(l_tmb)
    # Identify top fitting observations based on AIC with MAD filtering
    top_genes <- identifyTopFit(l_tmb, metric = "AIC", filter_method = "mad", keep = "top", 
                   sort = TRUE, decreasing = TRUE, mad_tolerance = 3)
    expect_equal(top_genes, c("gene3", "gene5", "gene4", "gene1", "gene2"))
    # Identify low fitting observations based on BIC without sorting
    top_genes <-identifyTopFit(l_tmb, metric = "BIC", filter_method = "mad", keep = "low", sort = FALSE)
    expect_equal(top_genes, character())

    # Identify top fitting observations based on log-likelihood with MAD filtering and custom tolerance
    top_genes <- identifyTopFit(l_tmb, metric = "logLik", filter_method = "mad", keep = "top", mad_tolerance = 2)
    expect_equal(top_genes, c("gene1", "gene2", "gene4", "gene5"))
})

# Tests for get_mad_left_threshold function
test_that("get_mad_left_threshold correctly calculates MAD-based left threshold", {
  left_threshold <- get_mad_left_threshold(c(100, 200, 300), 3)
  expect_equal(left_threshold, 200 - (3 * 100))
})


# Tests for get_mad_message
test_that('get_mad_user_message return correct output', {
  expect_message(get_mad_user_message("BIC", c(50, 75, 100), 75, "low"))
  expect_message(get_mad_user_message("BIC", c(50, 75, 100), 75, "top"))
})


```



```{r function-update_fittedmodel, filename = "update_fittedmodel"}


#' Update glmmTMB models in parallel.
#'
#' This function updates glmmTMB models in parallel using multiple cores, allowing for faster computation. 
#' It updates the models with new reference labels if specified. 
#' It can also be used to fit a new formula or to change additional parameters of glmmTMB (param : "...").
#'
#' @param formula Formula for the GLMNB model.
#' @param list_tmb List of glmmTMB objects.
#' @param reference_labels Vector of reference labels. Default is c(), selecting the first alphanumeric label as reference.
#' @param n.cores Number of cores to use for parallel processing. If NULL, the function will use all available cores.
#' @param log_file File path for the log output (default: Rtmpdir/htrfit.log).
#' @param cl_type cluster type (default "PSOCK"). "FORK" is recommended for linux.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function.
#' @export
#' @return A list of updated GLMNB models.
#'
#' @examples
#' # -- Example usage: update formula
#' data(iris)
#' groups <- unique(iris$Species)
#' group_by <- "Species"
#' formula <- Sepal.Length ~ Sepal.Width + Petal.Length
#' fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
#' new_formula <- Sepal.Length ~ Sepal.Width 
#' results <- updateParallel(new_formula, fitted_models, n.cores = 1)
#' #' # Example usage: update reference
#' # -- Load the mtcars dataset
#' data("mtcars")
#' # -- Specify categorical variables
#' mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
#' levels(mtcars$vs) <- c("V-shaped", "straight")
#' mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
#' levels(mtcars$am) <- c("automatic", "manual")
#' # -- For each group of number of cylinders:
#' # -- Explain fuel consumption with engine shape, Gross horsepower, and transmission type 
#' list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, 
#'                    data = mtcars, group_by = "cyl", n.cores = 1)
#' # -- Relevel transmission and engine shape variables
#' list_tmb <- updateParallel(formula = mpg ~ hp + vs + am, list_tmb,
#'                   reference_labels = c("straight", "manual"), n.cores = 1)
updateParallel <- function (formula, list_tmb, reference_labels = c(), n.cores = NULL, cl_type = "PSOCK", 
                             log_file = paste(tempdir(check = FALSE), "htrfit.log", sep = "/"), 
                             ...) {
  invisible(isValidList_tmb(list_tmb))
  non_null_idx <- first_non_null_index(list_tmb)
  
  if (!is.null(non_null_idx)){
      stopifnot(is.data.frame(list_tmb[[non_null_idx]]$frame))
      isValidInput2fit(list_tmb[[non_null_idx]]$frame, formula)
      list_tmb <- relevelling_factors(list_tmb, reference_labels)
      message(paste("Log file location", log_file, sep = ": "))
      list_tmb <- parallel_update(formula, list_tmb, n.cores, log_file, cl_type, ...)
      clear_memory(except_obj = list_tmb)
  }
  return(list_tmb)
}


#' Re-levels categorical variables in a the frame of a list of glmmTMB objects.
#'
#' This function re-levels categorical variables in a list of glmmTMB objects using the specified reference labels.
#'
#' @param list_tmb List of glmmTMB objects.
#' @param categorical_vars Names of the categorical variables to be re-leveled.
#' @param ref_labels Vector of reference labels corresponding to the categorical variables.
#' @return A list of glmmTMB objects with re-leveled categorical variables.
#' @export
#'
#' @examples
#' # Example usage:
#' # -- Load the mtcars dataset
#' data("mtcars")
#' ## -- specify categorical var
#' mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
#' levels(mtcars$vs) <- c("V-shaped", "straight")
#' mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
#' levels(mtcars$am) <- c("automatic", "manual")
#' # -- For each group of number of cylinders,
#' # -- Explain fuel consumption with engine shape, Gross horsepower, and transmission type 
#' list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, 
#'                data = mtcars, group_by = "cyl", n.cores = 1)
#' # -- Relevel transmission and engine shape variables
#' relevel_list_tmb_frame(list_tmb, c("am", "vs"), c("manual", "straight"))
relevel_list_tmb_frame <- function(list_tmb, categorical_vars, ref_labels){
  names(ref_labels) <- categorical_vars 
  lapply(list_tmb, function( tmb_obj ) {
    if (!is.null(tmb_obj)){
      for (categorical_var in names(ref_labels)) {
        reference_label <- ref_labels[categorical_var]
        tmb_obj$frame[[categorical_var]] <- stats::relevel(tmb_obj$frame[[categorical_var]], ref = reference_label)
      }
    }
    return(tmb_obj)
  })
}


#' Detects categorical variables based on reference labels in a glmmTMB object.
#'
#' This function detects categorical variables based on reference labels in a glmmTMB object's frame.
#'
#' @param tmb_frame The data frame of a glmmTMB object.
#' @param ref_labels Vector of reference labels corresponding to categorical variables.
#' @return Names of the categorical variables detected.
#' @export
#'
#' @examples
#' data("mtcars")
#' ## -- specify categorical var
#' mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
#' levels(mtcars$vs) <- c("V-shaped", "straight")
#' mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
#' levels(mtcars$am) <- c("automatic", "manual")
#' ## -- For each group of number of cylinder,
#' ## -- explain fuel consumption with engine shape, Gross horsepower,  and transmission type
#' list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, 
#'                  data = mtcars, group_by = "cyl" , n.cores = 1)
#' detect_categoricals_vars(list_tmb[["6"]]$frame, c("straight", "manual"))
detect_categoricals_vars <- function(tmb_frame, ref_labels){
  idx_col <- c()
  for (reference_label in ref_labels){
      
      catego_var_idx_col <- unique(which(tmb_frame == reference_label, arr.ind = T)[, "col"])
      
      if (length(catego_var_idx_col) > 1) {
        message_err <- paste("Label", reference_label, " detected in the metadata across different columns.\nUnable to determine the correct columns for re-leveling.\nPlease ensure that reference labels are specific to individual columns for re-leveling")
        stop(message_err)
      }
      
      if (length(catego_var_idx_col) == 0) {
      ref_label_str <- paste(ref_labels , collapse = ", ")
      message_err <- paste("Label", reference_label, "not found in metadata.")
      stop(message_err)
      }
      
  idx_col <- c(idx_col, catego_var_idx_col)
  }
    
  return(colnames(tmb_frame)[idx_col])
}

#' Relevels factors in a list of glmmTMB objects using specified reference labels.
#'
#' This function re-levels factors in a list of glmmTMB objects based on the specified reference labels.
#'
#' @param list_tmb List of glmmTMB objects.
#' @param reference_labels Vector of reference labels.
#' @return A list of glmmTMB objects with re-leveled factors.
#' @export
#'
#' @examples
#' data("mtcars")
#' ## -- specify categorical var
#' mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
#' levels(mtcars$vs) <- c("V-shaped", "straight")
#' mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
#' levels(mtcars$am) <- c("automatic", "manual")
#' ## -- For each group of number of cylinder,
#' ## -- explain fuel consumption with engine shape, Gross horsepower,  and transmission type
#' list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, 
#'                data = mtcars, group_by = "cyl", n.cores = 1 )
#' relevelling_factors(list_tmb , c("straight", "manual"))
relevelling_factors <- function(list_tmb, reference_labels ){
  if (length(reference_labels) > 0){
    l_categorical_vars <- detect_categoricals_vars(list_tmb[[1]]$frame, reference_labels)
    list_tmb <- relevel_list_tmb_frame(list_tmb, l_categorical_vars, reference_labels)
  }
  return(list_tmb)
}





#' Internal function to fit glmmTMB models in parallel.
#'
#' This function is used internally by \code{\link{updateParallel}} to fit glmmTMB models in parallel.
#'
#' @param formula Formula for the GLMNB model.
#' @param list_tmb List of glmmTMB objects.
#' @param n.cores Number of cores to use for parallel processing.
#' @param log_file File path for the log output (default : Rtmpdir/htrfit.log).
#' @param cl_type cluster type (defautl "PSOCK"). "FORK" is recommanded for linux.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function.
#' @export
#' @return A list of updated GLMNB models.
#' @examples
#' data(iris)
#' groups <- unique(iris$Species)
#' group_by <- "Species"
#' formula <- Sepal.Length ~ Sepal.Width + Petal.Length
#' fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
#' new_formula <- Sepal.Length ~ Sepal.Width 
#' results <- parallel_update(new_formula, fitted_models, n.cores = 1)
parallel_update <- function(formula, list_tmb, n.cores = NULL, 
                            log_file = paste(tempdir(check = FALSE), "htrfit.log", sep = "/"), 
                            cl_type = "PSOCK" , ...) {
  if (is.null(n.cores)) n.cores <-  max(1, parallel::detectCores(logical = FALSE) - 1)
  message(paste("CPU(s) number :", n.cores, sep = " "))
  message(paste("Cluster type :", cl_type, sep = " "))
  clust <- parallel::makeCluster(n.cores, type= cl_type, outfile = log_file)
  parallel::clusterExport(clust, c("launchUpdate", "fitUpdate", "is_fullrank", "drop_randfx"))
  updated_res <- parallel::parLapply(clust, X = list_tmb, fun = launchUpdate , formula = formula, ...)
  parallel::stopCluster(clust) ; invisible(gc(reset = T, verbose = F, full = T));
  return(updated_res)
}


#' Fit and update a GLMNB model.
#'
#' This function fits and updates a GLMNB model using the provided formula.
#' @param group group id to save in glmmTMB obj (usefull for update !)
#' @param glmm_obj A glmmTMB object to be updated.
#' @param formula Formula for the updated GLMNB model.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function.
#' @export
#' @return An updated GLMNB model.
#'
#' @examples
#' data(iris)
#' groups <- unique(iris$Species)
#' group_by <- "Species"
#' formula <- Sepal.Length ~ Sepal.Width + Petal.Length
#' fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
#' new_formula <- Sepal.Length ~ Sepal.Width 
#' updated_model <- fitUpdate("setosa", fitted_models[[1]], new_formula)
fitUpdate <- function(group, glmm_obj, formula , ...){
  data <- glmm_obj$frame
  is_fullrank(data, formula)
  resUpdt <- stats::update(glmm_obj, formula, ...)
  resUpdt$frame <- data
  ## save groupID => avoid error in future update
  resUpdt$groupId <- group
  ## family in ... => avoid error in future update
  additional_args <- list(...)
  familyArgs <- additional_args[['family']]
  if (!is.null(familyArgs)) resUpdt$call$family <- familyArgs
  ## control in ... => avoid error in future update
  controlArgs <- additional_args[['control']]
  if (!is.null(controlArgs)) resUpdt$call$control <- controlArgs
  return(resUpdt)
}


#' Launch the update process for a GLMNB model.
#'
#' This function launches the update process for a GLMNB model, capturing and handling warnings and errors.
#'
#' @param glmm_obj A glmmTMB object to be updated.
#' @param formula Formula for the updated GLMNB model.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function.
#' @export
#' @return An updated GLMNB model or NULL if an error occurs.
#'
#' @examples
#' data(iris)
#' groups <- unique(iris$Species)
#' group_by <- "Species"
#' formula <- Sepal.Length ~ Sepal.Width + Petal.Length
#' fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
#' new_formula <- Sepal.Length ~ Sepal.Width 
#' updated_model <- launchUpdate(fitted_models[[1]], new_formula)
launchUpdate <- function(glmm_obj, formula,  ...) {
  if (is.null(glmm_obj)) return(NULL)
  group <- glmm_obj$groupId
  tryCatch(
    expr = {
      withCallingHandlers(
        fitUpdate(group, glmm_obj, formula, ...),
        warning = function(w) {
          message(paste(Sys.time(), "warning for group", group ,":", conditionMessage(w)))
          invokeRestart("muffleWarning")
        })
    },
    error = function(e) {
    message(paste(Sys.time(), "error for group", group,":", conditionMessage(e)))
    return(NULL)
    }
  )
}

```


```{r  test-update_fittedmodel}
# Test updateParallel function
test_that("updateParallel function returns correct results", {
  # Load the required data
  data(iris)
  groups <- unique(iris$Species)
  group_by <- "Species"
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
  new_formula <- Sepal.Length ~ Sepal.Width 
  results <- updateParallel(new_formula, fitted_models, n.cores = 1)
  expect_is(results, "list")
  expect_equal(length(results), length(fitted_models))
  expect_is(results$setosa, "glmmTMB")

  #uncorrect formula 
  new_formula <- Sepal.Length ~ blabla
  expect_error(updateParallel(new_formula, fitted_models, n.cores = 1))
  
  # Additional parameters: 
   #change family + formula
  new_formula <- Sepal.Length ~ Sepal.Width 
  updated_model <- suppressWarnings(updateParallel(fitted_models, 
                                                    formula = new_formula,
                                                    n.cores = 1,
                                                    family = glmmTMB::nbinom1(link = "log") ))
  expect_s3_class(updated_model$setosa$call$family, "family")
  expect_equal(updated_model$setosa$call$formula, new_formula)
  #change control settings
  updated_model <- suppressWarnings(updateParallel(fitted_models, 
                                                 formula = new_formula, 
                                                 family = glmmTMB::nbinom1(link = "log"), 
                                                  n.cores = 1,
                                                control = glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,
                                                                                               eval.max=1e3))))
  expect_equal(updated_model$setosa$call$control,  glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,eval.max=1e3)))
  
  # Update an updated model
  updated_updated_model <- suppressWarnings(updateParallel(updated_model, 
                                                 formula = new_formula, 
                                                  n.cores = 1,
                                                 family = glmmTMB::ziGamma(link = "inverse")))
  expect_s3_class(updated_updated_model$setosa$call$family,  "family")
  
  
  ## -- update label reference
  data("mtcars")
  # -- Specify categorical variables
  mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
  levels(mtcars$vs) <- c("V-shaped", "straight")
  mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
  levels(mtcars$am) <- c("automatic", "manual")
  # -- For each group of number of cylinders:
  # -- Explain fuel consumption with engine shape, Gross horsepower, and transmission type 
  list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, data = mtcars, group_by = "cyl", n.cores = 1)
  # -- Relevel transmission and engine shape variables
  reference_labels <- c("straight", "manual")
  result <- updateParallel(formula = mpg ~ hp + vs + am, list_tmb = list_tmb, reference_labels = reference_labels, n.cores = 1)
  
  # Check if the returned list has the same length as the mock list
  expect_equal(length(result), length(list_tmb))
  expect_equal(levels(result[["6"]]$frame$am)[1] , "manual")
  expect_equal(levels(result[["4"]]$frame$am)[1] , "manual")
  expect_equal(levels(result[["6"]]$frame$vs)[1] , "straight")
  expect_equal(levels(result[["4"]]$frame$vs)[1] , "straight")

})



# Test for relevel_list_tmb_frame function
test_that("relevel_list_tmb_frame re-levels categorical variables correctly", {
  data("mtcars")
  # -- Specify categorical variables
  mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
  levels(mtcars$vs) <- c("V-shaped", "straight")
  mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
  levels(mtcars$am) <- c("automatic", "manual")
  # -- For each group of number of cylinders:
  # -- Explain fuel consumption with engine shape, Gross horsepower, and transmission type 
  list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, data = mtcars, group_by = "cyl", n.cores = 1)
  # -- Relevel transmission and engine shape variables
  reference_labels <- c("straight", "manual")
  result <- relevel_list_tmb_frame(list_tmb, c("vs", "am"), reference_labels)
  
  # Check if the returned list has the same length as the mock list
  expect_equal(length(result), length(list_tmb))
  # Check if categorical variables have been re-leveled correctly
  expect_equal(levels(result[["6"]]$frame$am)[1] , "manual")
  expect_equal(levels(result[["4"]]$frame$am)[1] , "manual")
  expect_equal(levels(result[["6"]]$frame$vs)[1] , "straight")
  expect_equal(levels(result[["4"]]$frame$vs)[1] , "straight")
})

# Test for relevelling_factors function
test_that("relevelling_factors relevel categorical variables correctly", {
  data("mtcars")
  # -- Specify categorical variables
  mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
  levels(mtcars$vs) <- c("V-shaped", "straight")
  mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
  levels(mtcars$am) <- c("automatic", "manual")
  # -- For each group of number of cylinders:
  # -- Explain fuel consumption with engine shape, Gross horsepower, and transmission type 
  list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, data = mtcars, group_by = "cyl", n.cores = 1)
  # -- Relevel transmission and engine shape variables
  reference_labels <- c("straight", "manual")
  result <- relevelling_factors(list_tmb, reference_labels)
  
   # Check if the returned list has the same length as the mock list
  expect_equal(length(result), length(list_tmb))
  # Check if categorical variables have been re-leveled correctly
  expect_equal(levels(result[["6"]]$frame$am)[1] , "manual")
  expect_equal(levels(result[["4"]]$frame$am)[1] , "manual")
  expect_equal(levels(result[["6"]]$frame$vs)[1] , "straight")
  expect_equal(levels(result[["4"]]$frame$vs)[1] , "straight")
})



# Test for detect_categoricals_vars function
test_that("detect_categoricals_vars detect categorical variables correctly", {
  data("mtcars")
  # -- Specify categorical variables
  mtcars$vs <- factor(mtcars$vs) ## Engine (0 = V-shaped, 1 = straight)
  levels(mtcars$vs) <- c("V-shaped", "straight")
  mtcars$am <- factor(mtcars$am) ## Transmission (0 = automatic, 1 = manual)
  levels(mtcars$am) <- c("automatic", "manual")
  # -- For each group of number of cylinders:
  # -- Explain fuel consumption with engine shape, Gross horsepower, and transmission type 
  list_tmb <- fitModelParallel(formula = mpg ~ hp + vs + am, data = mtcars, group_by = "cyl", n.cores = 1)
  # -- Relevel transmission and engine shape variables
  reference_labels <- c("straight", "manual")
  result <- detect_categoricals_vars(list_tmb[["6"]]$frame, reference_labels)
  
  # Check if the returned list has the same length as the mock list
  expect_equal(length(result), length(reference_labels))
  # Check if categorical variables have been re-leveled correctly
  expect_equal(result, c("vs", "am"))
})


# Test parallel_update function
test_that("parallel_update function returns correct results", {
# Load the required data
  data(iris)
  groups <- unique(iris$Species)
  group_by <- "Species"
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
  new_formula <- Sepal.Length ~ Sepal.Width 
  results <- parallel_update(new_formula, fitted_models, n.cores = 1)
  expect_is(results, "list")
  expect_equal(length(results), length(fitted_models))
  expect_is(results$setosa, "glmmTMB")

  #uncorrect formula 
  new_formula <- Sepal.Length ~ blabla
  results <- parallel_update(new_formula, fitted_models, n.cores = 1)
  expect_is(results, "list")
  expect_equal(length(results), length(fitted_models))
  expect_equal(results$setosa, NULL)
  
  # Additional parameters: 
   #change family + formula
  new_formula <- Sepal.Length ~ Sepal.Width 
  updated_model <- suppressWarnings(parallel_update(fitted_models, 
                                                     formula = new_formula,
                                                      n.cores = 1,
                                                      family = glmmTMB::nbinom1(link = "log") ))
  expect_s3_class(updated_model$setosa$call$family, "family")
  expect_equal(updated_model$setosa$call$formula, new_formula)
  #change control
  updated_model <- suppressWarnings(parallel_update(fitted_models, 
                                                 formula = new_formula, 
                                                  n.cores = 1,
                                                 family = glmmTMB::nbinom1(link = "log"), 
                                                control = glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,
                                                                                               eval.max=1e3))))
  expect_equal(updated_model$setosa$call$control,  glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,eval.max=1e3)))
})

# Test fitUpdate function
test_that("fitUpdate function returns correct results", {
  #Load the required data
  data(iris)
  groups <- unique(iris$Species)
  group_by <- "Species"
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
  new_formula <- Sepal.Length ~ Sepal.Width 

  updated_model <- fitUpdate("setosa",fitted_models[[1]], new_formula)
  expect_is(updated_model, "glmmTMB")
  
  ## -- check groupId presence
  expect_equal(updated_model$groupId, "setosa")
  
  # Additional parameters: 
   #change family + formula
  updated_model <- suppressWarnings(fitUpdate("setosa", fitted_models[[1]], new_formula, 
                                              family = glmmTMB::nbinom1(link = "log") ))
  expect_s3_class(updated_model$call$family, "family")
  expect_equal(updated_model$call$formula, new_formula)
  #change control
  updated_model <- suppressWarnings(fitUpdate("setosa", fitted_models[[1]], 
                                              new_formula, 
                                              family = glmmTMB::nbinom1(link = "log"), 
                                              control = glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,
                                                                                               eval.max=1e3))))
  expect_equal(updated_model$call$control,  glmmTMB::glmmTMBControl(optCtrl=list(iter.max=1e3,eval.max=1e3)))
  
})


# Test launchUpdate function
test_that("launchUpdate function returns correct results", {
  data(iris)
  groups <- unique(iris$Species)
  group_by <- "Species"
  formula <- Sepal.Length ~ Sepal.Width + Petal.Length
  fitted_models <- fitModelParallel(formula, iris, group_by, n.cores = 1)
  new_formula <- Sepal.Length ~ Sepal.Width 
  updated_model <- launchUpdate(fitted_models[[1]], new_formula)
  expect_is(updated_model, "glmmTMB")
  # Additional parameters: 
   #change family + formula
  updated_model <- launchUpdate(fitted_models[[1]], new_formula, family = glmmTMB::nbinom1(link = "log") )
  expect_s3_class(updated_model$call$family, "family")
  expect_equal(updated_model$call$formula, new_formula)
  #change control
  updated_model <- launchUpdate(fitted_models[[1]], new_formula, family = glmmTMB::nbinom1(link = "log"), 
                                control = glmmTMB::glmmTMBControl(optimizer=optim, optArgs=list(method="BFGS")))
  expect_equal(updated_model$call$control,  glmmTMB::glmmTMBControl(optimizer=optim, optArgs=list(method="BFGS")))
  
})

```

```{r function-tidy_glmmTMB, filename = "tidy_glmmTMB"}


#' Extract Fixed Effects from a GLMMTMB Model Summary
#'
#' This function extracts fixed effects from the summary of a glmmTMB model.
#'
#' @param x A glmmTMB model object.
#' @return A dataframe containing the fixed effects and their corresponding statistics.
#' @export
#' @examples
#'
#' model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length, data = iris)
#' fixed_effects <- extract_fixed_effect(model)
extract_fixed_effect <- function(x){
  ss = summary(x)
  as.data.frame(ss$coefficients$cond)
  ss_reshaped <- lapply(ss$coefficients,
                        function(sub_obj){
                          if(is.null(sub_obj)) return(NULL)
                          sub_obj <- data.frame(sub_obj)
                          sub_obj$term <- removeDuplicatedWord(rownames(sub_obj))
                          rownames(sub_obj) <- NULL
                          sub_obj <- renameColumns(sub_obj)
                          sub_obj
                        }
  )

  ss_df <- do.call(rbind, ss_reshaped)
  ss_df$component <- sapply(rownames(ss_df), function(x) strsplit(x, split = "[.]")[[1]][1])
  ss_df$effect <- "fixed"
  rownames(ss_df) <- NULL
  ss_df
}



#' Extract Tidy Summary of glmmTMB Model
#'
#' This function extracts a tidy summary of the fixed and random effects from a glmmTMB model and binds them together in a data frame. Missing columns are filled with NA.
#'
#' @param glm_TMB A glmmTMB model object.
#' @param ID An identifier to be included in the output data frame.
#' @return A data frame containing a tidy summary of the fixed and random effects from the glmmTMB model.
#' @export
#' @examples
#'
#' model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length, data = iris)
#' tidy_summary <- getTidyGlmmTMB(glm_TMB = model, ID = "Model1")
getTidyGlmmTMB <- function(glm_TMB, ID){
  if(is.null(glm_TMB)) return(NULL)
  df1 <- extract_fixed_effect(glm_TMB)
  df1 <- build_missingColumn_with_na(df1)
  df2 <- extract_ran_pars(glm_TMB)
  df2 <- build_missingColumn_with_na(df2)
  df_2ret <- rbind(df1, df2)
  df_2ret[df_2ret == "NaN"] <- NA
  df_2ret <- df_2ret[rowSums(!is.na(df_2ret)) > 0, ] # drop rows full of NA
  df_2ret$ID <- ID
  df_2ret <- reorderColumns(df_2ret,  
                            c("ID","effect", "component", "group", "term", 
                              "estimate", "std.error", "statistic", "p.value"))
  return(df_2ret)
}



#' Extract Tidy Summary of Multiple glmmTMB Models
#'
#' This function takes a list of glmmTMB models and extracts a tidy summary of the fixed and random effects from each model. It then combines the results into a single data frame.
#'
#' @param list_tmb A list of glmmTMB model objects.
#' @return A data frame containing a tidy summary of the fixed and random effects from all glmmTMB models in the list.
#' @export
#' @examples
#' model1 <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1 | Species), data = iris)
#' model2 <- glmmTMB::glmmTMB(Petal.Length ~ Sepal.Length + Sepal.Width + (1 | Species), data = iris)
#' model_list <- list(Model1 = model1, Model2 = model2)
#' tidy_summary <- tidy_tmb(model_list)
tidy_tmb <- function(list_tmb){
  
  if (identical(class(list_tmb), "glmmTMB")) return(getTidyGlmmTMB(list_tmb, "glmmTMB"))
  attributes(list_tmb)$names
  l_tidyRes <- lapply(attributes(list_tmb)$names,
               function(ID)
                 {
                    glm_TMB <- list_tmb[[ID]]
                    getTidyGlmmTMB(glm_TMB, ID)
                }
              )
  ret <- do.call("rbind", l_tidyRes)
  return(ret) 
}
  

#' Build DataFrame with Missing Columns and NA Values
#'
#' This function takes a DataFrame and a list of column names and adds missing columns with NA values to the DataFrame.
#'
#' @param df The input DataFrame.
#' @param l_columns A character vector specifying the column names to be present in the DataFrame.
#' @return A DataFrame with missing columns added and filled with NA values.
#' @export
#' @examples
#'
#' df <- data.frame(effect = "fixed", term = "Sepal.Length", estimate = 0.7)
#' df_with_na <- build_missingColumn_with_na(df)
build_missingColumn_with_na <- function(df, l_columns = c("effect", "component", "group", 
                                                          "term", "estimate", "std.error", "statistic", "p.value")) {
  df_missing_cols <- setdiff(l_columns, colnames(df))
  # Ajouter les colonnes manquantes à df1
  if (length(df_missing_cols) > 0) {
    for (col in df_missing_cols) {
      df[[col]] <- NA
    }
  }
  return(df)
}



#' Convert Correlation Matrix to Data Frame
#'
#' This function converts a correlation matrix into a data frame containing the correlation values and their corresponding interaction names.
#'
#' @param corr_matrix A correlation matrix to be converted.
#' @return A data frame with the correlation values and corresponding interaction names.
#' @export
#' @examples
#' mat <- matrix(c(1, 0.7, 0.5, 0.7, 
#'                  1, 0.3, 0.5, 0.3, 1), 
#'                  nrow = 3, 
#'                  dimnames = list(c("A", "B", "C"), 
#'                                  c("A", "B", "C")))
#' correlation_matrix_2df(mat)
correlation_matrix_2df <- function(corr_matrix){
  vec_corr <- corr_matrix[lower.tri(corr_matrix)]
  col_names <- removeDuplicatedWord(colnames(corr_matrix))
  row_names <- removeDuplicatedWord(rownames(corr_matrix))
  interaction_names <- vector("character", length(vec_corr))
  k <- 1
  n <- nrow(corr_matrix)
  for (i in 1:(n - 1)) {
    for (j in (i + 1):n) {
      interaction_names[k] <- paste("cor__", paste(col_names[i], ".", row_names[j], sep = ""), sep ="")
      k <- k + 1
    }
  }
  names(vec_corr) <- interaction_names
  ret <- data.frame(estimate = vec_corr)
  ret$term <- rownames(ret)
  rownames(ret) <- NULL
  ret
}

#' Wrapper for Extracting Variance-Covariance Components
#'
#' This function extracts variance-covariance components from a glmmTMB model object for a specific grouping factor and returns them as a data frame.
#'
#' @param var_cor A variance-covariance object from the glmmTMB model.
#' @param elt A character indicating the type of effect, either "cond" or "zi".
#' @return A data frame containing the standard deviation and correlation components for the specified grouping factor.
#' @export
#' @examples
#' model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), 
#'                            data = iris, family = gaussian)
#' var_cor <- summary(model)$varcor$cond
#' ran_pars_df <- wrapper_var_cor(var_cor, "Species")
wrapper_var_cor <- function(var_cor, elt){
  var_group <- attributes(var_cor)$names
  l_ret <- lapply(var_group,
         function(group)
         {
           ## -- standard dev
           std_df <- data.frame(estimate = attributes(var_cor[[group]])$stddev)
           std_df$term <- paste("sd_", removeDuplicatedWord(rownames(std_df)), sep = "")
           ## -- correlation
           corr_matrix <- attributes(var_cor[[group]])$correlation
           #no correlation 2 return 
           if (nrow(corr_matrix) <= 1) ret <-  std_df
           else {
            corr_df <- correlation_matrix_2df(corr_matrix)
            ret <- rbind(std_df, corr_df)
          }
           ret$component <- elt
           ret$effect <- "ran_pars"
           ret$group <- group
           rownames(ret) <- NULL
           return(ret)
         })
  l_ret

}


#' Extract Random Parameters from a glmmTMB Model
#'
#' This function extracts the random parameters from a glmmTMB model and returns them as a data frame.
#'
#' @param x A glmmTMB model object.
#' @return A data frame containing the random parameters and their estimates.
#' @export
#' @importFrom stats setNames
#' @examples
#' model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), data = iris, 
#'          family = gaussian)
#' random_params <- extract_ran_pars(model)
extract_ran_pars <- function(x){
  ss <- summary(x)
  l_2parcour <- c("cond", "zi")
  l_res = lapply(stats::setNames(l_2parcour, l_2parcour),
          function(elt)
            {
              var_cor <- ss$varcor[[elt]]
              return(wrapper_var_cor(var_cor, elt))
  })

  ret <- rbind(do.call("rbind", l_res$cond),do.call("rbind", l_res$zi))
  return(ret)

}


#' Rename Columns in a Data Frame
#'
#' This function renames columns in a data frame based on specified old names and corresponding new names.
#'
#' @param df A data frame.
#' @param old_names A character vector containing the old column names to be replaced.
#' @param new_names A character vector containing the corresponding new column names.
#' @return The data frame with renamed columns.
#' @export
#' @examples
#' df <- data.frame(Estimate = c(1.5, 2.0, 3.2),
#'                  Std..Error = c(0.1, 0.3, 0.2),
#'                  z.value = c(3.75, 6.67, 4.90),
#'                  Pr...z.. = c(0.001, 0.0001, 0.002))
#'
#' renamed_df <- renameColumns(df, old_names = c("Estimate", "Std..Error", "z.value", "Pr...z.."),
#'                               new_names = c("estimate", "std.error", "statistic", "p.value"))
#'
renameColumns <- function(df, old_names  = c("Estimate","Std..Error", "z.value", "Pr...z.."), 
                           new_names = c("estimate","std.error", "statistic", "p.value")) {
  stopifnot(length(old_names) == length(new_names))

  for (i in seq_along(old_names)) {
    old_col <- old_names[i]
    new_col <- new_names[i]

    if (old_col %in% names(df)) {
      names(df)[names(df) == old_col] <- new_col
    } else {
      warning(paste("Column", old_col, "not found in the data frame. Skipping renaming."))
    }
  }

  return(df)
}

```


```{r  test-tidy_glmmTMB}

test_that("extract_fixed_effect returns the correct results for glmmTMB models", {
  data(iris)
  # Créer un modèle glmmTMB avec les données iris (exemple)
  model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), data = iris)
  
  # Appeler la fonction extract_fixed_effect sur le modèle
  result <- extract_fixed_effect(model)
  
  # Check les résultats attendus
  expect_s3_class(result, "data.frame")
  expect_equal(result$effect, c("fixed", "fixed", "fixed"))
  expect_equal(result$component , c("cond", "cond", "cond"))
  expect_equal(result$term , c("(Intercept)", "Sepal.Width", "Petal.Length"))
  
})


test_that("getTidyGlmmTMB returns the correct results for glmmTMB models", {
  data(iris)
  # Créer un modèle glmmTMB avec les données iris (exemple)
  model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length, data = iris)
  tidy_summary <- getTidyGlmmTMB(glm_TMB = model, ID = "Model1")
  
  # Check les résultats attendus
  expect_s3_class(tidy_summary, "data.frame")
  expect_equal(tidy_summary$effect, c("fixed", "fixed", "fixed"))
  expect_equal(tidy_summary$component , c("cond", "cond", "cond"))
  expect_equal(tidy_summary$term , c("(Intercept)", "Sepal.Width", "Petal.Length"))
  expect_equal(tidy_summary$ID , c("Model1", "Model1", "Model1"))

  #MODEL == NULL
  tidy_summary <- getTidyGlmmTMB(glm_TMB = NULL, ID = "Model1")
  expect_equal(tidy_summary, NULL)
})


test_that("build_missingColumn_with_na returns the correct results", {
  df <- data.frame(effect = "fixed", term = "Sepal.Length", estimate = 0.7)
  df_with_na <- build_missingColumn_with_na(df)
  expected_df <- data.frame(effect = "fixed",
                            term = "Sepal.Length",
                            estimate = 0.7,
                            component = NA,
                            group = NA,
                            std.error = NA,
                            statistic = NA,
                            p.value  = NA)
    
  expect_equal(df_with_na, expected_df)
})



test_that("correlation_matrix_2df returns expected output",{

  mat <- matrix(c(1, 0.7, 0.5, 0.7, 1, 0.3, 0.5, 0.3, 1), nrow = 3, dimnames = list(c("A", "B", "C"), c("A", "B", "C")))
  df_corr <- correlation_matrix_2df(mat)
  df_expected <- data.frame(estimate = c(0.7, 0.5, 0.3),
                            term = c("cor__A.B", "cor__A.C", "cor__B.C"))
  expect_equal(df_corr, df_expected)
})



test_that("wrapper_var_cor returns expected output",{
  data(iris)
  model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), data = iris, family = gaussian)
  var_cor <- summary(model)$varcor$cond
  ran_pars_df <- wrapper_var_cor(var_cor, "Species")
  expected_l = list(data.frame(estimate = 0.4978083, term = "sd_(Intercept)", 
                               component = "Species", effect = "ran_pars", group = "Species"))
  expect_equal(ran_pars_df , expected_l, tolerance = 0.0000001) 
})


test_that("extract_ran_pars returns expected output",{
  model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), 
                            data = iris, family = gaussian)
  random_params <- extract_ran_pars(model)
  
  expected_df = data.frame(estimate = 0.4978083, term = "sd_(Intercept)", 
                               component = "cond", effect = "ran_pars", group = "Species")
  expect_equal(random_params , expected_df, tolerance = 0.0000001) 
})


test_that("renameColumns returns expected output",{
  df <- data.frame(Estimate = c(1.5, 2.0, 3.2),
                  Std..Error = c(0.1, 0.3, 0.2),
                  z.value = c(3.75, 6.67, 4.90),
                  Pr...z.. = c(0.001, 0.0001, 0.002))

  new_colnames <- c("estimate", "std.error", "statistic", "p.value")
  renamed_df <- renameColumns(df, old_names = c("Estimate", "Std..Error", "z.value", "Pr...z.."),
                               new_names = new_colnames)
  expect_equal(colnames(renamed_df),c("estimate", "std.error", "statistic", "p.value"))
  expect_equal(dim(renamed_df), dim(df))
})
    
test_that("tidy_tmb returns expected output",{
  model1 <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1 | Species), data = iris)
  model2 <- glmmTMB::glmmTMB(Petal.Length ~ Sepal.Length + Sepal.Width + (1 | Species), data = iris)
  model_list <- list(Model1 = model1, Model2 = model2)
  result <- tidy_tmb(model_list)
  expect_equal(unique(result$ID), c("Model1", "Model2"))
  expect_equal(unique(result$effect), c("fixed", "ran_pars"))
  expect_equal(unique(result$component), "cond")
  expect_equal(unique(result$term), c("(Intercept)", "Sepal.Width", "Petal.Length", "sd_(Intercept)", "Sepal.Length"))
  expect_true("estimate" %in% colnames(result))
  expect_true("std.error" %in% colnames(result))
  expect_true("statistic" %in% colnames(result))
  expect_true("p.value" %in% colnames(result))
  
  
  # zi component
  model2 <- glmmTMB::glmmTMB(Petal.Length ~ Sepal.Length + Sepal.Width + (1 | Species), data = iris, ziformula = ~1)
  model_list <- list(Model1 = model1, Model2 = model2)
  result_withZi <- tidy_tmb(model_list)
  expect_equal(dim(result_withZi)[1], dim(result)[1] + 1 )
  expect_equal(unique(result_withZi$component), c("cond", "zi"))

   ## unique obect in list 
  model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), data = iris)
  result <- tidy_tmb(model)
  expect_true("effect" %in% colnames(result))
  expect_true("component" %in% colnames(result))
  expect_true("group" %in% colnames(result))
  expect_true("term" %in% colnames(result))
  expect_true("estimate" %in% colnames(result))
  expect_true("std.error" %in% colnames(result))
  expect_true("statistic" %in% colnames(result))
  expect_true("p.value" %in% colnames(result))
})
```

```{r function-glance_glmmTMB, filename = "glance_glmmTMB"}

#' Extracts the summary statistics from a list of glmmTMB models.
#'
#' This function takes a list of glmmTMB models and extracts the summary statistics (AIC, BIC, logLik, deviance,
#' df.resid, and dispersion) for each model and returns them as a single DataFrame.
#'
#' @param list_tmb A list of glmmTMB models or a unique glmmTMB obj model
#' @return A DataFrame with the summary statistics for all the glmmTMB models in the list.
#' @export
#' @importFrom stats setNames
#' @examples
#' data(mtcars)
#' models <-  fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                            group_by = "Species",n.cores = 1, data = iris)
#' result <- glance_tmb(models)
glance_tmb <- function(list_tmb){
  
  if (identical(class(list_tmb), "glmmTMB")) return(getGlance(list_tmb))
  l_group <- attributes(list_tmb)$names
  l_glance <- lapply(stats::setNames(l_group, l_group), function(group) getGlance(list_tmb[[group]]))
  return(do.call("rbind", l_glance))
}


#' Extracts the summary statistics from a single glmmTMB model.
#'
#' This function takes a single glmmTMB model and extracts the summary statistics (AIC, BIC, logLik, deviance,
#' df.resid, and dispersion) from the model and returns them as a DataFrame.
#'
#' @param x A glmmTMB model.
#' @return A DataFrame with the summary statistics for the glmmTMB model.
#' @export
#'
#' @examples
#' data(mtcars)
#' model <- glmmTMB::glmmTMB(mpg ~ wt + (1|cyl), data = mtcars)
#' getGlance(model)
getGlance <- function(x){
  if (is.null(x)) return(NULL)
  ret <- data.frame(t(summary(x)$AICtab))
  ret$dispersion <- glmmTMB::sigma(x)
  ret
}


```


```{r test-glance_glmmTMB }

test_that("glance_tmb returns the summary statistics for multiple models", {
  data(iris)
  models <-  fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length, group_by = "Species",n.cores = 1, data = iris)
  result <- glance_tmb(models)
  expect_true("AIC" %in% colnames(result))
  expect_true("BIC" %in% colnames(result))
  expect_true("logLik" %in% colnames(result))
  expect_true("deviance" %in% colnames(result))
  expect_true("df.resid" %in% colnames(result))
  expect_true("dispersion" %in% colnames(result))
  expect_true(sum(c("setosa","versicolor", "virginica") %in% rownames(result)) == 3) 
  
  ## unique obect in list 
  model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), data = iris)
  result <- glance_tmb(model)
  expect_true("AIC" %in% colnames(result))
  expect_true("BIC" %in% colnames(result))
  expect_true("logLik" %in% colnames(result))
  expect_true("deviance" %in% colnames(result))
  expect_true("df.resid" %in% colnames(result))
  expect_true("dispersion" %in% colnames(result))

})

test_that("getGlance returns the summary statistics for a single model", {
  model <- glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length + (1|Species), data = iris)
  result <- getGlance(model)
  expect_true("AIC" %in% colnames(result))
  expect_true("BIC" %in% colnames(result))
  expect_true("logLik" %in% colnames(result))
  expect_true("deviance" %in% colnames(result))
  expect_true("df.resid" %in% colnames(result))
  expect_true("dispersion" %in% colnames(result))
})
```


```{r function-plot_metrics, filename = "plot_metrics"}

#' Subset the glance DataFrame based on selected variables.
#'
#' This function subsets the glance DataFrame to keep only the specified variables.
#'
#' @param glance_df The glance DataFrame to subset.
#' @param focus A character vector of variable names to keep, including "AIC", "BIC", "logLik", "deviance",
#' "df.resid", and "dispersion".
#' @return A subsetted glance DataFrame with only the selected variables.
#' @export
#'
#' @examples
#' data(iris)
#' models <-  fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                        group_by = "Species",n.cores = 1, data = iris)
#' glance_df <- glance_tmb(models)
#' glance_df$group_id <- rownames(glance_df)
#' subset_glance(glance_df, c("AIC", "BIC"))
subset_glance <- function(glance_df, focus){
  idx_existing_column <- focus %in% c("AIC", "BIC", "logLik", "deviance" ,"df.resid", "dispersion" )
  if(sum(!idx_existing_column) > 0) warning(paste(focus[!idx_existing_column], ": does not exist\n"))
  focus <- focus[idx_existing_column]
  if (identical(focus, character(0)))
    stop(paste0("Please select at least one variable to focus on : ", 
                "AIC, BIC, logLik, deviance, df.resid, dispersion" ))
  glance_df <- glance_df[ , c("group_id", focus)]
  return(glance_df)
}


#' Plot Metrics for Generalized Linear Mixed Models (GLMM)
#'
#' This function generates a density plot of the specified metrics for the given
#' list of generalized linear mixed models (GLMMs).
#'
#' @param list_tmb A list of GLMM objects to extract metrics from.
#' @param focus A character vector specifying the metrics to focus on. Possible
#'   values include "AIC", "BIC", "logLik", "deviance", "df.resid", and
#'   "dispersion". If \code{NULL}, all available metrics will be plotted.
#'
#' @return A ggplot object displaying histogram plots for the specified metrics.
#'
#' @importFrom reshape2 melt
#' @importFrom ggplot2 aes facet_wrap geom_histogram theme_bw theme ggtitle
#' @importFrom rlang .data
#' @export
#'
#' @examples
#' models_list <-  fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                      group_by = "Species",n.cores = 1, data = iris)
#' diagnostic_plot(models_list, focus = c("AIC", "BIC", "deviance"))
diagnostic_plot <- function(list_tmb, focus = NULL) {

  invisible(isValidList_tmb(list_tmb))
  
  glance_df <- glance_tmb(list_tmb)
  glance_df$group_id <- rownames(glance_df)
  if (!is.null(focus)) {
    glance_df <- subset_glance(glance_df, focus)
  }
  long_glance_df <- reshape2::melt(glance_df, variable.name = "metric")
  p <- ggplot2::ggplot(long_glance_df) +
    ggplot2::geom_histogram(ggplot2::aes(x = .data$value, col = .data$metric, fill = .data$metric)) +
    ggplot2::facet_wrap(~metric, scales = "free") +
    ggplot2::theme_bw() +
    ggplot2::theme(legend.position = 'null') + 
    ggplot2::ggtitle("Metrics plot")
  return(p)
}


```

```{r test-plot_metrics }


test_that("subset_glance subsets the glance DataFrame correctly", {
  data(iris)
  models <-  fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length, group_by = "Species",n.cores = 1, data = iris)
  glance_df <- glance_tmb(models)
  glance_df$group_id <- rownames(glance_df)
  result <- subset_glance(glance_df, c("AIC", "BIC"))
  expect_true("AIC" %in% colnames(result))
  expect_true("BIC" %in% colnames(result))
  expect_true("group_id" %in% colnames(result))
  expect_true(sum(c("setosa","versicolor", "virginica") %in% rownames(result)) == 3) 
})




test_that("diagnostic_plot returns a ggplot object", {
  
  data(iris)
  l_glmTMB <- list(
        setosa = glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length, 
                     data = subset(iris, Species == "setosa")),
        versicolor = glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length, 
                         data = subset(iris, Species == "versicolor")),
        virginica = glmmTMB::glmmTMB(Sepal.Length ~ Sepal.Width + Petal.Length, 
                          data = subset(iris, Species == "virginica"))
  )
  p <- diagnostic_plot(l_glmTMB)
  expect_true(inherits(p, "gg"))

})


```







```{r function-evaluate_dispersion, filename = "evaluate_dispersion"}

#' Get Dispersion Comparison
#'
#' Compares inferred dispersion values with actual dispersion values.
#'
#' @param inferred_dispersion A data frame containing inferred dispersion values.
#' @param actual_dispersion A numeric vector containing actual dispersion values.
#'
#' @return A data frame comparing actual and inferred dispersion values.
#' 
#' @export
#'
#' @examples
#' \dontrun{
#' dispersion_comparison <- getDispersionComparison(inferred_disp, actual_disp)
#' }
getDispersionComparison <- function(inferred_dispersion, actual_dispersion) {
  actual_disp <- data.frame(actual = actual_dispersion)
  actual_disp$ID <- rownames(actual_disp)
  rownames(actual_disp) <- NULL
  disp_comparison <- join_dtf(actual_disp, inferred_dispersion, c("ID"), c("ID"))
  disp_comparison$term <- 'dispersion'
  disp_comparison$description <- 'dispersion'
  return(disp_comparison)
}


#' Extract DESeq2 Dispersion Values
#'
#' Extracts inferred dispersion values from a DESeq2 wrapped object.
#'
#' @param dds_wrapped A DESeq2 wrapped object containing dispersion values.
#'
#' @return A data frame containing inferred dispersion values.
#' 
#' @export
#'
#' @examples
#' \dontrun{
#' dispersion_df <- extract_ddsDispersion(deseq2_object)
#' }
extract_ddsDispersion <- function(dds_wrapped) {
  inferred_dispersion <- data.frame(estimate = dds_wrapped$dispersion)
  inferred_dispersion$ID <- rownames(inferred_dispersion)
  rownames(inferred_dispersion) <- NULL
  return(inferred_dispersion)
}


#' Extract TMB Dispersion Values
#'
#' Extracts inferred dispersion values from a TMB result object.
#'
#' @param list_tmb A TMB result object containing dispersion values.
#'
#' @return A data frame containing inferred dispersion values.
#' 
#' @export
#'
#' @examples
#' \dontrun{
#' dispersion_df <- extract_tmbDispersion(tmb_result)
#' }
extract_tmbDispersion <- function(list_tmb) {
  glanceRes <- glance_tmb(list_tmb)
  inferred_dispersion <- data.frame(estimate = glanceRes$dispersion)
  inferred_dispersion$ID <- rownames(glanceRes)
  rownames(inferred_dispersion) <- NULL
  return(inferred_dispersion)
}


```

```{r test-evaluate_dispersion }



test_that("extract_tmbDispersion function extracts dispersion correctly", {
   N_GENES = 50
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  input_var_list <- init_variable(name = "varA", mu = 10, sd = 0.1, level = 3)
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates = MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata =  mock_data$metadata, normalization = 'MRN')
  l_res <- fitModelParallel(formula = kij ~ varA,
                          data = data2fit, group_by = "geneID",
                          family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
  extracted_disp <- extract_tmbDispersion(l_res)
  expect_identical(colnames(extracted_disp), c("estimate", "ID"))
})

test_that("extract_ddsDispersion function extracts dispersion correctly", {
  N_GENES = 100
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  input_var_list <- init_variable(name = "varA", mu = 10, sd = 0.1, level = 3)
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates = MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  dds <- DESeq2::DESeqDataSetFromMatrix(
      countData = round(mock_data$counts),
      colData = mock_data$metadata,
      design = ~ varA)
  dds <- DESeq2::DESeq(dds, quiet = TRUE)
  deseq_wrapped = wrap_dds(dds, 2, "greaterAbs")
  
  extracted_disp <- extract_ddsDispersion(deseq_wrapped)
  expect_identical(colnames(extracted_disp), c("estimate", "ID"))
})

test_that("getDispersionComparison function works correctly", {
   N_GENES = 100
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  input_var_list <- init_variable(name = "varA", mu = 10, sd = 0.1, level = 3)
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates = MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata =  mock_data$metadata, normalization = 'MRN')
  l_res <- fitModelParallel(formula = kij ~ varA,
                          data = data2fit, group_by = "geneID",
                          family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
  
  tmb_disp_inferred <- extract_tmbDispersion(l_res)
    
  comparison <- getDispersionComparison(tmb_disp_inferred, mock_data$groundTruth$gene_dispersion)
  expect_identical(colnames(comparison), c("actual",  "ID", "estimate",  "term","description"))
})


```





```{r function-sequencing_depth_scaling, filename =  "sequencing_depth_scaling"}

#' Scale Counts Table
#'
#' This function scales a counts table based on the expected sequencing depth per sample.
#'
#' @param countsTable A counts table containing raw read counts.
#' @param scalingDepth_factor  sequencing depth factor vector
#' @return A scaled counts table.
#'
#' @export
#' @examples
#' mock_data <- list(counts = matrix(c(10, 20, 30, 20, 30, 10, 10, 20, 20, 20, 30, 10), ncol = 4))
#' scaled_counts <- scaleCountsTable(countsTable = mock_data$counts, 2)
#'
scaleCountsTable <- function(countsTable, scalingDepth_factor){
  counts_scaled <- as.data.frame(sweep(as.matrix(countsTable), 2,  scalingDepth_factor, "*"))
  return(counts_scaled)
}



#' Get scaling factor for count normalization
#'
#' Calculates the scaling factor for count normalization based on the sequencing depth.
#'
#' @param countsTable Matrix or data frame of counts data.
#' @param seq_depth Numeric vector containing the sequencing depths.
#' @return Numeric vector of scaling factors for count normalization.
#' @export
get_scaling_factor <- function(countsTable, seq_depth){
  seq_depth_simu <- colSums(countsTable)
  
  if (length(seq_depth) > length(seq_depth_simu))
    message("INFO: The length of the sequencing_depth vector exceeds the number of samples. Only the first N values will be utilized.")
  if (length(seq_depth) < length(seq_depth_simu))
    message("INFO: The length of the sequencing_depth vector is shorter than the number of samples. Values will be recycled.")

  scalingDepth_factor <- suppressWarnings(seq_depth/seq_depth_simu)
  return(scalingDepth_factor)
}



```

```{r  test-sequencing_depth_scaling}

# Test case 1: Scaling with valid min_seq_depth and max_seq_depth
test_that("Valid scaling of counts table", {
      # Test data
      mock_data <- list(counts = matrix(c(10, 20, 30, 20, 30, 10, 10, 20, 20, 20, 30, 10), ncol = 4))
      # Test function
      scaled_counts <- scaleCountsTable(countsTable = mock_data$counts, 2)
      
      # Expected scaled counts
      expected_scaled_counts <- data.frame(matrix(c(20, 40, 60, 40, 60,20, 20,40,40, 40,60,20), ncol = 4))
      colnames(expected_scaled_counts) <- c("V1", "V2", "V3", "V4")
      # Check if the scaled counts match the expected scaled counts
      expect_identical(scaled_counts, expected_scaled_counts)

})





test_that("get_scaling_factor returns correct scaling factors", {
  # Mock data
  counts <- matrix(c(10, 20, 30, 20, 30, 10, 10, 20, 20, 20, 30, 10), ncol = 4)  # Simulated counts data
  seq_depth <- c(100, 200)  # Sequencing depths for each sample
  scaling_factors <- suppressWarnings(get_scaling_factor(counts, seq_depth))
  
  expect_equal(scaling_factors, c(1.666667,3.333333,2.000000,3.333333), tolerance = 0.001)
})

```



```{r function-basal_expression_scaling, filename =  "basal_expression_scaling"}




#' Get bin expression for a data frame.
#'
#' This function divides the values of a specified column in a data frame into \code{n_bins} bins of equal width.
#' The bin labels are then added as a new column in the data frame.
#'
#' @param dtf_coef A data frame containing the values to be binned.
#' @param n_bins The number of bins to create.
#' 
#' @return A data frame with an additional column named \code{binExpression}, containing the bin labels.
#' @export
#' @examples
#' dtf <- data.frame(mu_ij = c(10, 20, 30, 15, 25, 35, 40, 5, 12, 22))
#' dtf_with_bins <- getBinExpression(dtf, n_bins = 3)
#' 
getBinExpression <- function(dtf_coef, n_bins){
      col2bin <- "mu_ij"
      bin_labels <- cut(dtf_coef[[col2bin]], n_bins, labels = paste("BinExpression", 1:n_bins, sep = "_"))
      dtf_coef$binExpression <-  bin_labels     
      return(dtf_coef)
}




#' Generate BE data.
#' 
#' This function generates basal expression data for a given number of genes, in a vector of basal expression values.
#' 
#' @param n_genes The number of genes to generate BE data for.
#' @param basal_expression a numeric vector from which sample BE for eacg genes
#' 
#' @return A data frame containing gene IDs, BE values
#' 
#' @examples
#' generate_basal_expression(n_genes = 100, 10)
#' 
#' @export
generate_basal_expression <- function(n_genes, basal_expression) {
  ## --avoid bug if one value in basal_expr
  pool2sample <- c(basal_expression, basal_expression)
  BE <- sample(x = pool2sample, size = n_genes, replace = T)
  l_geneID <- base::paste("gene", 1:n_genes, sep = "")
  ret <- list(geneID = l_geneID, basalExpr = BE) %>% as.data.frame()
  return(ret)
}



#' Compute basal expresion for gene expression based on the coefficients data frame.
#'
#' This function takes the coefficients data frame \code{dtf_coef} and computes
#' basal expression for gene expression. The scaling factors are generated 
#' using the function \code{generate_basal_expression}.
#'
#' @param dtf_coef A data frame containing the coefficients for gene expression.
#' @param n_genes number of genes in simulation
#' @param basal_expression gene basal expression vector
#'
#' @return A modified data frame \code{dtf_coef} with an additional column containing
#'         the scaling factors for gene expression.
#' @export
#' @examples 
#' list_var <- init_variable()
#' N_GENES <- 5
#' dtf_coef <- getInput2simulation(list_var, N_GENES)
#' dtf_coef <- getLog_qij(dtf_coef)
#' addBasalExpression(dtf_coef, N_GENES, 1)
addBasalExpression <- function(dtf_coef, n_genes, basal_expression){
    BE_df  <-  generate_basal_expression(n_genes, basal_expression )
    dtf_coef <- join_dtf(dtf_coef, BE_df, "geneID", "geneID")
    return(dtf_coef) 
}




```

```{r  test-basal_expression_scaling}

test_that("generate_basal_expression returns correct number of genes", {
  be_data <- generate_basal_expression(n_genes = 100, 1)
  expect_equal(nrow(be_data), 100)
})


test_that("generate_basal_expression returns BE values within specified vector", {
  BE_vec <- c(1, 2, 33, 0.4)
  be_data <- generate_basal_expression(n_genes = 100, BE_vec)
  expect_true(all(be_data$basalExpr %in% BE_vec))
})


test_that("Test for addbasalExpre function",{
  
  list_var <- init_variable()
  N_GENES <- 5
  dtf_coef <- getInput2simulation(list_var, N_GENES)
  dtf_coef <- getLog_qij(dtf_coef)

  # Test the function
  dtf_coef_with_BE <- addBasalExpression(dtf_coef, N_GENES, 5)

  # Check if the output is a data frame
  expect_true(is.data.frame(dtf_coef_with_BE))

  # Check if the number of rows is equal to number of row in dtf_coef
  expect_equal(nrow(dtf_coef_with_BE), nrow(dtf_coef))
  
  # Check if the number of rows is equal to number of row in dtf_coef +1
  expect_equal(ncol(dtf_coef_with_BE), ncol(dtf_coef)+1)
  
  # Check if the data frame has a new column "BE"
  expect_true("basalExpr" %in% colnames(dtf_coef_with_BE))
  
  # Check if the values in the "BE" column are numeric
  expect_true(all(is.numeric(dtf_coef_with_BE$basalExpr)))

})


# Test 1: Check if the function returns the correct number of bins
test_that("getBinExpression returns the correct number of bins", {
  dtf <- data.frame(mu_ij = c(10, 20, 30, 15, 25, 35, 40, 5, 12, 22))
  n_bins <- 3
  dtf_with_bins <- getBinExpression(dtf, n_bins)
  expect_equal(nrow(dtf_with_bins), nrow(dtf), label = "Number of rows should remain the same")
  expect_equal(ncol(dtf_with_bins), ncol(dtf) + 1, label = "Number of columns should increase by 1")
})

# Test 2: Check if the function adds the binExpression column correctly
test_that("getBinExpression adds the binExpression column correctly", {
  dtf <- data.frame(mu_ij = c(10, 20, 30, 15, 25, 35, 40, 5, 12, 22))
  n_bins <- 3
  dtf_with_bins <- getBinExpression(dtf, n_bins)
  expected_bins <- c("BinExpression_1", "BinExpression_2", "BinExpression_3", "BinExpression_1", "BinExpression_2",
                     "BinExpression_3", "BinExpression_3", "BinExpression_1", "BinExpression_1", "BinExpression_2")
  expect_equal(dtf_with_bins$binExpression, factor(expected_bins))
})

# Test 3: Check if the function handles negative values correctly
test_that("getBinExpression handles negative values correctly", {
  dtf <- data.frame(mu_ij = c(10, -20, 30, -15, 25, 35, -40, 5, 12, -22))
  n_bins <- 4
  dtf_with_bins <- getBinExpression(dtf, n_bins)
  expected_bins <- c("BinExpression_3", "BinExpression_2", "BinExpression_4", "BinExpression_2", "BinExpression_4",
                     "BinExpression_4", "BinExpression_1", "BinExpression_3", "BinExpression_3", "BinExpression_1")
  expect_equal(dtf_with_bins$binExpression, factor(expected_bins))
})



```



```{r function-actual_mainfixeffects, filename =  "actual_mainfixeffects" }

#' Calculate average values by group
#'
#' @param data The input data frame
#' @param column The name of the target variable
#' @param group_by The names of the grouping variables
#' @importFrom data.table setDT tstrsplit
#' @importFrom rlang :=
#' @return A data frame with average values calculated by group
#' @export
averageByGroup <- function(data, column, group_by) {
  group_values <- split(data[[column]], data[group_by])
  mean_values <- sapply(group_values, mean)
  result <- data.frame(Group = names(mean_values), logQij_mean = mean_values)
  data.table::setDT(result)[, {{ group_by }} := data.table::tstrsplit(Group, "[.]")]
  result <- subset(as.data.frame(result), select = -Group)
  return(result)
}

#' Subset Fixed Effect Inferred Terms
#'
#' This function subsets the tidy TMB object to extract the fixed effect inferred terms
#' along with their categorization into interaction and non-interaction terms.
#'
#' @param tidy_tmb The tidy TMB object containing the inferred terms.
#'
#' @return A list with two elements:
#' \describe{
#'   \item{fixed_term}{A list with two components - \code{nonInteraction} and \code{interaction},
#'   containing the names of the fixed effect inferred terms categorized as non-interaction and interaction terms, respectively.}
#'   \item{data}{A data frame containing the subset of tidy_tmb that contains the fixed effect inferred terms.}
#' }
#' @export
#' @examples
#' input_var_list <- init_variable()
#' mock_data <- mock_rnaseq(input_var_list, 10, 2, 2)
#' getData2computeActualFixEffect(mock_data$groundTruth$effect)
#' data2fit = prepareData2fit(countMatrix = mock_data$counts, metadata =  mock_data$metadata )
#' #-- fit data
#' resFit <- fitModelParallel(formula = kij ~ myVariable   ,
#'                            data = data2fit, group_by = "geneID",
#'                            family = glmmTMB::nbinom2(link = "log"), n.cores = 1) 
#' tidy_tmb <- tidy_tmb(resFit)
#' subsetFixEffectInferred(tidy_tmb)
subsetFixEffectInferred <- function(tidy_tmb){
  fixed_tidy <- tidy_tmb[tidy_tmb$effect == "fixed",]
  l_term <- unique(fixed_tidy$term)
  l_term <- l_term[!l_term %in% c("(Intercept)", NA)]
  index_interaction <- grepl(x = l_term, ":")
  l_term_nonInteraction <- l_term[!index_interaction]
  l_term_interaction <- l_term[index_interaction]
  l_term2ret <- list(nonInteraction = l_term_nonInteraction, interaction = l_term_interaction )
  return(list(fixed_term = l_term2ret, data = fixed_tidy))
}


#' Get data for calculating actual values
#'
#' @param groundTruth The ground truth data frame
#' @return A list containing required data for calculating actual values
#' @export
#' @examples
#' input_var_list <- init_variable()
#' mock_data <- mock_rnaseq(input_var_list, 10, 2, 2)
#' getData2computeActualFixEffect(mock_data$groundTruth$effect)
getData2computeActualFixEffect <- function(groundTruth){
  col_names <- colnames(groundTruth)
  categorical_vars <- col_names[grepl(col_names, pattern = "label_")]
  average_gt <- averageByGroup(groundTruth, "log_qij_scaled", c("geneID", categorical_vars))
  average_gt <- convert2Factor(data = average_gt, columns = categorical_vars )
  return(list(categorical_vars = categorical_vars, data = average_gt))
}


#' Get the intercept dataframe
#'
#' @param fixeEff_dataActual The input list containing  the categorical variables and the data 
#' @return The intercept dataframe
#' @export
getActualIntercept <- function(fixeEff_dataActual) {
  ## -- split list
  data<- fixeEff_dataActual$data
  categorical_vars <- fixeEff_dataActual$categorical_vars

  if (length(categorical_vars) == 1){
    l_labels <- list()
    l_labels[[categorical_vars]] <- levels(data[, categorical_vars])

  } else l_labels <- lapply(data[, categorical_vars], levels)
  index_ref <- sapply(categorical_vars, function(var) data[[var]] == l_labels[[var]][1])
  index_ref <- rowSums(index_ref) == dim(index_ref)[2]
  df_intercept <- data[index_ref, ]
  df_intercept$term <- "(Intercept)"
  colnames(df_intercept)[colnames(df_intercept) == "logQij_mean"] <- "actual"
  df_intercept$description <- "(Intercept)"

  index2keep <- !colnames(df_intercept) %in% categorical_vars
  df_intercept <- df_intercept[,index2keep]

  return(df_intercept)
}


#' Generate actual values for a given term
#'
#' @param term The term for which actual values are calculated
#' @param df_actualIntercept The intercept dataframe
#' @param dataActual The average ground truth dataframe
#' @param categorical_vars The names of the categorical variables
#' @return The data frame with actual values for the given term
#' @export
generateActualForMainFixEff <- function(term , df_actualIntercept , dataActual  , categorical_vars){
  
  computeActualValueForMainFixEff <- function(df_actualIntercept, df_term) {
        df_term$actual <- df_term$logQij_mean - df_actualIntercept$actual
        return(subset(df_term, select = -c(logQij_mean)))
  }
  
  df_term <- subsetByTermLabel(dataActual, categorical_vars , term  )
  df_term <- computeActualValueForMainFixEff(df_actualIntercept, df_term)
  df_term$description <- gsub("\\d+$", "", term)
  return(df_term)
}



#' subset data By Term Label
#'
#'
#' Get a subset of the data based on a specific term label in the categorical variables.
#'
#' @param data The data frame to subset
#' @param categorical_vars The categorical variables to consider
#' @param term_label The term label to search for
#' @return A subset of the data frame containing rows where the categorical variables match the specified term label
#' @export
#'
#' @examples
#' # Create a data frame
#' my_data <- data.frame(color = c("red", "blue", "green", "red"),
#'                       size = c("small", "medium", "large", "medium"),
#'                       shape = c("circle", "square", "triangle", "circle"))
#' my_data[] <- lapply(my_data, as.factor)
#'
#' # Get the subset for the term "medium" in the "size" variable
#' subsetByTermLabel(my_data, "size", "medium")
#' # Output: A data frame with rows where "size" is "medium"
#'
#' # Get the subset for the term "red" in the "color" variable
#' subsetByTermLabel(my_data, "color", "red")
#' # Output: A data frame with rows where "color" is "red"
subsetByTermLabel <- function(data, categorical_vars, term_label ) {
  if (length(categorical_vars) == 1) {
    l_labels <- list()
    l_labels[[categorical_vars]] <- levels(data[, categorical_vars])
  } else {
    l_labels <- lapply(data[, categorical_vars], levels)
  }

  term_variable <- findAttribute(term_label, l_labels)
  if(is.null(term_variable)) stop("term_label not in 'data'")

  index_ref <- sapply(categorical_vars, function(var) {
    if (var == term_variable) {
      data[[var]] == term_label
    } else {
      data[[var]] == l_labels[[var]][1]
    }
  })

  index_ref <- rowSums(index_ref) == dim(index_ref)[2]
  df_term <- data[index_ref, ]
  df_term$term <- term_label
  return(df_term)
}

#' Find Attribute
#'
#' Find the attribute containing the specified term in a given list.
#'
#' @param term The term to search for
#' @param list The list to search within
#' @return The attribute containing the term, or NULL if the term is not found in any attribute
#' @export
#'
#' @examples
#' # Create a list
#' my_list <- list(color = c("red", "blue", "green"),
#'                 size = c("small", "medium", "large"),
#'                 shape = c("circle", "square", "triangle"))
#'
#' # Find the attribute containing "medium"
#' findAttribute("medium", my_list)
findAttribute <- function(term, list) {
  for (attr in names(list)) {
    if (term %in% list[[attr]]) {
      return(attr)
    }
  }
  return(NULL)  # If the term is not found in any attribute
}

#' Get actual values for non-interaction terms
#'
#' @param l_term list of term to compute 
#' @param fixeEff_dataActual A list containing required data for calculating actual values
#' @param df_actualIntercept The data frame containing the actual intercept values
#' @return A data frame with actual values for non-interaction terms
#' @export
getActualMainFixEff <- function( l_term , fixeEff_dataActual , df_actualIntercept  ){
  ## -- split list
  categorical_vars <- fixeEff_dataActual$categorical_vars
  data_groundTruth <- fixeEff_dataActual$data
  ## -- iteration over term
  l_actual <- lapply(l_term,
                     function(term){
                       generateActualForMainFixEff(term, df_actualIntercept,
                                               data_groundTruth, categorical_vars)})
  df_actual <- do.call("rbind", l_actual)
  index2keep <- !colnames(df_actual) %in% categorical_vars
  df_actual <- df_actual[,index2keep]
  return(df_actual)
}

```

```{r test-actual_mainfixeffects}

test_that("Test for subsetFixEffectInferred function", {
  # Prepare the test data
  input_var_list <- init_variable(name = "varA", mu = c(1,2,3), level = 3) %>%
                    init_variable(name = "varB", mu = c(2,-6), level = 2) %>%
                    add_interaction(between_var = c("varA", "varB"), mu = 1, sd = 3)

  mock_data <- mock_rnaseq(input_var_list, 10, 2, 2)
  getData2computeActualFixEffect(mock_data$groundTruth$effect)
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata = mock_data$metadata, normalization = NULL)

  # Fit data
  resFit <- fitModelParallel(formula = kij ~ varA + varB + varA:varB,
                             data = data2fit, group_by = "geneID",
                             family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
  tidy_tmb <- tidy_tmb(resFit)

  # Test the subsetFixEffectInferred function
  result <- subsetFixEffectInferred(tidy_tmb)
  # Define expected output
  expected_nonInteraction <- c("varA2", "varA3", "varB2")
  expected_interaction <- c("varA2:varB2", "varA3:varB2")

  # Compare the output with the expected values
  expect_equal(result$fixed_term$nonInteraction, expected_nonInteraction)
  expect_equal(result$fixed_term$interaction, expected_interaction)
})


# Tests for averageByGroup
test_that("averageByGroup returns correct average values", {
  # Create a sample data frame
  data <- data.frame(
    Group1 = rep(c("A", "B", "C", "D"), each = 2),
    Group2 = rep(c("X", "Y"), times = 4),
    Value = 1:8
  )
  
  # Calculate average values by group
  result <- averageByGroup(data, column = "Value", group_by = c("Group1", "Group2"))
  
  # Check the output
  expect_equal(nrow(result), 8)  # Number of rows
  expect_equal(colnames(result), c("logQij_mean","Group1", "Group2" ))  # Column names
  expect_equal(result$logQij_mean, c(1, 3,5, 7, 2, 4, 6, 8))  # Average values
})



# Tests for findAttribute
test_that("findAttribute returns the correct attribute", {
  # Create a sample list
  my_list <- list(
    color = c("red", "blue", "green"),
    size = c("small", "medium", "large"),
    shape = c("circle", "square", "triangle")
  )
  
  # Find attributes
  attr1 <- findAttribute("medium", my_list)
  attr2 <- findAttribute("rectangle", my_list)
  
  # Check the output
  expect_equal(attr1, "size")  # Attribute containing "medium"
  expect_equal(attr2, NULL)  # Attribute containing "rectangle"
})

# Tests for getActualIntercept
test_that("getActualIntercept returns the correct intercept dataframe", {
  # Create a sample data frame
  data <- data.frame(
    Category1 = c("A", "B", "A", "B"),
    Category2 = c("X", "Y", "X", "Z"),
    logQij_mean = 1:4
  )
  data[, c("Category1", "Category2")] <- lapply(data[, c("Category1", "Category2")], as.factor )
  
  l_fixEffDataActual= list(categorical_vars = c("Category1", "Category2"), data = data)
  # Get the intercept dataframe
  result <- getActualIntercept(l_fixEffDataActual)
  
  # Check the output
  expect_equal(nrow(result), 2)  # Number of rows
  expect_equal(unique(result$term), "(Intercept)")  # Term column
  expect_equal(result$actual, c(1, 3))  # Actual column
})





# Test subsetByTermLabel with single categorical variable
test_that("subsetByTermLabel with single categorical variable", {
  my_data <- list(color = c("red", "blue", "green", "red"),
                        size = c("small", "medium", "large", "medium"),
                        shape = c("circle", "square", "triangle", "circle"))
  my_data <- expand.grid(my_data)
  my_data[] <- lapply(my_data, as.factor)

  subset_data <- subsetByTermLabel(my_data, categorical_vars = "size", term_label = "medium")
  expected_data <- my_data[my_data$size == "medium", ]
  expected_data$term <- "medium"
  
  expect_equal(subset_data, expected_data)
})

# Test subsetByTermLabel with single term label in multiple categorical variables
test_that("subsetByTermLabel with single term label in multiple categorical variables", {
   my_data <- list(color = c("red", "blue", "green", "red"),
                        size = c("small", "medium", "large", "medium"),
                        shape = c("circle", "square", "triangle", "circle"))
  my_data <- expand.grid(my_data)
  my_data[] <- lapply(my_data, as.factor)

  subset_data <- subsetByTermLabel(data = my_data, categorical_vars = c("color", "shape"), term_label = "circle")
  expected_data <- my_data[my_data$shape == "circle" & my_data$color == "red" , ]
  expected_data$term <- "circle"

  expect_equal(subset_data, expected_data)
})

# Test subsetByTermLabel with non-existent term label expect error
test_that("subsetByTermLabel with non-existent term label", {
   my_data <- list(color = c("red", "blue", "green", "red"),
                        size = c("small", "medium", "large", "medium"),
                        shape = c("circle", "square", "triangle", "circle"))
  my_data <- expand.grid(my_data)
  my_data[] <- lapply(my_data, as.factor)

  expect_error(subsetByTermLabel(data = my_data, categorical_vars = "size", term_label = "extra-large"))
})



# Test getActualMainFixEff
test_that("getActualMainFixEff", {
  input_var_list <- init_variable() 
  set.seed(101)
  mock_data <- mock_rnaseq(input_var_list, 2, 2, 2, normal_distr = "multivariate")
  data2fit <- prepareData2fit(mock_data$counts, mock_data$metadata, normalization = NULL)
  inference <- fitModelParallel(kij ~ myVariable , 
                                  group_by = "geneID", data2fit, n.cores = 1)
  tidy_inference <- tidy_tmb(inference)
  tidy_fix <- subsetFixEffectInferred(tidy_inference)
  fixEff_dataActual <- getData2computeActualFixEffect(mock_data$groundTruth$effects)
  actual_intercept <- getActualIntercept(fixEff_dataActual)
  ## -- main = non interaction
  actual_mainFixEff <- getActualMainFixEff(tidy_fix$fixed_term$nonInteraction,
                            fixEff_dataActual, actual_intercept)
  
  expected_actual <- data.frame(geneID = c("gene1", "gene2"),
                                term = c("myVariable2", "myVariable2"),
                                actual = c(0.3209061, 0.3248530),
                                description = "myVariable")
  rownames(actual_mainFixEff) <- NULL
  rownames(actual_mainFixEff) <- NULL
  expect_equal(actual_mainFixEff, expected_actual, tolerance = 1e-3)
})



test_that("getData2computeActualFixEffect return correct output",{
  # Prepare the test data
  input_var_list <- init_variable() 
  set.seed(101)
  mock_data <- mock_rnaseq(input_var_list, 2, 2, 2, normal_distr = "multivariate")
  data2fit <- prepareData2fit(mock_data$counts, mock_data$metadata, normalization = NULL)
  inference <- fitModelParallel(kij ~ myVariable, group_by = "geneID", data2fit, n.cores = 1)
  tidy_inference <- tidy_tmb(inference)
  tidy_fix <- subsetFixEffectInferred(tidy_inference)

  # Call the function to test
  fixEff_dataActual <- getData2computeActualFixEffect(mock_data$groundTruth$effects)

  # Define expected output
  expected_data <- data.frame(logQij_mean = c(-0.09771371,-0.22516585,
                                              0.22319238,0.09968718), geneID = c("gene1", "gene2", "gene1", "gene2"), label_myVariable = factor(c("myVariable1", "myVariable1", "myVariable2", "myVariable2")))
  expected_categorical_vars <- "label_myVariable"
  # Compare the output with the expected values
  expect_equal(fixEff_dataActual$data, expected_data, tolerance = 1e-3)
  expect_equal(fixEff_dataActual$categorical_vars, expected_categorical_vars)
})


test_that("generateActualForMainFixEff returns correct values for main fixed effect term", {
  # Prepare the test data
  input_var_list <- init_variable() 
  set.seed(101)
  mock_data <- mock_rnaseq(input_var_list, 2, 2, 2, normal_distr = "multivariate")
  data2fit <- prepareData2fit(mock_data$counts, mock_data$metadata, normalization = NULL )
  fixEff_dataActual <- getData2computeActualFixEffect(mock_data$groundTruth$effects)
  actual_intercept <- getActualIntercept(fixEff_dataActual)
  df_term <- generateActualForMainFixEff("myVariable2", actual_intercept, fixEff_dataActual$data, fixEff_dataActual$categorical_vars)

  # Define expected output
  expected <- data.frame(
    geneID = c("gene1", "gene2"),
    label_myVariable = factor(c("myVariable2", "myVariable2"), levels = c("myVariable1", "myVariable2")),
    term = c("myVariable2", "myVariable2"),
    actual = c(0.3209061, 0.3248530),
    description = c("myVariable", "myVariable")
  )
  rownames(df_term) <- NULL
  rownames(expected) <- NULL
  # Compare the output with the exp-ted values
  expect_equal(df_term, expected, tolerance = 1e3)
})

```

```{r function-actual_interactionfixeffects, filename =  "actual_interactionfixeffects" }
#' Filter DataFrame
#'
#' Filter a DataFrame based on the specified filter list.
#'
#' @param df The DataFrame to be filtered
#' @param filter_list A list specifying the filters to be applied
#' @return The filtered DataFrame
#' @export
#'
#' @examples
#' # Create a DataFrame
#' df <- data.frame(ID = c(1, 2, 3, 4),
#'                  Name = c("John", "Jane", "Mike", "Sarah"),
#'                  Age = c(25, 30, 28, 32),
#'                  Gender = c("Male", "Female", "Male", "Female"))
#'
#' # Create a filter list
#' filter_list <- list(Name = c("John", "Mike"), Age = c(25, 28))
#'
#' # Filter the DataFrame
#' filter_dataframe(df, filter_list)
filter_dataframe <- function(df, filter_list ) {
  filtered_df <- df

  for (attr_name in attributes(filter_list)$names) {
    attr_value <- filter_list[[attr_name]]

    filtered_df <- filtered_df[filtered_df[[attr_name]] %in% attr_value, ]
  }

  return(filtered_df)
}


#' Calculate actual interaction values between two terms in a data frame.
#'
#' This function calculates the actual interaction values between two terms, \code{lbl_term_1} and \code{lbl_term_2},
#' in the given data frame \code{data}. The interaction values are computed based on the mean log expression levels
#' of the conditions satisfying the specified term combinations, and also considering a reference condition.
#'
#' @param data A data frame containing the expression data and associated terms.
#' @param l_reference A data frame representing the reference condition for the interaction.
#' @param clmn_term_1 The name of the column in \code{data} representing the first term.
#' @param lbl_term_1 The label of the first term to compute interactions for.
#' @param clmn_term_2 The name of the column in \code{data} representing the second term.
#' @param lbl_term_2 The label of the second term to compute interactions for.
#'
#' @return A numeric vector containing the actual interaction values between the specified terms.
#' @export
#' @examples
#' average_gt <- data.frame(clmn_term_1 = c("A", "A", "B", "B"), 
#'                          clmn_term_2 = c("X", "Y", "Y", "X"),
#'                          logQij_mean = c(1.5, 8.0, 0.5, 4.0))
#' # Définir les paramètres de la fonction
#' l_label <- list(clmn_term_1 = c("A", "B"), clmn_term_2 = c("X", "Y"))
#' clmn_term_1 <- "clmn_term_1"
#' lbl_term_1 <- "B"
#' clmn_term_2 <- "clmn_term_2"
#' lbl_term_2 <- "Y"
#' # Calculer la valeur d'interaction réelle
#' actual_interaction <- calculate_actual_interactionX2_values(average_gt, 
#'                                        l_label, clmn_term_1, lbl_term_1, 
#'                                        clmn_term_2, lbl_term_2)
calculate_actual_interactionX2_values <- function(data, l_reference , clmn_term_1, lbl_term_1, clmn_term_2, lbl_term_2) {
  A <- data[data[[clmn_term_1]] == lbl_term_1 & 
              data[[clmn_term_2]] == lbl_term_2, ]
  B <- data[data[[clmn_term_1]] == lbl_term_1 & 
              data[[clmn_term_2]] == l_reference[[clmn_term_2]][1], ]
  C <- data[data[[clmn_term_1]] == l_reference[[clmn_term_1]][1] & 
              data[[clmn_term_2]] == lbl_term_2, ]
  D <- data[data[[clmn_term_1]] == l_reference[[clmn_term_1]][1] &
              data[[clmn_term_2]] == l_reference[[clmn_term_2]][1], ]
  actual_interaction <- (A$logQij_mean - B$logQij_mean) - (C$logQij_mean - D$logQij_mean)
  return(actual_interaction)
}


#' Prepare data for computing interaction values.
#'
#' This function prepares the data for computing interaction values between variables.
#' It filters the \code{dataActual} data frame by selecting only the rows where the categorical variables
#' specified in \code{categorical_vars} are at their reference levels.
#'
#' @param categorical_vars A character vector containing the names of categorical variables.
#' @param categorical_varsInInteraction A character vector containing the names of categorical variables involved in interactions.
#' @param dataActual A data frame containing the actual data with categorical variables and associated expression levels.
#'
#' @return A data frame containing the filtered data for computing interaction values.
#' @export
prepareData2computeInteraction <- function(categorical_vars, categorical_varsInInteraction, dataActual){
  l_RefInCategoricalVars <- lapply(dataActual[, categorical_vars], function(vector) levels(vector)[1])
  l_categoricalVars_NOT_InInteraction <-  categorical_vars[! categorical_vars %in% categorical_varsInInteraction ]
  l_filter <- l_RefInCategoricalVars[l_categoricalVars_NOT_InInteraction]
  dataActual_2computeInteractionValues <- filter_dataframe(dataActual, l_filter)
  return(dataActual_2computeInteractionValues)
}



#' Generate actual values for the interaction fixed effect.
#'
#' This function calculates the actual values for the interaction fixed effect
#' based on the input labels in the interaction, categorical variables in the interaction,
#' data to compute interaction values, actual intercept, and the reference levels in
#' categorical variables.
#'
#' @param labelsInInteraction A vector containing the labels of the interaction terms.
#' @param l_categoricalVarsInInteraction A vector containing the names of categorical variables
#'                                        involved in the interaction.
#' @param data2computeInteraction The data frame used to compute interaction values.
#' @param l_RefInCategoricalVars A list containing the reference levels of categorical variables.
#'
#' @return A data frame with the actual values for the interaction fixed effect.
#' The data frame includes columns: term, actual, and description.
#'
#' @export
generateActualInteractionX2_FixEff <- function(labelsInInteraction, l_categoricalVarsInInteraction, 
                                               data2computeInteraction, l_RefInCategoricalVars ){
  clmn_term_1 <- l_categoricalVarsInInteraction[1]
  lbl_term_1 <- labelsInInteraction[1]
  clmn_term_2 <- l_categoricalVarsInInteraction[2]
  lbl_term_2 <- labelsInInteraction[2]
  interactionValues <- calculate_actual_interactionX2_values(data2computeInteraction,
                                                              l_RefInCategoricalVars, clmn_term_1,
                                                              lbl_term_1, clmn_term_2, lbl_term_2)


  df_actualForMyInteraction <- data.frame(geneID = unique(data2computeInteraction$geneID))
  df_actualForMyInteraction$term <- paste(labelsInInteraction, collapse = ":")
  df_actualForMyInteraction$actual <- interactionValues
  df_actualForMyInteraction$description <- paste(gsub("\\d+$", "", lbl_term_1) , 
                                                 gsub("\\d+$", "", lbl_term_2), sep = ":")

  return(df_actualForMyInteraction)

}


#' Generate Actual Interaction Values for Three Fixed Effects
#'
#' This function generates actual interaction values for three fixed effects in a dataset. It takes the labels of the three fixed effects, the dataset, and the reference values for the categorical variables. The function computes the actual interaction values and returns a data frame containing the geneID, the term description, and the actual interaction values.
#'
#' @param labelsInInteraction A character vector of labels for the three fixed effects.
#' @param l_categoricalVarsInInteraction A list of categorical variable names corresponding to the three fixed effects.
#' @param data2computeInteraction The dataset on which to compute the interaction values.
#' @param l_RefInCategoricalVars A list of reference values for the categorical variables.
#'
#' @return A data frame with geneID, term description, and actual interaction values.
#'
#' @export
generateActualInteractionX3_FixEff <- function(labelsInInteraction, l_categoricalVarsInInteraction,
                                            data2computeInteraction, l_RefInCategoricalVars) {

   clmn_term_1 <- l_categoricalVarsInInteraction[1]
  lbl_term_1 <- labelsInInteraction[1]
  clmn_term_2 <- l_categoricalVarsInInteraction[2]
  lbl_term_2 <- labelsInInteraction[2]
  clmn_term_3 <- l_categoricalVarsInInteraction[3]
  lbl_term_3 <- labelsInInteraction[3]
  interactionValues <- calculate_actual_interactionX3_values(data2computeInteraction,
                                                          l_RefInCategoricalVars, clmn_term_1,
                                                           lbl_term_1, clmn_term_2, lbl_term_2, lbl_term_3, clmn_term_3)


  df_actualForMyInteraction <- data.frame(geneID = unique(data2computeInteraction$geneID))
  df_actualForMyInteraction$term <- paste(labelsInInteraction, collapse = ":")
  df_actualForMyInteraction$actual <- interactionValues
  df_actualForMyInteraction$description <- paste(gsub("\\d+$", "", lbl_term_1) ,
                                                 gsub("\\d+$", "", lbl_term_2),
                                                 gsub("\\d+$", "", lbl_term_3), sep = ":")

  return(df_actualForMyInteraction)
  
}


#' Calculate Actual Interaction Values for Three Fixed Effects
#'
#' This function calculates actual interaction values for three fixed effects in a dataset. It takes the data, reference values for categorical variables, and the specifications for the fixed effects. The function computes the interaction values and returns the result.
#'
#' @param data The dataset on which to calculate interaction values.
#' @param l_reference A list of reference values for categorical variables.
#' @param clmn_term_1 The name of the first categorical variable.
#' @param lbl_term_1 The label for the first categorical variable.
#' @param clmn_term_2 The name of the second categorical variable.
#' @param lbl_term_2 The label for the second categorical variable.
#' @param lbl_term_3 The label for the third categorical variable.
#' @param clmn_term_3 The name of the third categorical variable.
#'
#' @return The computed actual interaction values.
#'
#' @export
calculate_actual_interactionX3_values <- function(data, l_reference, clmn_term_1, lbl_term_1, 
                                                  clmn_term_2, lbl_term_2, lbl_term_3, clmn_term_3) {
  ## Label term 3
  A <- data[data[[clmn_term_1]] == lbl_term_1 & 
              data[[clmn_term_2]] == lbl_term_2 & 
              data[[clmn_term_3]] == lbl_term_3, ]
  
  B <- data[data[[clmn_term_1]] == l_reference[[clmn_term_1]][1] & 
              data[[clmn_term_2]] == lbl_term_2 & 
              data[[clmn_term_3]] == lbl_term_3 , ]
  
  C <- data[data[[clmn_term_1]] == lbl_term_1 & 
              data[[clmn_term_2]] == l_reference[[clmn_term_2]][1] & 
              data[[clmn_term_3]] == lbl_term_3, ]
  
  D <- data[data[[clmn_term_1]] == l_reference[[clmn_term_1]][1] & 
              data[[clmn_term_2]] == l_reference[[clmn_term_2]][1] & 
              data[[clmn_term_3]] == lbl_term_3, ]
  
  termA = (A$logQij_mean-B$logQij_mean) - (C$logQij_mean - D$logQij_mean)
  
  ## Label term 3 == reference !
  A <- data[data[[clmn_term_1]] == lbl_term_1 & 
              data[[clmn_term_2]] == lbl_term_2 & 
              data[[clmn_term_3]] == l_reference[[clmn_term_3]][1], ]
  
  B <- data[data[[clmn_term_1]] == l_reference[[clmn_term_1]][1] & 
              data[[clmn_term_2]] == lbl_term_2 & 
              data[[clmn_term_3]] == l_reference[[clmn_term_3]][1] , ]
  
  C <- data[data[[clmn_term_1]] == lbl_term_1 & 
              data[[clmn_term_2]] == l_reference[[clmn_term_2]][1] & 
              data[[clmn_term_3]] == l_reference[[clmn_term_3]][1], ]
  
  D <- data[data[[clmn_term_1]] == l_reference[[clmn_term_1]][1] & 
              data[[clmn_term_2]] == l_reference[[clmn_term_2]][1] & 
              data[[clmn_term_3]] == l_reference[[clmn_term_3]][1], ]
  
  termB = (A$logQij_mean-B$logQij_mean) - (C$logQij_mean - D$logQij_mean)
  actual_interaction <- termA - termB
  return(actual_interaction)
}



#' Get the actual interaction values for a given interaction term in the data.
#'
#' This function takes an interaction term, the dataset, and the names of the categorical variables 
#' as inputs. It calculates the actual interaction values based on the difference in log-transformed 
#' mean expression levels for the specified interaction term. The function first prepares the data for 
#' computing the interaction values and then generates the actual interaction values.
#'
#' @param labelsInInteraction A character vector containing the labels of the categorical levels 
#'     involved in the interaction.
#' @param data The dataset containing the gene expression data and categorical variables.
#' @param categorical_vars A character vector containing the names of the categorical variables in 
#'     the dataset.
#' @return A data frame containing the actual interaction values.
#' @export 
getActualInteractionFixEff <- function(labelsInInteraction, data, categorical_vars ){
  l_RefInCategoricalVars <- lapply(data[, categorical_vars], function(vector) levels(vector)[1])
  l_labelsInCategoricalVars <- lapply(data[, categorical_vars], levels)
  l_categoricalVarsInInteraction <- lapply(labelsInInteraction,
                                           function(label) findAttribute(label, 
                                                        l_labelsInCategoricalVars)) %>% 
                                    unlist()
  data2computeInteraction <- prepareData2computeInteraction(categorical_vars, l_categoricalVarsInInteraction,  data )

  ## Interaction x3
  if (length(labelsInInteraction) == 3){
        actualInteractionValues <- generateActualInteractionX3_FixEff(labelsInInteraction,
                                                                     l_categoricalVarsInInteraction ,
                                                                     data2computeInteraction, 
                                                                     l_RefInCategoricalVars)
  }
  # Interaction x2
  if (length(labelsInInteraction) == 2){
    actualInteractionValues <- generateActualInteractionX2_FixEff(labelsInInteraction,
                                                               l_categoricalVarsInInteraction ,
                                                               data2computeInteraction, 
                                                               l_RefInCategoricalVars)
  }
  return(actualInteractionValues)
}


#' Compute actual interaction values for multiple interaction terms.
#'
#' This function calculates the actual interaction values for multiple interaction terms 
#' using the provided data.
#'
#' @param l_interactionTerm A list of interaction terms in the form of "term1:term2".
#' @param categorical_vars A character vector containing the names of categorical variables in the data.
#' @param dataActual The data frame containing the actual gene expression values and metadata.
#'
#' @return A data frame containing the actual interaction values for each interaction term.
#' @export
#' @examples
#' N_GENES <- 4
#' MIN_REPLICATES <- 3
#' MAX_REPLICATES <- 3
#' init_var <- init_variable(name = "varA", mu = 8, sd = 0.1, level = 3) %>%
#'   init_variable(name = "varB", mu = c(5,-5), NA , level = 2) %>%
#'   init_variable(name = "varC", mu = 1, 3, 3) %>%
#'   add_interaction(between_var = c("varA", "varC"), mu = 5, 0.1)
#' mock_data <- mock_rnaseq(init_var, N_GENES, 
#'                          MIN_REPLICATES, MAX_REPLICATES )
#' data2fit <- prepareData2fit(countMatrix = mock_data$counts, 
#'                              metadata =  mock_data$metadata )
#' results_fit <- fitModelParallel(formula = kij ~ varA + varB + varC + varA:varC,
#'                              data = data2fit, group_by = "geneID",
#'                              family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
#' tidy_tmb <- tidy_tmb(results_fit)
#' fixEff_dataInference  <- subsetFixEffectInferred(tidy_tmb)
#' fixEff_dataActual <- getData2computeActualFixEffect(mock_data$groundTruth$effects)
#' interactionTerm <- fixEff_dataInference$fixed_term$interaction[[1]]
#' categorical_vars <- fixEff_dataActual$categorical_vars
#' dataActual <- fixEff_dataActual$data
#' l_labelsInCategoricalVars <- lapply(dataActual[, categorical_vars], levels)
#' l_interaction <- strsplit(interactionTerm, split = ":")[[1]]
#' l_categoricalVarsInInteraction <- lapply(l_interaction,
#'                                          function(label) findAttribute(label, 
#'                                          l_labelsInCategoricalVars)) %>% 
#'                                          unlist()
#' data_prepared <- prepareData2computeInteraction(categorical_vars, 
#'                    l_categoricalVarsInInteraction, dataActual)
#' # Compute actual interaction values for multiple interactions
#' actualInteraction <- computeActualInteractionFixEff(interactionTerm, categorical_vars, dataActual)
computeActualInteractionFixEff <- function(l_interactionTerm, categorical_vars, dataActual){

  l_interaction <- strsplit(l_interactionTerm, split = ":")
  l_interactionActualValues <- lapply(l_interaction, function(labelsInInteraction)
                                getActualInteractionFixEff(labelsInInteraction, dataActual, categorical_vars))
  actualInteraction_df <- do.call('rbind', l_interactionActualValues)
  return(actualInteraction_df)
}
```

```{r test-actual_interactionfixeffects }

test_that("filter_dataframe retourne le dataframe filtré correctement", {
  # Créer un exemple de dataframe
  df <- data.frame(
  col1 = c(1, 2, 3, 4, 5),
  col2 = c("A", "B", "C", "D", "E"),
  col3 = c("X", "Y", "Z", "X", "Y")
  )
  
  # Créer une liste de filtres
  filter_list <- list(
    col1 = c(2),
    col2 = "B",
    col3 = c("Y")
  )

  # Appliquer les filtres sur le dataframe
  filtered_df <- filter_dataframe(df, filter_list)

  # Vérifier que les lignes correspondantes sont présentes dans le dataframe filtré
  expect_equal(nrow(filtered_df), 1)
  expect_true(all(filtered_df$col1 %in% c(2)))
  expect_true(all(filtered_df$col2 == "B"))
  expect_true(all(filtered_df$col3 %in% c("Y")))
})

test_that("filter_dataframe retourne le dataframe d'origine si aucun filtre n'est spécifié", {
  # Créer une liste de filtres vide
  filter_list <- list()

  # Appliquer les filtres sur le dataframe
  filtered_df <- filter_dataframe(df, filter_list)

  # Vérifier que le dataframe filtré est identique au dataframe d'origine
  expect_identical(filtered_df, df)
})

test_that("calculate_actual_interactionX2_values retourne la valeur d'interaction réelle correctement", {
  average_gt <- data.frame(
  clmn_term_1 = c("A", "A", "B", "B"),
  clmn_term_2 = c("X", "Y", "X", "Y"),
  logQij_mean = c(1.5, 2.0, 85, 1.0)
  )

  # Définir les paramètres de la fonction
  l_label <- list(clmn_term_1 = c("A", "B"), clmn_term_2 = c("X", "Y"))
  clmn_term_1 <- "clmn_term_1"
  lbl_term_1 <- "B"
  clmn_term_2 <- "clmn_term_2"
  lbl_term_2 <- "Y"

  # Calculer la valeur d'interaction réelle
  actual_interaction <- calculate_actual_interactionX2_values(average_gt, l_label, clmn_term_1, lbl_term_1, clmn_term_2, lbl_term_2)

  # Vérifier que la valeur d'interaction réelle est correcte
  expect_equal(actual_interaction, -84.5)
})



test_that("prepareData2computeInteraction filters data correctly", {
  
  data <- data.frame(
  geneID = c("gene1", "gene2", "gene3", "gene4"),
  label_varA = factor(c("A", "A", "B", "B")),
  label_varB = factor(c("X", "X", "Y", "Y")),
  label_varC = factor(c("P", "P", "Q", "Q")),
  logQij_mean = c(1.2, 3.4, 5.6, 7.8)
  )
  categorical_vars <- c("label_varA", "label_varB", "label_varC")
  categorical_varsInInteraction <- c("label_varA", "label_varC")

  dataActual_2computeInteractionValues <- prepareData2computeInteraction(categorical_vars, categorical_varsInInteraction, data)

  expect_equal(nrow(dataActual_2computeInteractionValues), 2)
  expect_true(all(dataActual_2computeInteractionValues$label_varA %in% c("A", "A")))
  expect_true(all(dataActual_2computeInteractionValues$label_varB %in% c("X", "X")))
  expect_true(all(dataActual_2computeInteractionValues$label_varC %in% c("P", "P")))
  expect_equal(dataActual_2computeInteractionValues$logQij_mean, c(1.2, 3.4 ))
})



## TEST
test_that("Generate actual interaction fixed effect correctly", {
  
  ########################################################################"
  N_GENES <- 4
  MIN_REPLICATES <- 3
  MAX_REPLICATES <- 3
  
  init_var <- init_variable(name = "varA", mu = 8, sd = 0.1, level = 3) %>%
  init_variable(name = "varB", mu = c(5, -5), NA, level = 2) %>%
  init_variable(name = "varC", mu = 1, 3, 3) %>%
  add_interaction(between_var = c("varA", "varC"), mu = 5, 0.1)
  
  # -- simulation
  mock_data <- mock_rnaseq(init_var, N_GENES, min_replicates = MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  
  # -- fit data
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata = mock_data$metadata, normalization = 'MRN')
  
  dtf_countsLong <- countMatrix_2longDtf(mock_data$counts, "k_ij")
  metadata_columnForjoining <- getColumnWithSampleID(dtf_countsLong, mock_data$metadata)
  
  example_spleID <- as.character(dtf_countsLong[1, "sampleID"])
  regex <- paste("^", as.character(dtf_countsLong[1, "sampleID"]), "$", sep = "")
  
 
  
  
  results_fit <- fitModelParallel(formula = kij ~ varA + varB + varC + varA:varC,
                                data = data2fit, group_by = "geneID",
                                family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
  
  # -- inputs
  tidy_tmb <- tidy_tmb(results_fit)
  fixEff_dataInference <- subsetFixEffectInferred(tidy_tmb)
  fixEff_dataActual <- getData2computeActualFixEffect(mock_data$groundTruth$effects)
  
  interactionTerm <- fixEff_dataInference$fixed_term$interaction[[1]]
  categorical_vars <- fixEff_dataActual$categorical_vars
  dataActual <- fixEff_dataActual$data
  l_labelsInCategoricalVars <- lapply(dataActual[, categorical_vars], levels)
  l_interaction <- strsplit(interactionTerm, split = ":")[[1]]
  l_categoricalVarsInInteraction <- lapply(l_interaction,
                                          function(label) findAttribute(label, l_labelsInCategoricalVars)) %>% unlist()
  
  data_prepared <- prepareData2computeInteraction(categorical_vars, l_categoricalVarsInInteraction, dataActual)
  actual_intercept <- getActualIntercept(fixEff_dataActual)
  l_RefInCategoricalVars <- lapply(dataActual[, categorical_vars], function(vector) levels(vector)[1])
  #######################################################################
  
  actualInteraction <- generateActualInteractionX2_FixEff(l_interaction, l_categoricalVarsInInteraction, 
                                                          data_prepared, l_RefInCategoricalVars)

  # Add your assertions here based on the expected values
  # For example:
  expect_true(nrow(actualInteraction) == 4)
  expect_equal(actualInteraction$geneID,  c("gene1", "gene2", "gene3", "gene4"))
  expect_true(all(actualInteraction$term %in%  'varA2:varC2'))
  #expect_true(all(actualInteraction$description %in%  'interaction'))
  expect_true(is.numeric(actualInteraction$actual))

  # Add more assertions as needed...
})


# Test the function `generateActualInteractionX2_FixEff`
test_that("Test generateActualInteractionX2_FixEff function", {
  # Generate example data
  data <- data.frame(
    geneID = rep(x = c("gene1", "gene2"), each = 8),
    logQij_mean = 1:16
    
  )
  metadata = expand.grid(list(varA = factor(c("A1", "A2")),
    varB = factor(c("B1", "B2")),
    varC = factor(c("C1", "C2"))))
  metadata = rbind(metadata, metadata)
  
  data <- cbind(metadata, data)
  
  categorical_vars <- c("varA", "varB", "varC")
  labelsInInteraction <- c("A2", "C2")
  
  actual_intercept <- data.frame(actual = c(23, 21 ), 
                                 geneID = c("gene1", "gene2"), 
                                 term = c("(Intercept)", "(Intercept)"), 
                                 description = c("(Intercept)", "(Intercept)"))
  # Run the function
  
  actualInteractionValues <- getActualInteractionFixEff(labelsInInteraction, data, categorical_vars  )

  
  # Define the expected output based on the example data
  expected_output <- data.frame(
    term = "A2:C2",
    geneID = c("gene1", "gene2"),
    actual = c(0, 0),
    description = c("A:C", "A:C")
  )
  
  # Add your assertions here to compare the actual output with the expected output
  expect_equal(nrow(actualInteractionValues), nrow(expected_output))
  expect_equal(actualInteractionValues$geneID, expected_output$geneID)
  expect_equal(actualInteractionValues$term, expected_output$term)
  expect_equal(actualInteractionValues$actual, expected_output$actual)
  #expect_equal(actualInteractionValues$description, expected_output$description)

})



# Test for generateActualInteractionX3FixEff
test_that("generateActualInteractionX3FixEff returns correct data frame", {
  
  # Create reference values
  reference <- list(
    varA = c("A1", "A2"),
    varB = c("B1", "B2"),
    varC = c("C1", "C2")
  )
  # Generate example data
  set.seed(123)
  data <- data.frame(
    geneID = rep(x = c("gene1", "gene2"), each = 8),
    logQij_mean = sample(x = -3:12, 16)
    
  )
  metadata = expand.grid(list(varA = factor(c("A1", "A2")),
    varB = factor(c("B1", "B2")),
    varC = factor(c("C1", "C2"))))
  metadata = rbind(metadata, metadata)
  
  data <- cbind(metadata, data)
  
  # Call the function
  result <- generateActualInteractionX3_FixEff(
    labelsInInteraction = c("A2", "B2", "C2"),
    l_categoricalVarsInInteraction = c("varA", "varB", "varC"),
    data2computeInteraction = data,
    l_RefInCategoricalVars = reference
  )
  
  # Check the result
  expect_equal(nrow(result), 2)
  expect_equal(ncol(result), 4)
  expect_identical(result$term, c("A2:B2:C2","A2:B2:C2"))
  expect_equal(result$actual, c(-3, 13))
  expect_identical(result$description, c("A:B:C", "A:B:C"))
})

# Test for calculate_actual_interactionX3_values
test_that("calculate_actual_interactionX3_values returns correct values", {
  # Create reference values
  reference <- list(
    varA = c("A1", "A2"),
    varB = c("B1", "B2"),
    varC = c("C1", "C2")
  )
  # Generate example data
  set.seed(123)
  data <- data.frame(
    geneID = rep(x = c("gene1", "gene2"), each = 8),
    logQij_mean = sample(x = -8:8, 16)
    
  )
  metadata = expand.grid(list(varA = factor(c("A1", "A2")),
    varB = factor(c("B1", "B2")),
    varC = factor(c("C1", "C2"))))
  metadata = rbind(metadata, metadata)
  
  data <- cbind(metadata, data)
  # Call the function
  result <- calculate_actual_interactionX3_values(
    data = data,
    l_reference = reference,
    clmn_term_1 = "varA",
    lbl_term_1 = "A2",
    clmn_term_2 = "varB",
    lbl_term_2 = "B2",
    lbl_term_3 = "C2",
    clmn_term_3 = "varC"
  )
  
  # Check the result
  expect_equal(result, c(-7, 11))
})



## Test interaction X2
test_that("Test getActualInteractionFixEff", {

  # Exemple de données d'entrée
  N_GENES <- 4
  MIN_REPLICATES <- 3
  MAX_REPLICATES <- 3
  
  init_var <- init_variable(name = "varA", mu = 8, sd = 0.1, level = 3) %>%
    init_variable(name = "varB", mu = c(5,-5), NA, level = 2) %>%
    init_variable(name = "varC", mu = 1, 3, 3) %>%
    add_interaction(between_var = c("varA", "varC"), mu = 5, 0.1)
  
  # Simulation
  mock_data <- mock_rnaseq(init_var, N_GENES, min_replicates = MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  
  # Données de fit
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata = mock_data$metadata, normalization = 'MRN')
  results_fit <- fitModelParallel(formula = kij ~ varA + varB + varC + varA:varC,
                                  data = data2fit, group_by = "geneID",
                                  family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
  
  # Données d'entrée
  tidy_tmb <- tidy_tmb(results_fit)
  fixEff_dataInference <- subsetFixEffectInferred(tidy_tmb)
  fixEff_dataActual <- getData2computeActualFixEffect(mock_data$groundTruth$effects)
  interactionTerm <- fixEff_dataInference$fixed_term$interaction[[1]]
  categorical_vars <- fixEff_dataActual$categorical_vars
  dataActual <- fixEff_dataActual$data
  l_labelsInCategoricalVars <- lapply(dataActual[, categorical_vars], levels)
  l_interaction <- strsplit(interactionTerm, split = ":")[[1]]
  l_categoricalVarsInInteraction <- lapply(l_interaction,
                                           function(label) findAttribute(label, l_labelsInCategoricalVars)) %>% unlist()
  
  data_prepared <- prepareData2computeInteraction(categorical_vars, l_categoricalVarsInInteraction, dataActual)
  #actual_intercept <- getActualIntercept(fixEff_dataActual)
  
  # Appel de la fonction à tester
  actualInteraction <- getActualInteractionFixEff(l_interaction, data_prepared, categorical_vars)
  

  expect_true(nrow(actualInteraction) == 4)
  expect_equal(actualInteraction$geneID,  c("gene1", "gene2", "gene3", "gene4"))
  expect_true(all(actualInteraction$term %in%  'varA2:varC2'))
  #expect_true(all(actualInteraction$description %in%  'interaction'))
  expect_true(is.numeric(actualInteraction$actual))
})


## Test interaction X3
test_that("Test getActualInteractionFixEff", {

  # Exemple de données d'entrée
  N_GENES <- 4
  MIN_REPLICATES <- 20
  MAX_REPLICATES <- 20
  
 init_var <- init_variable( name = "varA", mu = 3,sd = 1, level = 2) %>%
    init_variable( name = "varB", mu = 2, sd = 2, level = 2) %>%
      init_variable( name = "varC", mu = 2, sd = 1, level = 2) %>%
      add_interaction(between_var = c("varA", 'varC'), mu = 0.3, sd = 1) %>%
      add_interaction(between_var = c("varB", 'varC'), mu = 2, sd = 1) %>%
      add_interaction(between_var = c("varA", 'varB'), mu = -2, sd = 1) %>%
      add_interaction(between_var = c("varA", 'varB', "varC"), mu = 1, sd = 1)
    
  
  # Simulation
  mock_data <- mock_rnaseq(init_var, N_GENES, 
                           min_replicates = MIN_REPLICATES, 
                           max_replicates = MAX_REPLICATES, dispersion = 100)
  
  # Données de fit
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata = mock_data$metadata, normalization = 'MRN')
  results_fit <- fitModelParallel(formula = kij ~ varA + varB + varC + varA:varB + varB:varC + varA:varC + varA:varB:varC,
                                  data = data2fit, group_by = "geneID",
                                  family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
  
  # Données d'entrée
  tidy_tmb <- tidy_tmb(results_fit)
  fixEff_dataInference <- subsetFixEffectInferred(tidy_tmb)
  fixEff_dataActual <- getData2computeActualFixEffect(mock_data$groundTruth$effects)
  interactionTerm <- fixEff_dataInference$fixed_term$interaction[[4]]
  categorical_vars <- fixEff_dataActual$categorical_vars
  dataActual <- fixEff_dataActual$data
  l_labelsInCategoricalVars <- lapply(dataActual[, categorical_vars], levels)
  l_interaction <- strsplit(interactionTerm, split = ":")[[1]]
  l_categoricalVarsInInteraction <- lapply(l_interaction,
                                           function(label) findAttribute(label, l_labelsInCategoricalVars)) %>% unlist()
  
  data_prepared <- prepareData2computeInteraction(categorical_vars, l_categoricalVarsInInteraction, dataActual)

  actualInteraction <- getActualInteractionFixEff(l_interaction, data_prepared, categorical_vars)
  

  expect_true(nrow(actualInteraction) == 4)
  expect_equal(actualInteraction$geneID,  c("gene1", "gene2", "gene3", "gene4"))
  expect_true(all(actualInteraction$term %in%  'varA2:varB2:varC2'))
  expect_true(all(actualInteraction$description %in%  'varA:varB:varC'))
  expect_true(is.numeric(actualInteraction$actual))
})


```

```{r function-inferenceToExpected, filename =  "inferenceToExpected" }

#' Compare the results of inference with the ground truth data.
#'
#' This function takes the data frames containing the inference results and the ground truth data
#' and generates a table to compare the inferred values with the expected values.
#'
#' @param tidy_tmb A data frame containing the results of inference.
#' @param df_ground_truth A data frame containing the ground truth data used for simulation.
#'
#' @return A data frame
#'
#' @examples
#' \dontrun{
#' inferenceToExpected_withFixedEff(tidy_tmb, df_ground_truth)
#' }
#'
#' @export
inferenceToExpected_withFixedEff <- function(tidy_tmb , df_ground_truth) {

  ## -- get data
  fixEff_dataInference  <- subsetFixEffectInferred(tidy_tmb)
  fixEff_dataActual <- getData2computeActualFixEffect(df_ground_truth)
  actual_intercept <- getActualIntercept(fixEff_dataActual)

  ## -- main = non interaction
  l_mainEffectTerm <- fixEff_dataInference$fixed_term$nonInteraction
  actual_mainFixEff <- getActualMainFixEff(l_mainEffectTerm, fixEff_dataActual, actual_intercept)

  ## -- interaction term
  l_interactionTerm <- fixEff_dataInference$fixed_term$interaction
  categorical_vars <- fixEff_dataActual$categorical_vars
  data <- fixEff_dataActual$data
  actualInteractionFixEff <- computeActualInteractionFixEff(l_interactionTerm, categorical_vars, data)

  ## -- rbind Interaction & Main
  actual_fixEff <- rbind(actual_mainFixEff , actualInteractionFixEff, actual_intercept )

  ## -- join inference & actual
  inference_fixEff <- fixEff_dataInference$data
  res <- join_dtf(inference_fixEff, actual_fixEff  ,  c("ID", "term"), c("geneID", "term"))
  return(res)
}

```


```{r function-waldtest, filename =  "waldtest" }

#' Wald test for hypothesis testing
#'
#' This function performs a Wald test for hypothesis testing by comparing an estimation
#' to a reference value using the provided standard error. It allows testing for
#' one-tailed alternatives: "greater" - β > reference_value, "less" - β < −reference_value,
#' or two-tailed alternative: "greaterAbs" - |β| > reference_value.
#' If the p-value obtained is greater than 1, it is set to 1 to avoid invalid p-values.
#'
#' @param estimation The estimated coefficient value.
#' @param std_error The standard error of the estimation.
#' @param reference_value The reference value for comparison (default is 0).
#' @param alternative The type of alternative hypothesis to test (default is "greaterAbs").
#' @return A list containing the test statistic and p-value.
#' @importFrom stats pnorm
#' @export
#' @examples
#' # Perform a Wald test with the default "greaterAbs" alternative
#' wald_test(estimation = 0.1, std_error = 0.02, reference_value = 0.2)
wald_test <- function(estimation, std_error, reference_value = 0, alternative = "greaterAbs") {
  if (alternative == "greater") {
    test_statistic <- (estimation - reference_value) / std_error
    p_value <- 1 - stats::pnorm(test_statistic, mean = 0, sd = 1, lower.tail = TRUE)
  } else if (alternative == "less") {
    test_statistic <- (estimation - reference_value) / std_error
    p_value <- pnorm(test_statistic, mean = 0, sd = 1, lower.tail = TRUE)
  } else if (alternative == "greaterAbs") {
    test_statistic <- (abs(estimation) - reference_value) / std_error
    p_value <- 2 * (1 - pnorm(test_statistic, mean = 0, sd = 1, lower.tail = TRUE))
  } else {
    stop("Invalid alternative type. Use 'greater', 'less', or 'greaterAbs'.")
  }

  # Set p-value to 1 if it exceeds 1
  p_value <- pmin(p_value, 1)
  return(list(statistic = test_statistic, p.value = p_value))
}




#' Perform statistical tests and return tidy results
#'
#' This function takes a list of glmmTMB objects and performs statistical tests based on the estimated coefficients and their standard errors. The results are returned in a tidy data frame format.
#'
#' @param list_tmb A list of glmmTMB objects representing the fitted models.
#' @param coeff_threshold  A non-negative value which specifies a ln(fold change) threshold. The Threshold  is used for the Wald test to determine whether the  coefficient (β) is significant or not, depending on \code{alt_hypothesis} parameter. Default is 0, ln(FC = 1).
#' @param alternative_hypothesis Alternative hypothesis for the Wald test (default is "greaterAbs").
#' Possible choice: 
#' "greater" 
#' - β > coeff_threshold, 
#' "less" 
#' - β < −coeff_threshold,
#' or two-tailed alternative: 
#' "greaterAbs" 
#' - |β| > coeff_threshold
#' @param correction_method a character string indicating the correction method to apply to p-values. Possible values are: 
#'                          "holm", "hochberg", "hommel", #' "bonferroni", "BH", "BY", "fdr", and "none".
#'
#' @return A tidy data frame containing the results of statistical tests for the estimated coefficients.
#'
#' @importFrom stats p.adjust
#' @export
#'
#' @examples
#' data(iris)
#' model_list <- fitModelParallel(formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                  data = iris, group_by = "Species", n.cores = 1) 
#' results_df <- tidy_results(model_list, coeff_threshold = 0.1, alternative_hypothesis = "greater")
tidy_results <- function(list_tmb, coeff_threshold = 0, alternative_hypothesis = "greaterAbs", correction_method = "BH") {
  
  invisible(isValidList_tmb(list_tmb))
  tidy_tmb_df <- tidy_tmb(list_tmb)
  if (coeff_threshold != 0 || alternative_hypothesis != "greaterAbs") {
    waldRes <- wald_test(tidy_tmb_df$estimate, tidy_tmb_df$std.error, coeff_threshold, alternative_hypothesis)
    tidy_tmb_df$statistic <- waldRes$statistic
    tidy_tmb_df$p.value <- waldRes$p.value
  }
  tidy_tmb_df$p.adj <- stats::p.adjust(tidy_tmb_df$p.value, method = correction_method)
  return(tidy_tmb_df)
}


```



```{r test-waldtest}

# Test unitaires
test_that("wald_test performs correct tests", {
  # Test with "greater" alternative
  result_greater <- wald_test(estimation = 0.1, std_error = 0.02, reference_value = 0.05, alternative = "greater")
  expect_equal(result_greater$p.value, 1 - pnorm((0.1 - 0.05) / 0.02, mean = 0, sd = 1, lower.tail = TRUE))

  # Test with "less" alternative
  result_less <- wald_test(estimation = 0.1, std_error = 0.02, reference_value = 0.05, alternative = "less")
  expect_equal(result_less$p.value, pnorm((0.1 - 0.05) / 0.02, mean = 0, sd = 1, lower.tail = TRUE))

  # Test with "greaterAbs" alternative
  result_greaterAbs <- wald_test(estimation = 0.1, std_error = 0.02, reference_value = 0.05, alternative = "greaterAbs")
  expect_equal(result_greaterAbs$p.value, (2 * (1 - pnorm((abs(0.1) - 0.05) / 0.02, mean = 0, sd = 1, lower.tail = TRUE))))

  # Test with invalid alternative
  expect_error(wald_test(estimation = 0.1, std_error = 0.02, reference_value = 0.05, alternative = "invalid"))
})



test_that("results function performs statistical tests correctly", {
  # Charger les données iris pour les tests
  data(iris)
  # Fit models and perform statistical tests
  model_list <- fitModelParallel(formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
                                 data = iris, group_by = "Species", n.cores = 1) 
  results_df <- tidy_results(model_list, coeff_threshold = 0.1, alternative_hypothesis = "greater")

  # Vérifier que les colonnes 'statistic' et 'p.value' ont été ajoutées au dataframe
  expect_true("statistic" %in% colnames(results_df))
  expect_true("p.value" %in% colnames(results_df))

  # Vérifier que les tests statistiques ont été effectués correctement
  # Ici, nous ne vérifierons pas les valeurs exactes des résultats car elles peuvent varier en fonction de la machine et des packages utilisés.
  # Nous nous assurerons seulement que les résultats sont dans le format attendu.
  expect_is(results_df$statistic, "numeric")
  expect_is(results_df$p.value, "numeric")
  expect_is(results_df$p.adj, "numeric")


  # Vérifier que les p-values ne dépassent pas 1
  expect_true(all(results_df$p.value <= 1))

  # Vérifier que les valeurs sont correctes pour les colonnes 'statistic' et 'p.value'
  # (Cela dépend des données iris et des modèles ajustés)
  # Remarque : Vous devrez peut-être ajuster ces tests en fonction des valeurs réelles des données iris et des modèles ajustés.
  expect_true(all(!is.na(results_df$statistic)))
  expect_true(all(!is.na(results_df$p.value)))

  # Vérifier que le seuil des coefficients et l'hypothèse alternative sont correctement appliqués
  # Ici, nous nous attendons à ce que les p-values soient uniquement pour les coefficients dépassant le seuil
  expect_true(all(ifelse(abs(results_df$estimate) > 0.1, results_df$p.value <= 1, results_df$p.value == 1)))
  expect_true(all(ifelse(abs(results_df$estimate) > 0.1, results_df$p.adj <= 1, results_df$p.adj == 1)))

  })




```



```{r function-receiver_operating_characteristic, filename = "receiver_operating_characteristic"}


#' Get Labels for Expected Differential Expression
#'
#' This function assigns labels to genes based on whether their actual effect estimates
#' indicate differential expression according to a given threshold and alternative hypothesis.
#'
#' @param comparison_df A data frame containing comparison results with actual effect estimates.
#' @param coeff_threshold The threshold value for determining differential expression.
#' @param alt_hypothesis The alternative hypothesis for comparison. Possible values are "greater",
#'                      "less", and "greaterAbs".
#' @return A modified data frame with an additional column indicating if the gene is differentially expressed.
#'
#' @examples
#' # Generate a sample comparison data frame
#' comparison_data <- data.frame(
#'   geneID = c("gene1", "gene2", "gene3"),
#'   actual = c(0.5, -0.3, 0.8)
#' )
#'
#' # Get labels for expected differential expression
#' labeled_data <- getLabelExpected(comparison_data, coeff_threshold = 0.2, alt_hypothesis = "greater")
#'
#' @export
getLabelExpected <- function(comparison_df, coeff_threshold, alt_hypothesis) {
  if (alt_hypothesis == "greater") {
    idx_DE <- comparison_df$actual > coeff_threshold
    comparison_df$isDE <- idx_DE
  } else if (alt_hypothesis == "less") {
    idx_DE <- comparison_df$actual < coeff_threshold
    comparison_df$isDE <- idx_DE
  } else if (alt_hypothesis == "greaterAbs") {
    idx_DE <- abs(comparison_df$actual) > coeff_threshold
    comparison_df$isDE <- idx_DE
  }
  ## isDE for random params == NA
  idx_ran_pars <- comparison_df$effect == "ran_pars"
  comparison_df$isDE[idx_ran_pars] <- NA
  return(comparison_df)
}


#' Computes the ROC curve.
#'
#' This function takes a data frame with binary truth values and predicted scores,
#' computes the ROC curve, and returns a data frame containing specificity, sensitivity, and threshold values.
#' This function is inspired by the yardstick package.
#'
#' @param dt A data frame with columns truth (first column) and score (second column).
#' @return A data frame with specificity, sensitivity, and threshold values.
#' @export
compute_roc_curve <- function(dt){
  ## -- replace 0 by minimum machine 
  dt$p.adj[ dt$p.adj == 0 ] <- 1e-217
  pred_obj <- prediction( -log10(dt$p.adj), dt$isDE)
  perf_obj <- performance(pred_obj,"tpr","fpr")
  data2curve <- data.frame(x.name = perf_obj@x.values[[1]], y.name = perf_obj@y.values[[1]])
  names(data2curve) <- c(unname(perf_obj@x.name), unname(perf_obj@y.name))
  return(data2curve)
}


#' Computes area under the ROC curve (AUC).
#'
#' This function calculates the area under the ROC curve (AUC) using specificity and sensitivity values.
#'
#' @param dt A data table with columns for True positive rate and False positive rate
#' @return A numeric value representing the AUC.   
#' @export
compute_roc_auc <- function(dt) Area_Under_Curve(x  = dt$`False positive rate`, y = dt$`True positive rate`)




#' Gets ROC objects for a given parameter.
#'
#' This function takes a data table of evaluation parameters and returns ROC curves for each term
#' and an aggregate ROC curve along with corresponding AUC values.
#'
#' @param evaldata_params Data table containing evaluation parameters.
#' @param col_param Column name specifying the parameter for grouping.
#' @param col_truth Column name for binary ground truth values.
#' @param col_score Column name for predicted scores.
#' @return A list containing ROC curves and AUCs for each group and an aggregate ROC curve and AUC.
#' @importFrom data.table setDT .SD
#' @export
get_roc_object <- function(evaldata_params, col_param = "description", col_truth = "isDE", col_score = "p.adj"  ) {
  
  ## -- subset fixed eff
  evaldata_params <- subset(evaldata_params, effect == "fixed")
  
  ## -- data.table conversion
  dt_evaldata_params <- data.table::setDT(evaldata_params)

  ## -- by params
  roc_curve_params <- dt_evaldata_params[, compute_roc_curve(.SD), by=c("from", col_param), .SDcols=c(col_truth, col_score)]
  roc_auc_params <- roc_curve_params[, compute_roc_auc(.SD), by=c("from", col_param), .SDcols=c("False positive rate", "True positive rate")]
  names(roc_auc_params)[ names(roc_auc_params) == "V1" ] <- "roc_AUC"
  roc_auc_params$roc_randm_AUC <- 0.5

  ## -- aggregate
  roc_curve_agg <- dt_evaldata_params[, compute_roc_curve(.SD), by= "from", .SDcols=c(col_truth, col_score)]
  roc_auc_agg <- roc_curve_agg[, compute_roc_auc(.SD), by="from", .SDcols=c("False positive rate", "True positive rate")]
  names(roc_auc_agg)[ names(roc_auc_agg) == "V1" ] <- "roc_AUC"
  roc_auc_agg$roc_randm_AUC <- 0.5
  
  return(list(byparams = list(roc_curve = as.data.frame(roc_curve_params),
                              roc_auc = as.data.frame(roc_auc_params)),
              aggregate = list(roc_curve = as.data.frame(roc_curve_agg),
                               roc_auc = as.data.frame(roc_auc_agg)))
  )
  
}


#' Builds a ggplot ROC curve.
#'
#' This function takes data frames for ROC curve and AUC and builds a ggplot ROC curve.
#'
#' @param data_curve Data frame with ROC curve.
#' @param data_auc Data frame with AUC.
#' @param palette_color List of colors used.
#' @param ... Additional arguments to be passed to ggplot2::geom_path.
#' @return A ggplot object representing the ROC curve.
#' @importFrom ggplot2 ggplot geom_path geom_text theme_bw coord_fixed scale_color_manual aes 

#' @export 
build_gg_roc_curve <- function(data_curve, data_auc, palette_color = c("#500472", "#79cbb8") ,  ...){

 
  data_auc <- get_label_y_position(data_auc)
  
  ggplot2::ggplot(data_curve) +
    ggplot2::geom_path(ggplot2::aes(x = `False positive rate` , y = `True positive rate`, ...), linewidth = 1) +
    ggplot2::geom_text(data_auc,
                       mapping = ggplot2::aes(x = 0.75, y = pos_y,
                                              label = paste("AUC :", round(roc_AUC, 2) , sep = ""), col = from)
    ) +
    ggplot2::theme_bw() +
    ggplot2::coord_fixed() +
    ggplot2::scale_color_manual(values = palette_color)
}



#' Computes y-axis position for text labels.
#'
#' This function calculates the y-axis position for text labels in a ggplot based on the levels of a factor.
#' It is specifically designed for use with ROC curve plotting.
#'
#' @param data_auc Data frame with AUC values and factor levels.
#' @return A modified data frame with an additional column pos_y representing y-axis positions.
#' @export
get_label_y_position <- function(data_auc){
  ## -- y text  
  l_y_pos <- c(0.15, 0.05)
  lvls <- levels(as.factor(data_auc$from))
  vec_pos_y <- data_auc$from == lvls[[1]]
  vec_pos_y[vec_pos_y] <- l_y_pos[1]
  vec_pos_y[!vec_pos_y] <- l_y_pos[2]
  data_auc$pos_y <- vec_pos_y
  return(data_auc)
}


#' Gets ROC curves and AUC for both aggregated and individual parameters.
#'
#' This function takes a ROC object and returns ROC curves and AUCs for both aggregated and individual parameters.
#'
#' @param roc_obj ROC object.
#' @param ... Additional arguments to be passed to \code{ggplot2::geom_path}.
#' @return ROC curves and AUCs for both aggregated and individual parameters.
#' @importFrom ggplot2 facet_wrap
#' @export
get_roc_curve <- function(roc_obj, ...){
  
  ## -- aggreg
  data_curve <- roc_obj$aggregate$roc_curve
  data_auc <- roc_obj$aggregate$roc_auc
  roc_obj$aggregate$roc_curve <- build_gg_roc_curve(data_curve, data_auc, col = from , ... )
  
  ## -- indiv
  data_curve <- roc_obj$byparams$roc_curve
  data_auc <- roc_obj$byparams$roc_auc
  roc_obj$byparams$roc_curve <- build_gg_roc_curve(data_curve, data_auc, col = from,  ... ) +
    ggplot2::facet_wrap(~description) 
  
  return(roc_obj)
}






```

```{r test-receiver_operating_characteristic}


# Test cases for getLabelExpected function
test_that("getLabelExpected assigns labels correctly", {
  

    # Sample comparison data frame
  comparison_data <- data.frame(
      geneID = c("gene1", "gene2", "gene3"),
      actual = c(0.5, -0.3, 0.8)
  )
  
  # Test case 1: Alt hypothesis = "greater"
  labeled_data_greater <- getLabelExpected(comparison_data, coeff_threshold = 0.2, alt_hypothesis = "greater")
  expect_identical(labeled_data_greater$isDE, c(TRUE, FALSE, TRUE))
  
  # Test case 2: Alt hypothesis = "less"
  labeled_data_less <- getLabelExpected(comparison_data, coeff_threshold = -0.2, alt_hypothesis = "less")
  expect_identical(labeled_data_less$isDE, c(FALSE, TRUE, FALSE))
  
  # Test case 3: Alt hypothesis = "greaterAbs"
  labeled_data_greaterAbs <- getLabelExpected(comparison_data, coeff_threshold = 0.6, alt_hypothesis = "greaterAbs")
  expect_identical(labeled_data_greaterAbs$isDE, c(FALSE, FALSE, TRUE))
  
})





test_that("compute_roc_curve computes ROC curve from data frame", {
  # Test data
  dt <- data.frame(
    isDE = c(1, 0, 1, 0, 1),
    p.adj = c(0.8, 0.6, 0.7, 0.4, 0.9)
  )

  # Test
  result <- compute_roc_curve(dt)
  expect_equal(names(result), c("False positive rate", "True positive rate"))
})

test_that("compute_roc_auc computes AUC from data frame", {
  # Test data
  dt <- data.frame(
    isDE = c(0, 1, 1, 0, 0, 0, 1),
    p.adj = c(1, 0.75, 0.75, 0.5, 0.25, 0, 0.1)
  )
  
  roc_curve_obj <- compute_roc_curve(dt)
  # Test
  result <- compute_roc_auc(roc_curve_obj)
  
   # Expected output
  expected_result <- 0.42
  expect_equal(round(result,2 ), expected_result)
})

test_that("get_roc_object returns ROC curves and AUCs", {
  # Test data
  set.seed(101)
  evaldata_params <- data.frame(
    from = rep(c("HTRfit", "DESeq2"), each = 5),
    description = rep(c("param1", "param2"), each = 5),
    isDE = sample(0:1, 10, replace = TRUE),
    p.adj = runif(10),
    effect = rep("fixed", 10)
  )

  # Test
  result <- get_roc_object(evaldata_params)
  expect_equal(names(result), c("byparams", "aggregate"))
  expect_equal(names(result$byparams), c("roc_curve", "roc_auc"))
  expect_equal(names(result$aggregate), c("roc_curve", "roc_auc"))
  
  ## -- not only fixed effect 
  set.seed(101)
  evaldata_params <- data.frame(
    from = rep(c("HTRfit", "DESeq2"), each = 5),
    description = rep(c("param1", "param2"), each = 5),
    isDE = sample(0:1, 10, replace = TRUE),
    p.adj = runif(10),
    effect = c(rep("fixed", 8), 'ran_pars', 'ran_pars')
  )
    
  result <- get_roc_object(evaldata_params)
  expect_equal(names(result), c("byparams", "aggregate"))
  expect_equal(names(result$byparams), c("roc_curve", "roc_auc"))
  expect_equal(names(result$aggregate), c("roc_curve", "roc_auc"))
  

    
})

test_that("build_gg_roc_curve builds ggplot ROC curve", {
  # Test data
  data_curve <- data.frame(
    .threshold = c(-Inf, 0.9, 0.8, 0.7, 0.6, 0.4, Inf),
    specificity = c(0, 1, 0.75, 0.5, 0.25, 0, 1),
    sensitivity = c(1, 0.75, 0.75, 0.5, 0.25, 0, 0),
    from = rep("HTRfit", 7)
  )
  data_auc <- data.frame(from = "HTRfit", AUC = 0.6875)

  # Test
  result <- build_gg_roc_curve(data_curve, data_auc)
  # Ensure that the ggplot object is created without errors
  expect_true("gg" %in% class(result))
})

test_that("get_label_y_position computes y-axis positions for labels", {
  # Test data
  data_auc <- data.frame(from = rep(c("HTRfit", "DESeq2"), each = 2), AUC = c(1, 0.90))

  # Expected output
  expected_result <- data.frame(from = rep(c("HTRfit", "DESeq2"), each = 2), AUC = c(1, 0.90), pos_y = c(0.05, 0.05, 0.15, 0.15))

  # Test
  result <- get_label_y_position(data_auc)
  expect_equal(result, expected_result)
})



```


```{r function-counts_plot, filename = "counts_plot"}

#' Generate a density plot of gene counts
#'
#' This function generates a density plot of gene counts from mock data.
#'
#' @param mock_obj The mock data object containing gene counts.
#'
#' @return A ggplot2 density plot.
#'
#' @importFrom ggplot2 aes geom_density theme_bw ggtitle scale_x_log10 element_blank
#' @export
#'
#' @examples
#' mock_data <- list(counts = matrix(c(1, 2, 3, 4, 5, 6, 7, 8, 9), ncol = 3))
#' counts_plot(mock_data)
counts_plot <- function(mock_obj){

  counts <- unname(unlist(mock_obj$counts))
  p <- ggplot2::ggplot() +
      ggplot2::aes(x = "Genes", y = counts) +
      ggplot2::geom_point(position = "jitter", alpha = 0.6, size = 0.4, col = "#F0B27A") +
      ggplot2::geom_violin(fill = "#F8F9F9", alpha = 0.4) +
      ggplot2::stat_summary(fun = "mean", geom = "point", color = "#B63A0F", size = 5) +
      ggplot2::theme_bw() +
      ggplot2::ggtitle("Gene expression plot") +
      ggplot2::theme(axis.title.x =  ggplot2::element_blank())
  return(p)
}


```

```{r test-counts_plot}



# Test cases
test_that("Counts plot is generated correctly", {
  mock_data <- list(
    counts = matrix(c(1, 2, 3, 4, 5, 6, 7, 8, 9), ncol = 3)
  )
  
  plot <- counts_plot(mock_data)
  
  expect_true("gg" %in% class(plot))
})



```



```{r function-evaluation_identity, filename = "evaluation_identity"}

#' Gets R-squared values for plotting.
#'
#' This function takes a data frame with R-squared and RMSE values,
#' computes position coordinates, and prepares data for plotting.
#' @param data Data frame with R-squared values and RMSE values.
#' @return A data frame with additional columns for labeling in the plot.
#' @export
#' @examples
#' data_metrics <- data.frame(from = c("A", "B", "C"), 
#'                            description = c("Desc1", "Desc2", "Desc3"), 
#'                            R2 = c(0.9, 0.8, 0.7),
#'                            RMSE = c(0.6, 0.2, 0.1))
#' result <- get_metrics_2plot(data_metrics)
get_metrics_2plot <- function(data){
  data$pos_x <- -Inf
  data$pos_y <- Inf
  data$label_italic <- sprintf("italic(R^2) == %1.2f ~ phantom() ~  phantom() ~ italic(RMSE) == %2.2f", round(data$R2, 3), round(data$RMSE, 3))
  data$label_vjust <- as.numeric(factor(data$from))
  return(data)
}


#' Generate an identity term plot and get metrics associated
#'
#' This function generates an identity plot for comparing actual values with estimates
#'
#' @param data_identity A data frame containing comparison results with "actual" and "estimate" columns.
#' @param palette_color dict-like palette default: palette_color = c(DESeq2 = "#500472", HTRfit ="#79cbb8")
#' @param palette_shape Optional parameter that sets the point shape for plots.Default : c(DESeq2 = 17, HTRfit = 19).
#' @param ... additional parameters to pass geom_point aes 
#' @return A ggplot2 identity plot and R2 metric associated
#'
#' @importFrom ggplot2 sym aes geom_point geom_abline facet_wrap theme_bw ggtitle scale_color_manual geom_text scale_shape_manual
#' @importFrom rlang .data new_environment
#' @export
#' @examples
#' comparison_data <- data.frame(
#'        actual = c(1, 2, 3, 4, 5),
#'        estimate = c(0.9, 2.2, 2.8, 4.1, 5.2),
#'        description = rep("Category A", 5),
#'        term = rep("Category A", 5),
#'        from = c("A", "B", "B", "A", "B"))
#' eval_identityTerm(comparison_data, 
#'                  palette_color = c(A = "#500472", B ="#79cbb8"),
#'                  palette_shape = c(A = 17, B = 19))
eval_identityTerm <- function(data_identity, palette_color = c(DESeq2 = "#500472", HTRfit ="#79cbb8"),  
                               palette_shape = c(DESeq2 = 17, HTRfit = 19), ...){

  data_rsquare <- compute_rsquare(data_identity)
  data_rmse <- compute_rmse(data_identity)
  data_metrics <- join_dtf(data_rsquare, data_rmse, k1 = c("from", "description"), k2 = c("from", "description"))
  data_metrics2plot <- get_metrics_2plot(data_metrics)
  
  p <- ggplot2::ggplot(data_identity, mapping = ggplot2::aes(x = .data$actual, y = .data$estimate, col = from, shape = from,  ...) )+
    ggplot2::geom_point(alpha = 0.6, size = 2) +
    ggplot2::geom_abline(intercept = 0, slope = 1, lty = 3, col = 'red', linewidth = 1) +
    ggplot2::facet_wrap(~description, scales = "free") +
    ggplot2::theme_bw()  +
    ggplot2::geom_text(data = data_metrics2plot,
                       mapping = ggplot2::aes(x = pos_x, y = pos_y, label = label_italic, col = from, vjust = label_vjust),
                       parse = TRUE, hjust = -0.3 ) +
    ggplot2::ggtitle("Identity plot") +
    ggplot2::scale_color_manual(values = palette_color ) + 
    ggplot2::scale_shape_manual(values = palette_shape )

  p$plot_env <- rlang::new_environment()

  obj_idTerm <- list(R2 = data_metrics, p = p )

  return(obj_idTerm)
}


```

```{r test-evaluation_identity}


# Test cases
test_that("Identity plot is generated correctly", {
  comparison_data <- data.frame(
    actual = c(1, 2, 3, 4, 5),
    estimate = c(0.9, 2.2, 2.8, 4.1, 5.2),
    description = rep("Category A", 5),
    from = c("A", "B", "A", "B", "A"),
    term = rep("Category A", 5)
  )
  
  idTerm_obj <- eval_identityTerm(comparison_data)
  
  expect_true("gg" %in% class(idTerm_obj$p))
  expect_equal(c("from", "description", "R2", "RMSE"), colnames(idTerm_obj$R2))  
})



#' Unit Test for get_rsquare_2plot function.
test_that("get_metrics_2plot returns expected result", {
  data_metrics <- data.frame(from = c("A", "B", "C"), description = c("Desc1", "Desc2", "Desc3"), R2 = c(0.9, 0.8, 0.7), RMSE = c(0.03, 0.9, 0.18))
  result <- get_metrics_2plot(data_metrics)
  expect_equal(names(result), c("from","description","R2" , "RMSE", "pos_x", "pos_y", "label_italic","label_vjust"))
  expect_equal(result$from, c("A","B", "C"))
  expect_equal(result$description, c("Desc1","Desc2", "Desc3"))
  expect_equal(result$label_vjust, c(1,2, 3))

})



```

```{r function-MLmetrics, filename = "mlmetrics"}



#' @title accuracy
#'
#' @description
#' Compute the accuracy classification score.
#'
#' @param y_pred Predicted labels vector, as returned by a classifier
#' @param y_true Ground truth (correct) 0-1 labels vector
#' @return accuracy
#' @examples
#' data(cars)
#' logreg <- glm(formula = vs ~ hp + wt,
#'               family = binomial(link = "logit"), data = mtcars)
#' pred <- ifelse(logreg$fitted.values < 0.5, 0, 1)
#' accuracy(y_pred = pred, y_true = mtcars$vs)
#' @export

accuracy <- function(y_pred, y_true) {
  accuracy <- mean(y_true == y_pred)
  return(accuracy)
}


#' @title Confusion Matrix
#'
#' @description
#' Compute confusion matrix to evaluate the accuracy of a classification.
#'
#' @param y_pred Predicted labels vector, as returned by a classifier
#' @param y_true Ground truth (correct) 0-1 labels vector
#' @return a table of Confusion Matrix
#' @examples
#' data(cars)
#' logreg <- glm(formula = vs ~ hp + wt,
#'               family = binomial(link = "logit"), data = mtcars)
#' pred <- ifelse(logreg$fitted.values < 0.5, 0, 1)
#' ConfusionMatrix(y_pred = pred, y_true = mtcars$vs)
#' @export

ConfusionMatrix <- function(y_pred, y_true) {
  Confusion_Mat <- table(y_true, y_pred)
  return(Confusion_Mat)
}


#' @title Confusion Matrix (Data Frame Format)
#'
#' @description
#' Compute data frame format confusion matrix for internal usage.
#'
#' @param y_pred Predicted labels vector, as returned by a classifier
#' @param y_true Ground truth (correct) 0-1 labels vector
#' @return a data.frame of Confusion Matrix
#' @examples
#' data(cars)
#' logreg <- glm(formula = vs ~ hp + wt,
#'               family = binomial(link = "logit"), data = mtcars)
#' pred <- ifelse(logreg$fitted.values < 0.5, 0, 1)
#' ConfusionDF(y_pred = pred, y_true = mtcars$vs)
#' @keywords internal
#' @export
ConfusionDF <- function(y_pred, y_true) {
  Confusion_DF <- transform(as.data.frame(ConfusionMatrix(y_pred, y_true)),
                            y_true = as.character(y_true),
                            y_pred = as.character(y_pred),
                            Freq = as.integer(Freq))
  return(Confusion_DF)
}

#' @title precision
#'
#' @description
#' Compute the precision score.
#'
#' @param y_pred Predicted labels vector, as returned by a classifier
#' @param y_true Ground truth (correct) 0-1 labels vector
#' @param positive An optional character string for the factor level that
#'   corresponds to a "positive" result
#' @return precision
#' @examples
#' data(cars)
#' logreg <- glm(formula = vs ~ hp + wt,
#'               family = binomial(link = "logit"), data = mtcars)
#' pred <- ifelse(logreg$fitted.values < 0.5, 0, 1)
#' precision(y_pred = pred, y_true = mtcars$vs, positive = "0")
#' precision(y_pred = pred, y_true = mtcars$vs, positive = "1")
#' @export
precision <- function(y_true, y_pred, positive = NULL) {
  Confusion_DF <- ConfusionDF(y_pred, y_true)
  if (is.null(positive) == TRUE) positive <- as.character(Confusion_DF[1,1])
  TP <- as.integer(subset(Confusion_DF, y_true==positive & y_pred==positive)["Freq"])
  FP <- as.integer(sum(subset(Confusion_DF, y_true!=positive & y_pred==positive)["Freq"]))
  precision <- TP/(TP+FP)
  return(precision)
}


#' @title recall
#'
#' @description
#' Compute the recall score.
#'
#' @param y_pred Predicted labels vector, as returned by a classifier
#' @param y_true Ground truth (correct) 0-1 labels vector
#' @param positive An optional character string for the factor level that
#'   corresponds to a "positive" result
#' @return recall
#' @examples
#' data(cars)
#' logreg <- glm(formula = vs ~ hp + wt,
#'               family = binomial(link = "logit"), data = mtcars)
#' pred <- ifelse(logreg$fitted.values < 0.5, 0, 1)
#' recall(y_pred = pred, y_true = mtcars$vs, positive = "0")
#' recall(y_pred = pred, y_true = mtcars$vs, positive = "1")
#' @export
recall <- function(y_true, y_pred, positive = NULL) {
  Confusion_DF <- ConfusionDF(y_pred, y_true)
  if (is.null(positive) == TRUE) positive <- as.character(Confusion_DF[1,1])
  TP <- as.integer(subset(Confusion_DF, y_true==positive & y_pred==positive)["Freq"])
  FN <- as.integer(sum(subset(Confusion_DF, y_true==positive & y_pred!=positive)["Freq"]))
  recall <- TP/(TP+FN)
  return(recall)
}


#' @title sensitivity
#'
#' @description
#' Compute the sensitivity score.
#'
#' @param y_pred Predicted labels vector, as returned by a classifier
#' @param y_true Ground truth (correct) 0-1 labels vector
#' @param positive An optional character string for the factor level that
#'   corresponds to a "positive" result
#' @return sensitivity
#' @examples
#' data(cars)
#' logreg <- glm(formula = vs ~ hp + wt,
#'               family = binomial(link = "logit"), data = mtcars)
#' pred <- ifelse(logreg$fitted.values < 0.5, 0, 1)
#' sensitivity(y_pred = pred, y_true = mtcars$vs, positive = "0")
#' sensitivity(y_pred = pred, y_true = mtcars$vs, positive = "1")
#' @export
sensitivity  <- function(y_true, y_pred, positive = NULL) {
  Confusion_DF <- ConfusionDF(y_pred, y_true)
  if (is.null(positive) == TRUE) positive <- as.character(Confusion_DF[1,1])
  TP <- as.integer(subset(Confusion_DF, y_true==positive & y_pred==positive)["Freq"])
  FN <- as.integer(sum(subset(Confusion_DF, y_true==positive & y_pred!=positive)["Freq"]))
  sensitivity <- TP/(TP+FN)
  return(sensitivity)
}


#' @title specificity
#'
#' @description
#' Compute the specificity score.
#'
#' @param y_pred Predicted labels vector, as returned by a classifier
#' @param y_true Ground truth (correct) 0-1 labels vector
#' @param positive An optional character string for the factor level that
#'   corresponds to a "positive" result
#' @return specificity
#' @examples
#' data(cars)
#' logreg <- glm(formula = vs ~ hp + wt,
#'               family = binomial(link = "logit"), data = mtcars)
#' pred <- ifelse(logreg$fitted.values < 0.5, 0, 1)
#' specificity(y_pred = pred, y_true = mtcars$vs, positive = "0")
#' specificity(y_pred = pred, y_true = mtcars$vs, positive = "1")
#' @export
specificity  <- function(y_true, y_pred, positive = NULL) {
  Confusion_DF <- ConfusionDF(y_pred, y_true)
  if (is.null(positive) == TRUE) positive <- as.character(Confusion_DF[1,1])
  TN <- as.integer(subset(Confusion_DF, y_true!=positive & y_pred!=positive)["Freq"])
  FP <- as.integer(sum(subset(Confusion_DF, y_true!=positive & y_pred==positive)["Freq"]))
  specificity <- TN/(TN+FP)
  return(specificity)
}



#' @title Calculate the Area Under the Curve
#'
#' @description
#' Calculate the area under the curve.
#'
#' @param x the x-points of the curve
#' @param y the y-points of the curve
#' @param method can be "trapezoid" (default), "step" or "spline"
#' @return Area Under the Curve (AUC)
#' @examples
#' x <- seq(0, pi, length.out = 200)
#' plot(x = x, y = sin(x), type = "l")
#' Area_Under_Curve(x = x, y = sin(x), method = "trapezoid")
#' @importFrom stats splinefun
#' @export
Area_Under_Curve <- function(x, y, method = "trapezoid"){
  idx <- order(x)
  x <- x[idx]
  y <- y[idx]
  if (method == 'trapezoid'){
    auc <- sum((rowMeans(cbind(y[-length(y)], y[-1]))) * (x[-1] - x[-length(x)]))
  }else if (method == 'step'){
    auc <- sum(y[-length(y)] * (x[-1] - x[-length(x)]))
  }else if (method == 'spline'){
    auc <- integrate(splinefun(x, y, method = "natural"), lower = min(x), upper = max(x))
    auc <- auc$value
  }
  return(auc)
}



.performance.positive.predictive.value <-
  function(predictions, labels, cutoffs, fp, tp, fn, tn,
           n.pos, n.neg, n.pos.pred, n.neg.pred) {

    ppv <- tp / (fp + tp)
    list( cutoffs, ppv )
  }

.performance.false.positive.rate <-
  function(predictions, labels, cutoffs, fp, tp, fn, tn,
           n.pos, n.neg, n.pos.pred, n.neg.pred) {

    list( cutoffs, fp / n.neg )
  }


.performance.true.positive.rate <-
  function(predictions, labels, cutoffs, fp, tp, fn, tn,
           n.pos, n.neg, n.pos.pred, n.neg.pred) {

    list( cutoffs, tp / n.pos )
  }

.sarg <- function( arglist, ...) {
    ll <- list(...)
    for (argname in names(ll) ) {
        arglist[[ argname ]] <- ll[[ argname ]]
    }
    return(arglist)
}

## return list of selected arguments, skipping those that
## are not present in arglist
.select.args <- function( arglist, args.to.select, complement=FALSE) {
    match.bool <- names(arglist) %in% args.to.select
    if (complement==TRUE) match.bool <- !match.bool
    return( arglist[ match.bool] )
}
```


```{r function-precision_recall , filename = "precision_recall"}


#' Computes the precision-recall curve (AUC).
#'
#'
#' @param dt A data frame with columns truth (first column) and score (second column).
#' @return A dataframe with precision recall.
#' @export
compute_pr_curve <- function(dt){
  ## -- replace 0 by minimum machine 
  dt$p.adj[ dt$p.adj == 0 ] <- 1e-217
  ## --see .SDcols for order
  pred_obj <- prediction( -log10(dt$p.adj) , dt$isDE)
  perf_obj = performance(pred_obj, measure = "prec", x.measure = "rec")
  data2curve <- data.frame(x.name = perf_obj@x.values[[1]], y.name = perf_obj@y.values[[1]])
  names(data2curve) <- c(unname(perf_obj@x.name), unname(perf_obj@y.name))
  ## -- drop NA
  data2curve <- na.omit(data2curve)
  ## -- add start point
  data2curve <- rbind(c(0,1), data2curve)
  return(data2curve)
}


#' Computes area under the precision-recall curve (AUC).
#'
#' This function calculates the area under the precision-recall curve (AUC).
#'
#' @param dt A data table with columns for recall and precision.
#' @return A numeric value representing the AUC.
#' @export
compute_pr_auc <- function(dt) Area_Under_Curve( dt$recall, dt$precision )


#' Gets precision-recall objects for a given parameter.
#'
#' This function takes a data table of evaluation parameters and returns precision-recall curves
#' for each term and an aggregate precision-recall curve.
#'
#' @param evaldata_params Data table containing evaluation parameters.
#' @param col_param Column name specifying the parameter for grouping.
#' @param col_truth Column name for binary ground truth values.
#' @param col_score Column name for predicted scores.
#' @return A list containing precision-recall curves and AUCs for each group and an aggregate precision-recall curve and AUC.
#' @importFrom data.table setDT
#' @export
get_pr_object <- function(evaldata_params, col_param = "description", col_truth = "isDE", col_score = "p.adj"  ) {

  ## -- subset fixed eff
  evaldata_params <- subset(evaldata_params, effect == "fixed")
   
  ## -- by params -- random class AUC
  prop_table <- table(evaldata_params[[col_param]], evaldata_params[[col_truth]])
  random_classifier_auc_params <- prop_table[,"TRUE"]/rowSums(prop_table)
  random_classifier_auc_params <- as.data.frame(random_classifier_auc_params)
  random_classifier_auc_params[col_param] <- rownames(random_classifier_auc_params)
  
  ## -- aggregate -- random class AUC
  prop_table <- table(evaldata_params[[col_truth]])
  random_classifier_auc_agg <- unname(prop_table["TRUE"]/sum(prop_table))
  
  ## -- data.table conversion
  dt_evaldata_params <- data.table::setDT(evaldata_params)
  
  ## -- by params
  pr_curve_params <- dt_evaldata_params[, compute_pr_curve(.SD), by=c("from", col_param), .SDcols=c(col_truth, col_score)]
  pr_auc_params <- pr_curve_params[, compute_pr_auc(.SD), by=c("from", col_param), .SDcols=c("recall", "precision")]
  names(pr_auc_params)[ names(pr_auc_params) == "V1" ] <- "pr_AUC"
  pr_auc_params <- join_dtf(pr_auc_params, random_classifier_auc_params , 
                            k1 = col_param, k2 = col_param)
  names(pr_auc_params)[ names(pr_auc_params) == "random_classifier_auc_params" ] <- "pr_randm_AUC"
  pr_auc_params$pr_performance_ratio <- pr_auc_params$pr_AUC/pr_auc_params$pr_randm_AUC

  ## -- aggregate
  pr_curve_agg <- dt_evaldata_params[, compute_pr_curve(.SD), by = "from", .SDcols=c(col_truth, col_score)]
  pr_auc_agg <- pr_curve_agg[, compute_pr_auc(.SD), by = "from", .SDcols=c("recall", "precision")]
  names(pr_auc_agg)[ names(pr_auc_agg) == "V1" ] <- "pr_AUC"
  pr_auc_agg$pr_randm_AUC <- random_classifier_auc_agg
  pr_auc_agg$pr_performance_ratio <- pr_auc_agg$pr_AUC/pr_auc_agg$pr_randm_AUC

  return(list(byparams = list(pr_curve = as.data.frame(pr_curve_params),
                              pr_auc = as.data.frame(pr_auc_params)),
              aggregate = list(pr_curve = as.data.frame(pr_curve_agg),
                               pr_auc = as.data.frame(pr_auc_agg)))
  )

}



#' Builds a ggplot precision-recall curve.
#'
#' This function takes data frames for precision-recall curve and AUC and builds a ggplot precision-recall curve.
#'
#' @param data_curve Data frame with precision-recall curve.
#' @param data_auc Data frame with AUC.
#' @param palette_color list of colors used.
#' @param ... Additional arguments to be passed to \code{ggplot2::geom_path}.
#' @return A ggplot object representing the precision-recall curve.
#' @importFrom ggplot2 ggplot geom_path geom_text theme_bw coord_fixed scale_color_manual aes sym 
#' @export
build_gg_pr_curve <- function(data_curve, data_auc, palette_color = c("#500472", "#79cbb8"), ...){
  
  #print(list(...))
  #print(ggplot2::sym(list(...)))
  #args <- lapply(list(...), function(x) if(!is.null(x)) ggplot2::sym(x) )
  
  
  data_auc <- get_label_y_position(data_auc)
  
  
  ggplot2::ggplot(data_curve) +
    ggplot2::geom_path(ggplot2::aes(x = recall, y = precision , ... ), linewidth = 1) +
    ggplot2::geom_text(data_auc,
                       mapping = ggplot2::aes(x = 0.75, y = pos_y,
                                              label = paste("AUC :", round(pr_AUC, 2) , sep = ""), col = from )
    ) +
    ggplot2::theme_bw() +
    ggplot2::coord_fixed() +
    ggplot2::scale_color_manual(values = palette_color)

}



#' Gets precision-recall curves and AUC for both aggregated and individual parameters.
#'
#' This function takes a precision-recall object and returns precision-recall curves and AUCs for both aggregated and individual parameters.
#'
#' @param pr_obj precision-recall object.
#' @param ... Additional arguments to be passed to \code{ggplot2::geom_path}.
#' @return precision-recall curves and AUCs for both aggregated and individual parameters.
#' @importFrom ggplot2 facet_wrap 
#' @export
get_pr_curve <- function(pr_obj, ...){

  ## -- aggreg
  data_curve <- pr_obj$aggregate$pr_curve
  data_auc <- pr_obj$aggregate$pr_auc
  pr_obj$aggregate$pr_curve <- build_gg_pr_curve(data_curve, data_auc,  col = from , ... )

  ## -- indiv
  data_curve <- pr_obj$byparams$pr_curve
  data_auc <- pr_obj$byparams$pr_auc
  pr_obj$byparams$pr_curve <- build_gg_pr_curve(data_curve, data_auc,  col = from , ... ) +
                                ggplot2::facet_wrap(~description)

  return(pr_obj)
}




```

```{r test-precision_recall}



test_that("compute_pr_curve computes precision-recall curve", {
  # Mock data for testing
  set.seed(42)
  dt <- data.table::data.table(
    description = rep(c("param1", "param2"), each = 50),
    isDE = sample(0:1, 100, replace = TRUE),
    p.adj = runif(100)
  )

  # Test the compute_pr_curve function
  result <- compute_pr_curve(dt)
  expect_true("recall" %in% names(result))
  expect_true("precision" %in% names(result))
})

test_that("compute_pr_auc computes area under the precision-recall curve", {
  # Mock data for testing
  set.seed(42)
  dt <- data.table::data.table(
    description = rep(c("param1", "param2"), each = 50),
    isDE = sample(0:1, 100, replace = TRUE),
    p.adj = runif(100)
  )

  # Test the compute_pr_auc function
  pr_curve <- compute_pr_curve(dt)
  result <- compute_pr_auc(pr_curve)
  expect_equal(expected = 0.60, round(result, 2))
})

test_that("get_pr_object gets precision-recall objects", {
  # Mock data for testing
  set.seed(42)
  dt_evaldata_params <- data.table::data.table(
    description = rep(c("param1", "param2"), each = 50),
    isDE = sample(c("FALSE","TRUE"), 100, replace = TRUE),
    from = c("A", "B"),
    p.adj = runif(100),
    effect = "fixed"
  )

  # Test the get_pr_object function
  result <- get_pr_object(dt_evaldata_params)
  
  expect_true("byparams" %in% names(result))
  expect_true("aggregate" %in% names(result))
  expect_true("data.frame" %in% class(result$byparams$pr_curve))
  expect_true("data.frame" %in% class(result$byparams$pr_auc))
  expect_true("data.frame" %in% class(result$aggregate$pr_curve))
  expect_true("data.frame" %in% class(result$aggregate$pr_auc))

  ## -- not only fixed effect
  # Mock data for testing
  set.seed(42)
  dt_evaldata_params <- data.table::data.table(
    description = rep(c("param1", "param2"), each = 50),
    isDE = sample(c(TRUE,FALSE), 100, replace = TRUE),
    from = c("A", "B"),
    p.adj = runif(100),
    effect = sample(c("fixed","ran_pars"), 100, replace = TRUE)
  )

  # Test the get_pr_object function
  result <- get_pr_object(dt_evaldata_params)
  
  expect_true("byparams" %in% names(result))
  expect_true("aggregate" %in% names(result))
  expect_true("data.frame" %in% class(result$byparams$pr_curve))
  expect_true("data.frame" %in% class(result$byparams$pr_auc))
  expect_true("data.frame" %in% class(result$aggregate$pr_curve))
  expect_true("data.frame" %in% class(result$aggregate$pr_auc))
  
    
})

test_that("build_gg_pr_curve builds ggplot precision-recall curve", {
  # Mock data for testing
  set.seed(42)
  data_curve <- data.frame(
    from = "A",
    recall = seq(0, 1, length.out = 100),
    precision = runif(100)
  )
  data_auc <- data.frame(from = "A", AUC = 0.75)

  # Test the build_gg_pr_curve function
  result <- build_gg_pr_curve(data_curve, data_auc)
  expect_true("gg" %in% class(result))
})

test_that("get_pr_curve gets precision-recall curves and AUCs", {
  # Mock data for testing
  set.seed(42)
  pr_obj <- list(
    byparams = list(
      pr_curve = data.frame(
        from = c("A"),
        description = rep(c("param1", "param2"), each = 50),
        recall = seq(0, 1, length.out = 100),
        precision = runif(100)
      ),
      pr_auc = data.frame(from = c("A"), description = c("param1", "param2"), AUC = c(0.75, 0.80))
    ),
    aggregate = list(
      pr_curve = data.frame(
        from = c("A"),
        recall = seq(0, 1, length.out = 100),
        precision = runif(100)
      ),
      pr_auc = data.frame(from = c("A"), description = c("param1", "param2"), AUC = c(0.75, 0.80))
    )
  )

  # Test the get_pr_curve function
  result <- get_pr_curve(pr_obj)
  build_gg_pr_curve(pr_obj$aggregate$pr_curve, pr_obj$aggregate$pr_auc )
  expect_true("byparams" %in% names(result))
  expect_true("aggregate" %in% names(result))
  expect_true("gg" %in% class(result$byparams$pr_curve))
  expect_true("gg" %in% class(result$aggregate$pr_curve))
})

```


```{r function-simulation_report, filename =  "simulation_report" }


#' Check Validity of Truth Labels
#'
#' This function checks the validity of truth labels for HTRfit evaluation, specifically designed for binary classification.
#'
#' @param eval_data Data frame containing evaluation parameters.
#' @param col_param Column name specifying the parameter for grouping.
#' @param col_truth Column name for binary ground truth values (default is "isDE").
#' @return Logical value indicating the validity of truth labels.
#' @export
is_truthLabels_valid <- function(eval_data, col_param = "description", col_truth = "isDE" ) {
  
    ## --init
    isValid <- TRUE
  
    ## -- subset fixed effect
    eval_data_fixed <- subset(eval_data, effect == "fixed")
    
    table_summary <- table(eval_data_fixed[[col_param]], eval_data_fixed[[col_truth]])
    ## -- 2 lines needed (FALSE/TRUE)
    n_labels <- dim(table_summary)[2]
    if(n_labels != 2) {
      labels_str <- paste(colnames(table_summary), collapse = ", ")
      msg <- paste("Both FALSE/TRUE truth labels (isDE) are required for classification evaluation.\nFound : ", labels_str )
      message(msg)
      isValid <- FALSE
    }
    ## -- one label found 0 time !
    label_not_found <- which(table_summary == 0, arr.ind=TRUE)
    if(dim(label_not_found)[1] > 0){
      description <- rownames(label_not_found)
      label_not_found <- colnames(table_summary)[label_not_found[,"col"]]
      msg <- "Both TRUE and FALSE labels are required for HTRfit evaluation.\n"
      msg2 <- paste("Label isDE ==", label_not_found, "not found for description ==", description, collapse = '\n')
      msg3 <- "Please review your threshold or alternative hypothesis, and consider checking the identity plot for further insights."
      msg <- paste(msg, msg2, ".\n", msg3 , sep = "")
      message(msg)
      isValid <- FALSE
    }
    return(isValid)
}


#' Validate input parameters for evaluation
#'
#' This function validates the input parameters used for evaluation, ensuring that they meet the required criteria.
#'
#' @param mock_obj Mock object containing data for evaluation.
#' @param list_gene Character vector of gene names to evaluate. Default is NULL.
#' @param list_tmb List of glmmTMB objects to evaluate. Default is NULL.
#' @param dds DESeqDataSet object to evaluate. Default is NULL.
#' @param coeff_threshold Numeric value representing the threshold for coefficients.
#' @param alt_hypothesis Numeric value representing the alternative hypothesis. Should be one of 'greaterAbs', 'greater', or 'less'.
#' @param alpha_risk Numeric value representing the alpha risk for hypothesis testing.
#' @return TRUE or error message
#' @export
isValidEvalInput <- function(mock_obj, list_gene, list_tmb, dds, coeff_threshold, alt_hypothesis, alpha_risk) {
    # Validation de l'objet mock
    invisible(isValidMock_obj(mock_obj))

    # Vérification de la présence d'au moins un objet à évaluer
    if (is.null(list_tmb) && is.null(dds)) 
        stop("Both 'list_tmb' and 'dds' are NULL. There is nothing to evaluate.")

    # Vérification du type d'objet DESeqDataSet
    if (!is.null(dds)) {
        stopifnot("DESeqDataSet" %in% class(dds))
    }

    # Validation de la liste d'objets glmmTMB
    if (!is.null(list_tmb)) {
        isValidList_tmb(list_tmb)
    }

    # Vérification de la liste de gènes
    if (!is.null(list_gene)) {
        stopifnot(is.character(list_gene))
    }

      
    stopifnot(length(coeff_threshold) == 1)
    stopifnot(is.numeric(coeff_threshold))
    
    stopifnot(length(alt_hypothesis) == 1)
    stopifnot(is.character(alt_hypothesis))
    stopifnot(alt_hypothesis %in% c("greaterAbs", "greater", "less"))
    
    stopifnot(length(alpha_risk) == 1)
    stopifnot(is.numeric(alpha_risk))
    
    return(TRUE)
}



#' Compute evaluation report for TMB/DESeq2 analysis
#'
#' This function computes an evaluation report for TMB/DESeq2 analysis using several graphical
#' summaries like precision-recall (PR) curve, Receiver operating characteristic (ROC) curve
#' and others. It takes as input several parameters like TMB results (\code{l_tmb}), DESeq2
#' result (\code{dds}), mock object (\code{mock_obj}), coefficient threshold (\code{coeff_threshold}) and
#' alternative hypothesis (\code{alt_hypothesis}). 
#'
#' @param mock_obj Mock object that represents the distribution of measurements corresponding
#'   to mock samples.
#' @param list_gene A character vector specifying the genes id to be retained for evaluation. If NULL (default) all genes are used for evaluation
#' @param list_tmb TMB results from analysis.
#' @param dds DESeq2 results from differential gene expression analysis.
#' @param coeff_threshold  A non-negative value which specifies a ln(fold change) threshold. The Threshold  is used for the Wald test to determine whether the  coefficient (β) is significant or not, depending on \code{alt_hypothesis} parameter. Default is 0.69, ln(FC = 2).
#' @param alt_hypothesis Alternative hypothesis for the Wald test (default is "greaterAbs").
#' Possible choice: 
#' "greater" 
#' - β > coeff_threshold, 
#' "less" 
#' - β < −coeff_threshold,
#' or two-tailed alternative: 
#' "greaterAbs" 
#' - |β| > coeff_threshold
#' @param alpha_risk parameter that sets the threshold for alpha risk level while testing coefficient (β). Default: 0.05.
#' @param palette_color Optional parameter that sets the color palette for plots.Default : c(DESeq2 = "#500472", HTRfit ="#79cbb8").
#' @param palette_shape Optional parameter that sets the point shape for plots.Default : c(DESeq2 = 17, HTRfit = 19).
#' @param skip_eval_intercept indicate whether to calculate precision-recall and ROC metrics for the intercept (default skip_eval_intercept = TRUE).
#' @param ... Additional parameters to be passed to aesthetics \code{get_pr_curve} and \code{get_roc_curve}.
#'
#' @return A list containing the following components:
#' \item{identity}{A list containing model parameters and dispersion data.}
#' \item{precision_recall}{A PR curve object generated from TMB and DESeq2 results.}
#' \item{roc}{A ROC curve object generated from TMB and DESeq2 results.}
#' \item{counts}{A counts plot generated from mock object.}
#'  \item{performances}{A summary of the performances obtained.}
#' @export
#' @examples
#' \dontrun{
#' report <- evaluation_report(list_tmb = l_res, dds = NULL, 
#'                            mock_obj = mock_data, 
#'                            coeff_threshold = 0.45, 
#'                            alt_hypothesis = "greaterAbs")
#' }
#' 
evaluation_report <- function(mock_obj, list_gene = NULL,
                              list_tmb = NULL, dds = NULL,
                              coeff_threshold = 0.69, alt_hypothesis = "greaterAbs", alpha_risk = 0.05, 
                              palette_color = c(DESeq2 = "#500472", HTRfit ="#79cbb8"), 
                              palette_shape = c(DESeq2 = 17, HTRfit = 19), 
                              skip_eval_intercept = TRUE, ...) {
  
  ## -- verif
 invisible(isValidEvalInput(mock_obj, list_gene, list_tmb, dds, 
                            coeff_threshold, alt_hypothesis, alpha_risk))
  
  ## -- subset genes
  if (!is.null(list_gene)) {
        mock_obj <- subsetGenes(list_gene, mock_obj)
        list_tmb <- list_tmb[list_gene]
  }
  
  ## -- eval data
  eval_data <- get_eval_data(list_tmb, dds, mock_obj, coeff_threshold, alt_hypothesis)
  ## -- identity plot
  params_identity_eval <- eval_identityTerm( eval_data$modelparams, palette_color, palette_shape )
  dispersion_identity_eval <- eval_identityTerm(eval_data$modeldispersion, palette_color, palette_shape)
  
  if (isTRUE(skip_eval_intercept)) {
    eval_data2metrics <- subset(eval_data$modelparams, term != "(Intercept)")
  } else {
    eval_data2metrics <- eval_data$modelparams
  }

  ## aggregate RMSE + R2
  rmse_modelparams <- compute_rmse(eval_data2metrics, grouping_by = "from")
  rsquare_modelparams <- compute_rsquare(eval_data2metrics, grouping_by = "from")
  rsquare_modelparams$from <- NULL
  aggregate_metrics <- cbind(rmse_modelparams, rsquare_modelparams)
  
  
  ## -- counts plot
  counts_violinplot <- counts_plot(mock_obj)
  
  ## -- check if eval data ok 
  if(isFALSE(is_truthLabels_valid(eval_data2metrics))){
    
    message("The required truth labels for HTRfit classification metrics evaluation are not met. Only the identity plot and counts plot will be returned.")
    ## -- clear memory
    invisible(gc(reset = T, verbose = F, full = T)) ; 
    
    return(
          list(
                data = eval_data, 
                identity = list(params = params_identity_eval$p,
                                dispersion = dispersion_identity_eval$p ), 
                counts = counts_violinplot,
                performances = list(byparams = rbind(dispersion_identity_eval$R2, params_identity_eval$R2),
                                    aggregate = aggregate_metrics )
                
          ))
      
      
  }

  
  
  ## -- pr curve
  pr_curve_obj <- get_pr_object(eval_data2metrics)
  pr_curve_obj <- get_pr_curve(pr_curve_obj, palette_color = palette_color,  ...)
  
  ## -- auc curve
  roc_curve_obj <- get_roc_object(eval_data2metrics)
  roc_curve_obj <- get_roc_curve(roc_curve_obj, palette_color = palette_color, ...)
  
  ## -- acc, recall, sensib, speci, ...
  metrics_obj <- get_ml_metrics_obj(eval_data2metrics, alpha_risk )
  ## -- merge all metrics in one obj
  model_perf_obj <- get_performances_metrics_obj( r2_params = params_identity_eval$R2,   
                                                  r2_agg = aggregate_metrics, 
                                                  dispersion_identity_eval$R2, 
                                                  pr_curve_obj,
                                                  roc_curve_obj,
                                                  metrics_obj )
  
  ## -- counts plot
  counts_violinplot <- counts_plot(mock_obj)
  
  ## -- clear memory
  invisible(gc(reset = T, verbose = F, full = T)) ;  
  
  return(
        list(
              data = eval_data, 
              identity = list( params = params_identity_eval$p,
                              dispersion = dispersion_identity_eval$p ) ,
              precision_recall = list( params = pr_curve_obj$byparams$pr_curve,
                                        aggregate = pr_curve_obj$aggregate$pr_curve ),
              roc = list( params = roc_curve_obj$byparams$roc_curve,
                                        aggregate = roc_curve_obj$aggregate$roc_curve ),
              counts = counts_violinplot,
              performances = model_perf_obj)
        )
}



#' Compute classification and regression performance metrics object
#' 
#' This function computes metrics object for both classification and regression performance
#' from evaluation objects generated by \code{evaluation_report} function. Metrics object
#' contains the by-parameter and aggregate metrics for PR AUC, ROC AUC, R-squared and other
#' classification metrics for precision, recall, sensitivity, and specificity. The function
#' takes as input various evaluation objects including R-squared values (\code{r2_params}),
#' dispersion values (\code{r2_dispersion}), PR object (\code{pr_obj}), ROC object
#' (\code{roc_obj}), and machine learning performance metrics object (\code{ml_metrics_obj}).
#' The function generates separate data frames for metric values by parameter value and for the
#' aggregated metric values.
#' 
#' @param r2_params R-squared and RMSE values from model parameters evaluation object.
#' @param r2_agg R-squared and RMSE values aggregated.
#' @param r2_dispersion R-squared and RMSE values from dispersion evaluation object.
#' @param pr_obj PR object generated from evaluation report.
#' @param roc_obj ROC object generated from evaluation report.
#' @param ml_metrics_obj Machine learning performance metrics object.
#' 
#' @return A list containing separate data frames for by-parameter and aggregated metric values.
#' @export 
get_performances_metrics_obj <- function(r2_params , r2_agg ,r2_dispersion,
                                         pr_obj, roc_obj, ml_metrics_obj ){
  ## -- by params  
  auc_mtrics_params <- join_dtf(pr_obj$byparams$pr_auc , roc_obj$byparams$roc_auc,
                         k1 = c("from", "description"), k2 = c("from", "description"))
  metrics_params <- join_dtf(auc_mtrics_params,  ml_metrics_obj$byparams, 
                        k1 = c("from", "description"), k2 = c("from", "description")) 
  rsquare_mtrics <- rbind(r2_params, r2_dispersion)
  metrics_params <- join_dtf(metrics_params, rsquare_mtrics, 
                        k1 = c("from", "description"), k2 = c("from", "description")) 
  ## -- aggregate
  auc_mtrics_agg <- join_dtf(pr_obj$aggregate$pr_auc , roc_obj$aggregate$roc_auc,
                      k1 = c("from"), k2 = c("from"))
  metrics_agg <- join_dtf(auc_mtrics_agg , ml_metrics_obj$aggregate,
                             k1 = c("from"), k2 = c("from"))
  metrics_agg <- join_dtf(metrics_agg, r2_agg,
                          k1 = c("from"), k2 = c("from"))
  return(list(byparams = metrics_params, aggregate = metrics_agg ))  
}


#' Compute summary metrics on classification results
#'
#' This function computes several classification metrics like accuracy, precision, recall,
#' sensitivity and specificity on classification results. The input to the function is a data frame
#' (\code{dt}) containing the predicted classification result as \code{y_pred} and the actual
#' classification as \code{isDE}. The function returns a data frame with the computed metrics.
#'
#' @param dt Data frame containing the predicted and actual classification results.
#'
#' @return A data frame with the computed classification metrics of accuracy, precision, recall,
#' sensitivity and specificity.
#' @export
compute_metrics_summary <- function(dt) {
  
  data.frame(
    accuracy = accuracy( dt$y_pred, dt$isDE ),
    precision = precision(dt$y_pred, y_true = dt$isDE, positive = "TRUE"),
    recall = recall(dt$y_pred, y_true = dt$isDE, positive = "TRUE"),
    sensitivity = sensitivity(dt$y_pred, y_true = dt$isDE, positive = "TRUE"),
    specificity = specificity(dt$y_pred , y_true = dt$isDE, positive = "TRUE")
  )
}

#' Get classification metrics for evaluation object
#'
#' This function extracts the classification metrics. It takes as input (\code{evaldata_params}) 
#' and an optional risk level for the alpha risk (\code{alpha_risk}).
#' It retrieves the p-values from the identity term and computes the binary classification
#' result by thresholding with the alpha risk level. It then computes the classification metrics
#' using \code{compute_metrics_summary} function separately for each parameter value as well as
#' for the aggregated results.
#'
#' @param evaldata_params Identity term of the evaluation object.
#' @param alpha_risk parameter that sets the threshold for alpha risk level (default 0.05).
#' @param col_param  parameter that sets the column name for the parameter (default "description").
#'
#' @return A list containing separate data frames for classification metrics by parameter value
#' and for aggregated classification metrics.
#'
#' @importFrom data.table setDT .SD
#' @export
get_ml_metrics_obj <- function(evaldata_params, alpha_risk = 0.05, col_param = "description"){
  
   ## -- subset fixed eff
  evaldata_params <- subset(evaldata_params, effect == "fixed")
  
  
  evaldata_params$y_pred <-  evaldata_params$p.adj <  alpha_risk
  
  ## by params
  dt_evaldata_params <- data.table::setDT(evaldata_params)
  byparam_metrics <- dt_evaldata_params[, compute_metrics_summary(.SD), by=c("from", col_param), .SDcols=c("y_pred", "isDE")]
  
  ## aggreg
  agg_metrics <- dt_evaldata_params[, compute_metrics_summary(.SD), by=c("from"), .SDcols=c("y_pred", "isDE")]
  
  return(list( byparams =  as.data.frame(byparam_metrics), aggregate = as.data.frame(agg_metrics)))
}

```



```{r test-simulation_report}



# Test invalid data
test_that("is_truthLabels_valid returns TRUE for valid data with both labels present", {
  eval_data <- data.frame(description = c("A", "B", "C"), isDE = c(TRUE, FALSE, FALSE), effect = c("fixed", "fixed", "fixed"))
  result <- is_truthLabels_valid(eval_data)
  expect_equal(result, FALSE)
})

# Test valid data
test_that("is_truthLabels_valid returns FALSE for valid data with only one label present", {
eval_data <- data.frame(description = c("A", "A", "A"), isDE = c(TRUE, NA, FALSE), effect = c("fixed", "fixed", "fixed"))
  result <- is_truthLabels_valid(eval_data)
  expect_equal(result, TRUE)
})




# Tester isValidEvalInput avec des données valides
test_that("isValidEvalInput with valid input", {
  
    # Définir des données de test
    mock_obj <- list(settings = NULL, init = NULL, groundTruth = NULL, counts = NULL, metadata = NULL, scaling_factors = NULL)
    list_gene <- c("gene1", "gene2", "gene3")
    list_tmb <- list(glmmTMB1 = NULL, glmmTMB2 = NULL)
    dds <- NULL
    coeff_threshold <- 0.05
    alt_hypothesis <- "greaterAbs"
    alpha_risk <- 0.01
  
  expect_error(isValidEvalInput(mock_obj, list_gene, list_tmb, dds, coeff_threshold, alt_hypothesis, alpha_risk),
               "All elements in 'list_tmb' are NULL")

  # Tester isValidEvalInput avec list_tmb et dds à NULL
  expect_error(isValidEvalInput(mock_obj, list_gene, NULL, NULL, coeff_threshold, alt_hypothesis, alpha_risk),
               "Both 'list_tmb' and 'dds' are NULL. There is nothing to evaluate.")


  # Tester isValidEvalInput avec alt_hypothesis incorrect
  l_tmb <- list("model1" = glmmTMB::glmmTMB(mpg ~ hp + vs + am + (1|cyl), data = mtcars),
                 "model2" = glmmTMB::glmmTMB(mpg ~ hp + vs + am + (1|cyl), data = mtcars))
  expect_error(isValidEvalInput(mock_obj, list_gene, l_tmb, dds, coeff_threshold, "greaterThan", alpha_risk))
  
  # alpha risk incoret
  expect_error(isValidEvalInput(mock_obj, list_gene, l_tmb, dds, coeff_threshold, "greater", "incorrect_alpha_risk"))

  # coeff_threshold  incoret
  expect_error(isValidEvalInput(mock_obj, list_gene, l_tmb, dds, c(), "greater", alpha_risk))

   # mock_obj  incoret
  expect_error(isValidEvalInput(c(), list_gene, l_tmb, dds, coeff_threshold, "greater", alpha_risk))

  
  ## no error 
  expect_true(isValidEvalInput(mock_obj, list_gene, l_tmb, dds, coeff_threshold, "greater", alpha_risk))
})




test_that("evaluation_report returns correct output", {
  N_GENES <- 100
  MAX_REPLICATES <- 5
  MIN_REPLICATES <- 5
  input_var_list <- init_variable(name = "varA", mu = 10, sd = 0.1, level = 3)
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                           min_replicates = MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata = mock_data$metadata, normalization = 'MRN')
  l_res <- fitModelParallel(formula = kij ~ varA,
                            data = data2fit, group_by = "geneID",
                            family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
  
  # Tests here
  report <- evaluation_report(list_tmb = l_res, 
                              dds = NULL, 
                              mock_obj = mock_data, 
                              coeff_threshold = 0.01, 
                              alt_hypothesis = "greater")
  
  expect_true(is.list(report))
  expect_true("data" %in% names(report))
  expect_true("identity" %in% names(report))
  expect_true("precision_recall" %in% names(report))
  expect_true("roc" %in% names(report))
  expect_true("counts" %in% names(report))
  expect_true("performances" %in% names(report))

  expect_error(evaluation_report(list_tmb = l_res, 
                    dds = NULL, 
                    mock_obj = NULL, # throw error 
                    coeff_threshold = 0.01, 
                    alt_hypothesis = "greater")  )
})




test_that("get_performances_metrics_obj returns correct output", {
  # Define the input data
  r2_params <- data.frame(from = c("Glm", "Hglm"),
                          description = c("query_1|Intercept", "query_1|Intercept", "query_1|varA",  "query_1|varA"),
                          R2 = c(0.9, 0.7))
  r2_dispersion <- data.frame(from = c("Glm", "Hglm"),
                               description = c("dispersion"),
                               R2 = c(0.8, 0.6))
  pr_obj <- list(byparams = list(pr_auc = data.frame(from = c("Glm", "Hglm"),
                                                      description = c("query_1|Intercept", "query_1|Intercept", "query_1|varA",  "query_1|varA"),
                                                      pr_auc = 0.9)),
                 aggregate = list(pr_auc = data.frame(from = c("Glm", "Hglm"),
                                                      pr_auc = 0.8)))
  roc_obj <- list(byparams = list(roc_auc = data.frame(from = c("Glm", "Hglm"),
                                                        description = c("query_1|Intercept", "query_1|Intercept", "query_1|varA",  "query_1|varA"),
                                                        roc_auc = 0.7)),
                  aggregate = list(roc_auc = data.frame(from = c("Glm", "Hglm"),
                                                        roc_auc = 0.6)))
  ml_metrics_obj <- list(byparams = data.frame(from = c("Glm", "Hglm"),
                                                description = c("query_1|Intercept", "query_1|Intercept", "query_1|varA", "query_1|varA"),
                                                accuracy = c(0.9, 0.8),
                                                recall = c(0.09, 0.88)),
                         aggregate = data.frame(from = c("Glm", "Hglm"),
                                                accuracy = c(0.7, 0.6), recall = c(0.1, 0.8)))
  r2_agg <- data.frame(from = c("Glm", "Hglm"),
                          RMSE = c(0.22, 0.55),
                          R2 = c(0.9, 0.7))
  
  # Call the function
  result <- get_performances_metrics_obj(r2_params, r2_agg, r2_dispersion,
                                          pr_obj, roc_obj, ml_metrics_obj)
  
  # Test the output
  expect_true(is.list(result))
  expect_true("byparams" %in% names(result))
  expect_true("aggregate" %in% names(result))
  expect_equal(nrow(result$byparams), 6)
  expect_equal(ncol(result$byparams), 7)
  expect_equal(nrow(result$aggregate), 2)
  expect_equal(ncol(result$aggregate), 7)
})





test_that("compute_metrics_summary returns correct output",{
  # Define input data
  dt <- data.frame(
    y_pred = c(TRUE, TRUE, FALSE, TRUE, FALSE, FALSE, TRUE, TRUE),
    isDE = c(TRUE, FALSE, TRUE, TRUE, FALSE, TRUE, FALSE, FALSE)
  )
  # Compute expected output
  acc_exp <- 0.375
  prec_exp <- 0.4
  rec_exp <- 0.5
  sens_exp <- 0.5
  spec_exp <- 0.25
  exp_df <- data.frame(
    accuracy = acc_exp,
    precision = prec_exp,
    recall = rec_exp,
    sensitivity = sens_exp,
    specificity = spec_exp
  )
  # Compute actual output with tested function
  act_df <- compute_metrics_summary(dt)
  # Test that compute_metrics_summary returns expected output
  expect_identical(act_df, exp_df)
})


test_that("get_ml_metrics_obj returns correct output", {
  # Define input data
  seed = 100
  evaldata_params <- data.frame(
    from = rep(c("Glm", "Hglm", "Hglm", "Glm"), 1000),
    description = rep(c("intercept", "varA", "intercept", "varA"), 1000),
    p.adj = sort(runif(1000), decreasing = T),
    isDE = sort(sample(c(TRUE, FALSE), 1000, replace = T)),
    effect = rep("fixed", 1000)
    )
 
  # Compute actual output with tested function
  act_output <- get_ml_metrics_obj(evaldata_params, alpha_risk = 0.05, col_param = "description")
  # Test actual output against expected output
  expect_identical(names(act_output$byparams), c( "from","description","accuracy","precision","recall","sensitivity","specificity"))
  expect_identical(names(act_output$aggregate), c( "from","accuracy","precision","recall","sensitivity","specificity"))

})




```


```{r function-evaluation, filement = "evaluation"}



#' Extracts evaluation data from a list of TMB models.
#'
#' This function takes a list of TMB models, performs tidy evaluation, extracts model parameters,
#' and compares them to the ground truth effects. Additionally, it evaluates and compares dispersion
#' inferred from TMB with the ground truth gene dispersion. The results are organized in two data frames,
#' one for model parameters and one for dispersion, both labeled as "HTRfit".
#'
#' @param list_tmb A list of TMB models.
#' @param mock_obj A mock object containing ground truth information.
#' @param coeff_threshold The coefficient threshold for wald test
#' @param alt_hypothesis The alternative hypothesis for wald test
#' @return A list containing data frames for model parameters and dispersion.
#' @export
get_eval_data_from_ltmb <- function(list_tmb, mock_obj, coeff_threshold, alt_hypothesis ){

  ## -- reshape 2 dataframe
  tidyRes  <- tidy_results(list_tmb, coeff_threshold, alt_hypothesis)

  ## -- model params
  formula_used <- list_tmb[[1]]$modelInfo$allForm$formula
  params_df <- compareInferenceToExpected(tidyRes, mock_obj$groundTruth$effects, formula_used)
  params_df <- getLabelExpected(params_df, coeff_threshold, alt_hypothesis)
  params_df$from <- "HTRfit"

  ## -- dispersion
  dispersion_inferred <- extract_tmbDispersion(list_tmb)
  dispersion_df <- getDispersionComparison(dispersion_inferred, mock_obj$groundTruth$gene_dispersion)
  dispersion_df$from <- "HTRfit"

  return(list(modelparams = params_df, modeldispersion = dispersion_df ))
}


#' Extracts evaluation data from a DESeqDataSet (dds) object.
#'
#' This function takes a DESeqDataSet object, performs tidy evaluation, extracts model parameters
#' (beta in the case of DESeqDataSet), and compares them to the ground truth effects. Additionally,
#' it evaluates and compares dispersion inferred from DESeqDataSet with the ground truth gene dispersion.
#' The results are organized in two data frames, one for model parameters and one for dispersion, both #' labeled as "HTRfit".
#'
#' @param dds A DESeqDataSet object.
#' @param mock_obj A mock object containing ground truth information.
#' @param coeff_threshold The coefficient threshold wald test
#' @param alt_hypothesis The alternative hypothesis wald test
#' @return A list containing data frames for model parameters and dispersion.
#' @export
get_eval_data_from_dds <- function(dds, mock_obj, coeff_threshold, alt_hypothesis){

  ## -- reshape 2 dataframe
  tidy_dds <- wrap_dds(dds, coeff_threshold, alt_hypothesis)

  ## -- model params (beta in case of dds)
  params_df <- inferenceToExpected_withFixedEff(tidy_dds$fixEff, mock_obj$groundTruth$effects)
  params_df <- getLabelExpected(params_df, coeff_threshold, alt_hypothesis)
  params_df$component <- NA
  params_df$group <- NA
  params_df$from <- "DESeq2"

  ## -- dispersion
  dispersion_inferred <- extract_ddsDispersion(tidy_dds)
  dispersion_df <- getDispersionComparison(dispersion_inferred , mock_obj$groundTruth$gene_dispersion)
  dispersion_df$from <- "DESeq2"

  return(list(modelparams = params_df, modeldispersion = dispersion_df ))

}




#' Combines evaluation data from TMB and DESeqDataSet (dds) objects.
#'
#' This function combines model parameters and dispersion data frames from TMB and DESeqDataSet (dds) evaluations.
#'
#' @param evaldata_tmb Evaluation data from TMB models.
#' @param evaldata_dds Evaluation data from DESeqDataSet (dds) object.
#' @return A list containing combined data frames for model parameters and dispersion.
#' @export
rbind_evaldata_tmb_dds <- function(evaldata_tmb, evaldata_dds){
  ## -- rbind
  evaldata_dispersion <- rbind(evaldata_tmb$modeldispersion, evaldata_dds$modeldispersion)
  evaldata_params <- rbind(evaldata_tmb$modelparams, evaldata_dds$modelparams)

  ## -- res
  return(list(modelparams = evaldata_params, modeldispersion = evaldata_dispersion ))
}


#' Combines model parameters and dispersion data frames.
#'
#' This function combines model parameters and dispersion data frames, ensuring proper alignment.
#'
#' @param eval_data Evaluation data containing model parameters and dispersion.
#' @return A combined data frame with model parameters and dispersion.
#' @export
rbind_model_params_and_dispersion <- function(eval_data){
  ## -- split
  disp_df <- eval_data$modeldispersion
  params_df <- eval_data$modelparams
  ## -- merging model and dispersion param
  disp_df[setdiff(names(params_df), names(disp_df))] <- NA
  disp_df <- disp_df[names(params_df)]
  ## -- rbind
  res_df <- rbind(params_df, disp_df)
  return(res_df)
}



#' Gets evaluation data from both TMB and DESeqDataSet (dds) objects.
#'
#' This function retrieves evaluation data from TMB and DESeqDataSet (dds) objects, combining
#' the results into a list containing data frames for model parameters and dispersion.
#'
#' @param l_tmb A list of TMB models (default is NULL).
#' @param dds A DESeqDataSet object (default is NULL).
#' @param mock_obj A mock object containing ground truth information.
#' @param coefficient Threshold value for coefficient testing (default is 0). This threshold corresponds to the natural logarithm of the fold change (ln(FC)).
#' @param alt_hypothesis The alternative hypothesis for wald test
#' @return A list containing data frames for model parameters and dispersion.
#' @export
get_eval_data <- function(l_tmb = NULL, dds = NULL , mock_obj, coefficient, alt_hypothesis){
  ## -- init 
  eval_data_tmb <- NULL
  eval_data_dds <- NULL
  
  ## -- evaluation data
  eval_data_tmb <- if (!is.null(l_tmb)) get_eval_data_from_ltmb(l_tmb, mock_obj, coefficient, alt_hypothesis )
  eval_data_dds <- if (!is.null(dds)) get_eval_data_from_dds(dds, mock_obj, coefficient, alt_hypothesis )
  ## -- merge/rbind
  eval_data <- rbind_evaldata_tmb_dds(eval_data_tmb, eval_data_dds)
  
  return(eval_data)
}


```

```{r test-evaluation}


# Test get_eval_data_from_ltmb
test_that("get_eval_data_from_ltmb returns correct output", {
  
  input_var_list <- init_variable( name = "varA", mu = 3, sd = 2, level = 3) 
  
  ## -- Required parameters
  N_GENES = 3
  MIN_REPLICATES = 3
  MAX_REPLICATES = 3
  ########################
  
  ## -- simulate RNAseq data based on input_var_list, minimum input required
  ## -- number of replicate randomly defined between MIN_REP and MAX_REP
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates  = MIN_REPLICATES,
                         max_replicates = MAX_REPLICATES)
  ## -- data from simulation 
  count_matrix <- mock_data$counts
  metaData <- mock_data$metadata
  ##############################
  ## -- convert counts matrix and samples metadatas in a data frame for fitting
  data2fit = prepareData2fit(countMatrix = count_matrix,
                           metadata =  metaData,
                           normalization = NULL)
  l_tmb <- fitModelParallel(formula = kij ~ varA  ,
                          data = data2fit,
                          group_by = "geneID",
                          family = glmmTMB::nbinom2(link = "log"),
                          n.cores = 1)
 
  eval_data_ltmb <- get_eval_data_from_ltmb(l_tmb, mock_data, 0.27, 'greater')
  expect_is(eval_data_ltmb, "list")
  expect_named(eval_data_ltmb, c("modelparams", "modeldispersion"))
})

# Test get_eval_data_from_dds
test_that("get_eval_data_from_dds returns correct output", {
  
  input_var_list <- init_variable( name = "varA", mu = 3, sd = 2, level = 3) 
  
  ## -- Required parameters
  N_GENES = 100
  MIN_REPLICATES = 3
  MAX_REPLICATES = 3
  ########################
  
  ## -- simulate RNAseq data based on input_var_list, minimum input required
  ## -- number of replicate randomly defined between MIN_REP and MAX_REP
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates  = MIN_REPLICATES,
                         max_replicates = MAX_REPLICATES)
  ## -- data from simulation 
  count_matrix <- mock_data$counts
  metaData <- mock_data$metadata
  ##############################
  
  dds <- DESeq2::DESeqDataSetFromMatrix(count_matrix, colData = metaData, ~ varA )
  dds <- DESeq2::DESeq(dds)
  class(dds)
  eval_data_dds <- get_eval_data_from_dds(dds, mock_data, 0.27, "greater")
  
  expect_is(eval_data_dds, "list")
  expect_named(eval_data_dds, c("modelparams", "modeldispersion"))
})

# Test rbind_evaldata_tmb_dds
test_that("rbind_evaldata_tmb_dds returns correct output", {
  
  input_var_list <- init_variable( name = "varA", mu = 3, sd = 2, level = 3) 
  
  ## -- Required parameters
  N_GENES = 15
  MIN_REPLICATES = 3
  MAX_REPLICATES = 3
  BASAL_EXPR <- 3
  ########################
  
  ## -- simulate RNAseq data based on input_var_list, minimum input required
  ## -- number of replicate randomly defined between MIN_REP and MAX_REP
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates  = MIN_REPLICATES,
                         max_replicates = MAX_REPLICATES,
                         basal_expression = BASAL_EXPR)
  ## -- data from simulation 
  count_matrix <- mock_data$counts
  metaData <- mock_data$metadata
  ##############################
  ## -- convert counts matrix and samples metadatas in a data frame for fitting
  data2fit = prepareData2fit(countMatrix = count_matrix,
                           metadata =  metaData,
                           normalization = NULL)
  l_tmb <- fitModelParallel(formula = kij ~ varA  ,
                          data = data2fit,
                          group_by = "geneID",
                          family = glmmTMB::nbinom2(link = "log"),
                          n.cores = 1)
  dds <- DESeq2::DESeqDataSetFromMatrix(count_matrix, colData = metaData, ~ varA )
  dds <- DESeq2::DESeq(dds)
  
  eval_data_dds <- get_eval_data_from_dds(dds, mock_data, 0.27, "greater")
  eval_data_ltmb <- get_eval_data_from_ltmb(l_tmb, mock_data, 0.27, 'greater')
  
  combined_eval_data <- rbind_evaldata_tmb_dds(eval_data_ltmb, eval_data_dds)
  
  expect_is(combined_eval_data, "list")
  expect_named(combined_eval_data, c("modelparams", "modeldispersion"))
})

# Test rbind_model_params_and_dispersion
test_that("rbind_model_params_and_dispersion returns correct output", {
  
  input_var_list <- init_variable( name = "varA", mu = 3, sd = 2, level = 3) 
  
  ## -- Required parameters
  N_GENES = 100
  MIN_REPLICATES = 3
  MAX_REPLICATES = 3
  ########################
  
  ## -- simulate RNAseq data based on input_var_list, minimum input required
  ## -- number of replicate randomly defined between MIN_REP and MAX_REP
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates  = MIN_REPLICATES,
                         max_replicates = MAX_REPLICATES)
  ## -- data from simulation 
  count_matrix <- mock_data$counts
  metaData <- mock_data$metadata
  ##############################

  dds <- DESeq2::DESeqDataSetFromMatrix(count_matrix, colData = metaData, ~ varA )
  dds <- DESeq2::DESeq(dds)
  
  
  eval_data <- get_eval_data_from_dds(dds, mock_data, 0.27, "greater")
  
  combined_data <- rbind_model_params_and_dispersion(eval_data)
  
  expect_is(combined_data, "data.frame")
})

# Test get_eval_data
test_that("get_eval_data returns correct output", {
  
  input_var_list <- init_variable( name = "varA", mu = 0, sd = 1, level = 3) 
  
  ## -- Required parameters
  N_GENES = 50
  MIN_REPLICATES = 3
  MAX_REPLICATES = 3
  ########################
  
  ## -- simulate RNAseq data based on input_var_list, minimum input required
  ## -- number of replicate randomly defined between MIN_REP and MAX_REP
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates  = MIN_REPLICATES,
                         max_replicates = MAX_REPLICATES)
  ## -- data from simulation 
  count_matrix <- mock_data$counts
  metaData <- mock_data$metadata
  ##############################
  ## -- convert counts matrix and samples metadatas in a data frame for fitting
  data2fit = prepareData2fit(countMatrix = count_matrix,
                           metadata =  metaData,
                           normalization = NULL)
  l_tmb <- fitModelParallel(formula = kij ~ varA  ,
                          data = data2fit,
                          group_by = "geneID",
                          family = glmmTMB::nbinom2(link = "log"),
                          n.cores = 1)
  dds <- DESeq2::DESeqDataSetFromMatrix(count_matrix, colData = metaData, ~ varA )
  dds <- DESeq2::DESeq(dds)
  
  
  eval_data <- get_eval_data(l_tmb, dds, mock_data, 0.27, "greater")
  
  expect_is(eval_data, "list")
  expect_named(eval_data, c("modelparams", "modeldispersion"))
  expect_equal(unique(eval_data$modelparams$from), c("HTRfit", "DESeq2"))
  
  
  ## -- dds == NULL
  eval_data <- get_eval_data(l_tmb, NULL, mock_data, 0.27, "greater")
  expect_is(eval_data, "list")
  expect_named(eval_data, c("modelparams", "modeldispersion"))
  expect_equal(unique(eval_data$modelparams$from), c("HTRfit"))
  
  
    ## -- l_tmb == NULL
  eval_data <- get_eval_data(NULL, dds, mock_data, 0.27, "greater")
  expect_is(eval_data, "list")
  expect_named(eval_data, c("modelparams", "modeldispersion"))
  expect_equal(unique(eval_data$modelparams$from), c("DESeq2"))

})



```


```{r function-wrapper_dds, filename =  "wrapper_dds" }

#' Wrapper Function for DESeq2 Analysis
#'
#' This function performs differential expression analysis using DESeq2 based on the provided
#' DESeqDataSet (dds) object. It calculates the dispersion values from the dds object and then
#' performs inference on the log-fold change (LFC) values using the specified parameters.
#'
#' @param dds A DESeqDataSet object containing the count data and experimental design.
#' @param lfcThreshold The threshold for minimum log-fold change (LFC) to consider differentially expressed.
#' @param altHypothesis The alternative hypothesis for the analysis, indicating the direction of change.
#'                      Options are "greater", "less", or "two.sided".
#' @param correction_method The method for p-value correction. Default is "BH" (Benjamini-Hochberg).
#'
#' @return A list containing the dispersion values and the results of the differential expression analysis.
#'         The dispersion values are calculated from the dds object and named according to sample names.
#'         The inference results include adjusted p-values and log2 fold changes for each gene.
#'
#' @examples
#' N_GENES = 100
#' MAX_REPLICATES = 5
#' MIN_REPLICATES = 5
#' BASAL_EXP = 10
#' ## --init variable
#' input_var_list <- init_variable( name = "genotype", sd = 1, level = 3) %>%
#'                    init_variable(name = "environment", sd = 0.9 , level = 2) 
#' mock_data <- mock_rnaseq(input_var_list, N_GENES, MIN_REPLICATES, 
#'                          MAX_REPLICATES, basal_expression = BASAL_EXP)
#' dds <- DESeq2::DESeqDataSetFromMatrix(mock_data$counts , 
#'                    mock_data$metadata, ~ genotype + environment)
#' dds <- DESeq2::DESeq(dds, quiet = TRUE)
#' result <- wrap_dds(dds, lfcThreshold = 1, altHypothesis = "greater")
#' @export
wrap_dds <- function(dds, lfcThreshold , altHypothesis, correction_method = "BH") {
  dds_full <- as.data.frame(S4Vectors::mcols(dds))
  
  ## -- dispersion
  message("INFO: The dispersion values from DESeq2 were reparametrized to their reciprocals (1/dispersion).")
  dispersion <- 1/dds_full$dispersion
  names(dispersion) <- rownames(dds_full)

  ## -- coeff
  inference_df <- get_inference_dds(dds_full, lfcThreshold, altHypothesis, correction_method)
  res <- list(dispersion = dispersion, fixEff = inference_df)
  return(res)
}



#' Calculate Inference for Differential Expression Analysis
#'
#' This function calculates inference for differential expression analysis based on the results of DESeq2.
#'
#' @param dds_full A data frame containing DESeq2 results, including estimate and standard error information.
#' @param lfcThreshold Log fold change threshold for determining differentially expressed genes.
#' @param altHypothesis Alternative hypothesis for testing, one of "greater", "less", or "two.sided".
#' @param correction_method Method for multiple hypothesis correction, e.g., "BH" (Benjamini-Hochberg).
#'
#' @return A data frame containing inference results, including statistics, p-values, and adjusted p-values.
#'
#' @examples
#' \dontrun{
#' # Example usage of the function
#' inference_result <- get_inference_dds(dds_full, lfcThreshold = 0.5, 
#'                                    altHypothesis = "greater", 
#'                                    correction_method = "BH")
#' }
#' @importFrom stats p.adjust
#' @export
get_inference_dds <- function(dds_full, lfcThreshold, altHypothesis, correction_method){

  ## -- build subdtf
  stdErr_df <- getSE_df(dds_full)
  estim_df <- getEstimate_df(dds_full)
  ## -- join
  df2ret <- join_dtf(estim_df, stdErr_df, k1 = c("ID", "term") , k2 = c("ID", "term"))

  ## -- convert to ln
  message("INFO: The log2-fold change estimates and standard errors from DESeq2 were converted to the natural logarithm scale.")
  df2ret$estimate <- df2ret$estimate*log(2)
  df2ret$std.error <- df2ret$std.error*log(2)

  ## -- some details reshaped
  df2ret$term <- gsub("_vs_.*","", df2ret$term)
  df2ret$term <- gsub(pattern = "_", df2ret$term, replacement = "")
  df2ret$term <- removeDuplicatedWord(df2ret$term)
  df2ret$term <- gsub(pattern = "[.]", df2ret$term, replacement = ":")
  df2ret$effect <- "fixed"
  idx_intercept <- df2ret$term == "Intercept"
  df2ret$term[idx_intercept] <- "(Intercept)"

  ## -- statistical part
  waldRes <- wald_test(df2ret$estimate, df2ret$std.error, lfcThreshold, altHypothesis)
  df2ret$statistic <- waldRes$statistic
  df2ret$p.value <- waldRes$p.value
  df2ret$p.adj <- stats::p.adjust(df2ret$p.value, method = correction_method)

  return(df2ret)
}


#' Extract Standard Error Information from DESeq2 Results
#'
#' This function extracts the standard error (SE) information from DESeq2 results.
#'
#' @param dds_full A data frame containing DESeq2 results, including standard error columns.
#'
#' @return A data frame with melted standard error information, including gene IDs and terms.
#'
#' @examples
#' \dontrun{
#' # Example usage of the function
#' se_info <- getSE_df(dds_full)
#' }
#' @importFrom reshape2 melt
#' @export
getSE_df <- function(dds_full){
  columnsInDds_full <- colnames(dds_full)
  SE_columns <- columnsInDds_full [ grepl("SE" , columnsInDds_full) ]
  SE_df <- dds_full[, SE_columns]
  SE_df$ID <- rownames(SE_df)
  SE_df_long <- reshape2::melt(SE_df,
                                       measure.vars = SE_columns,
                                       variable.name  = "term", value.name = "std.error", drop = F)
  SE_df_long$term <- gsub(pattern = "SE_", SE_df_long$term, replacement = "")
  return(SE_df_long)

}


#' Extract Inferred Estimate Information from DESeq2 Results
#'
#' This function extracts the inferred estimate values from DESeq2 results.
#'
#' @param dds_full A data frame containing DESeq2 results, including estimate columns.
#'
#' @return A data frame with melted inferred estimate information, including gene IDs and terms.
#'
#' @examples
#' \dontrun{
#' # Example usage of the function
#' estimate_info <- getEstimate_df(dds_full)
#'  }
#' @importFrom reshape2 melt
#' @export
getEstimate_df <- function(dds_full){
  columnsInDds_full <- colnames(dds_full)
  SE_columns <- columnsInDds_full [ grepl("SE" , columnsInDds_full) ]
  inferedVal_columns <- gsub("SE_", "" , x = SE_columns)

  estimate_df <- dds_full[, inferedVal_columns]
  estimate_df$ID <- rownames(estimate_df)
  estimate_df_long <- reshape2::melt(estimate_df,
                                 measure.vars = inferedVal_columns,
                                 variable.name  = "term", value.name = "estimate", drop = F)
  return(estimate_df_long)

}

```


```{r test-wrapper_dds}


test_that("get_inference_dds returns a data frame with correct columns", {
  # Create a sample dds_full data frame
  N_GENES = 100
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  ## --init variable
  input_var_list <- init_variable( name = "genotype", mu = 12, sd = 0.1, level = 3) %>%
                    init_variable(name = "environment", mu = c(0,1), NA , level = 2) 

  mock_data <- mock_rnaseq(input_var_list, N_GENES, MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  dds <- DESeq2::DESeqDataSetFromMatrix(mock_data$counts , mock_data$metadata, ~ genotype + environment)
  dds <- DESeq2::DESeq(dds, quiet = TRUE)
  dds_full <- S4Vectors::mcols(dds) %>% as.data.frame()
  
  # Call the function
  inference_results <- get_inference_dds(dds_full, lfcThreshold = 0.5, altHypothesis = "greater", correction_method = "BH")
  
  # Check if the returned object is a data frame
  expect_true(is.data.frame(inference_results))
  
  # Check if the data frame contains the correct columns
  expect_true("ID" %in% colnames(inference_results))
  expect_true("estimate" %in% colnames(inference_results))
  expect_true("std.error" %in% colnames(inference_results))
  expect_true("term" %in% colnames(inference_results))
  expect_true("effect" %in% colnames(inference_results))
  expect_true("statistic" %in% colnames(inference_results))
  expect_true("p.value" %in% colnames(inference_results))
  expect_true("p.adj" %in% colnames(inference_results))
})






test_that("getEstimate_df function works correctly", {
  
 # Create a sample dds_full data frame
  N_GENES = 100
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  ## --init variable
  input_var_list <- init_variable( name = "genotype", mu = 12, sd = 0.1, level = 3) %>%
                    init_variable(name = "environment", mu = c(0,1), NA , level = 2) 

  mock_data <- mock_rnaseq(input_var_list, N_GENES, MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  dds <- DESeq2::DESeqDataSetFromMatrix(mock_data$counts , mock_data$metadata, ~ genotype + environment)
  dds <- DESeq2::DESeq(dds, quiet = TRUE)
  dds_full <- S4Vectors::mcols(dds) %>% as.data.frame()
  
  # Call the function
  estimate_df_long <- getEstimate_df(dds_full)
  
  # Check if the resulting data frame has the expected structure
  expect_true("ID" %in% colnames(estimate_df_long))
  expect_true("term" %in% colnames(estimate_df_long))
  expect_true("estimate" %in% colnames(estimate_df_long))
})



# Define a test context
test_that("getSE_df function works correctly", {
  
 # Create a sample dds_full data frame
  N_GENES = 100
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  ## --init variable
  input_var_list <- init_variable( name = "genotype", mu = 12, sd = 0.1, level = 3) %>%
                    init_variable(name = "environment", mu = c(0,1), NA , level = 2) 

  mock_data <- mock_rnaseq(input_var_list, N_GENES, MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  dds <- DESeq2::DESeqDataSetFromMatrix(mock_data$counts , mock_data$metadata, ~ genotype + environment)
  dds <- DESeq2::DESeq(dds, quiet = TRUE)
  dds_full <- S4Vectors::mcols(dds) %>% as.data.frame()
  
  # Call the function
  SE_df_long <- getSE_df(dds_full)
  
  # Check if the resulting data frame has the expected structure
  expect_true("ID" %in% colnames(SE_df_long))
  expect_true("term" %in% colnames(SE_df_long))
  expect_true("std.error" %in% colnames(SE_df_long))
})


# Define a test context
test_that("wrapperDESeq2 function works correctly", {
  
 # Create a sample dds_full data frame
  N_GENES = 100
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  ## --init variable
  input_var_list <- init_variable( name = "genotype", mu = 12, sd = 0.1, level = 3) %>%
                    init_variable(name = "environment", mu = c(0,1), NA , level = 2) 

  mock_data <- mock_rnaseq(input_var_list, N_GENES, MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  dds <- DESeq2::DESeqDataSetFromMatrix(mock_data$counts , mock_data$metadata, ~ genotype + environment)
  dds <- DESeq2::DESeq(dds, quiet = TRUE)
  deseq2_wrapped <- wrap_dds(dds, 0.2, "greaterAbs")
  
  expect_true(is.list(deseq2_wrapped))
  
  # Check if the resulting data frame has the expected structure
  expect_true("ID" %in% colnames(deseq2_wrapped$fixEff))
  expect_true("term" %in% colnames(deseq2_wrapped$fixEff))
  expect_true("std.error" %in% colnames(deseq2_wrapped$fixEff))
  expect_true("estimate" %in% colnames(deseq2_wrapped$fixEff))
  expect_true("statistic" %in% colnames(deseq2_wrapped$fixEff))
  expect_true("p.value" %in% colnames(deseq2_wrapped$fixEff))
  expect_true("p.adj" %in% colnames(deseq2_wrapped$fixEff))

})


```


```{r function-anova, filename =  "anova"}

#' Handle ANOVA Errors
#'
#' This function handles ANOVA errors and warnings during the ANOVA calculation process.
#'
#' @param list_tmb A list of fitted glmmTMB models.
#' @param group A character string indicating the group for which ANOVA is calculated.
#' @param ... Additional arguments to be passed to the \code{car::Anova} function.
#' 
#' @return A data frame containing ANOVA results for the specified group.
#' @export
#' 
#' @examples
#' list_tmb <- fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length,
#'                           data = iris, group_by = "Species", n.cores = 1)
#' anova_res <- handleAnovaError(list_tmb, "setosa", type = "III")
#'
#' @importFrom car Anova
#' @export
handleAnovaError <- function(list_tmb, group, ...) {
  tryCatch(
    expr = {
      withCallingHandlers(
        car::Anova(list_tmb[[group]], ...),
        warning = function(w) {
          message(paste(Sys.time(), "warning for group", group, ":", conditionMessage(w)))
          invokeRestart("muffleWarning")
        })
    },
    error = function(e) {
      message(paste(Sys.time(), "error for group", group, ":", conditionMessage(e)))
      NULL
    }
  )
}


#' Perform ANOVA on Multiple glmmTMB Models in Parallel
#'
#' This function performs analysis of variance (ANOVA) on a list of \code{glmmTMB}
#' models in parallel for different groups specified in the list. It returns a list
#' of ANOVA results for each group.
#'
#' @param list_tmb A list of \code{glmmTMB} models, with model names corresponding to the groups.
#' @param ... Additional arguments passed to \code{\link[stats]{anova}} function.
#'
#' @return A list of ANOVA results for each group.
#' @importFrom stats setNames
#' @examples
#' # Perform ANOVA
#' data(iris)
#' list_tmb<- fitModelParallel( Sepal.Length ~ Sepal.Width  + Petal.Length, 
#'                          data = iris, group_by = "Species", n.cores = 1 )
#' anov_res <- anovaParallel(list_tmb , type = "III")
#' @importFrom stats anova
#' @export
anovaParallel <- function(list_tmb, ...) {
  invisible(isValidList_tmb(list_tmb))
  l_group <- attributes(list_tmb)$names
  lapply(stats::setNames(l_group, l_group), function(group) handleAnovaError(list_tmb, group, ...))
}


```


```{r  test-anova}


test_that("handleAnovaError return correct ouptut", {
  data(iris)
  l_tmb <- fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length,
                            data = iris, group_by = "Species", n.cores = 1)
  anova_res <- handleAnovaError(l_tmb, "setosa", type = "III")
  
  expect_s3_class(anova_res, "data.frame")
  expect_equal(nrow(anova_res), 3)  # Number of levels
})

test_that("handleAnovaError return correct ouptut", {
  data(iris)
  l_tmb <- fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length,
                            data = iris, group_by = "Species", n.cores = 1)
  anova_res <- handleAnovaError(l_tmb, "INALID_GROUP", type = "III")
  
  expect_null(anova_res)
})



test_that("anovaParallel returns valid ANOVA results", {
  data(iris)
  l_tmb <- fitModelParallel(Sepal.Length ~ Sepal.Width + Petal.Length,
                            data = iris, group_by = "Species", n.cores = 1)
  anov_res <- anovaParallel(l_tmb, type = "III")
  
  expect_is(anov_res, "list")
  expect_equal(length(anov_res), length(unique(iris$Species)))
  
})





```

```{r function-R2, filename = "rsquare"}



#' Compute R-squared values for linear regression on grouped data
#'
#' This function takes a data frame, performs linear regression on specified grouping variables,
#' and computes R-squared values for each group.
#'
#' @param data A data frame containing the variables 'actual' and 'estimate' for regression.
#' @param grouping_by A character vector specifying the grouping variables for regression.
#' @return A data frame with columns 'from', 'term', and 'R2' representing the grouping variables
#' and the corresponding R-squared values.
#' @export
#' @examples
#' data <- data.frame(from = c("A", "A", "A", "A"),
#'                    term = c("X", "Y", "X", "Y"),
#'                    actual = c(1, 2, 3, 4),
#'                    estimate = c(1.5, 2.5, 3.5, 4.5))
#' compute_rsquare(data, grouping_by = c("from", "term"))
#'
#' @importFrom data.table data.table
compute_rsquare <- function(data, grouping_by =  c("from", "description") ){
  ## -- convert to data.table
  dat <- data.table::data.table(data)
  ## -- calculate the regression coefficient r^2
  r_square_df <- as.data.frame( 
                              dat[ , summary(lm(actual~estimate))$r.squared, 
                              by = grouping_by ]
                              )
  names(r_square_df)[names(r_square_df) == "V1"] <- "R2"
  return(r_square_df)
}


```


```{r test-rsquare}


# Test case 1: Check if the function returns a data frame
test_that("compute_rsquare returns a data frame", {
  data <- data.frame(from = c("A", "A", "A", "A"),
                    description = c("X", "Y", "X", "Y"),
                    actual = c(1, 2, 3, 4),
                    estimate = c(10, 20, 30, 40))
  df_rsquare <- compute_rsquare(data, grouping_by = c("from", "description"))
  expect_s3_class(df_rsquare, "data.frame")
  expect_equal(df_rsquare$from, c("A", "A"))
  expect_equal(df_rsquare$description, c("X", "Y"))
  expect_equal(df_rsquare$R2, c(1, 1))

})


```


```{r function-RMSE, filename =  "rmse"}

#' Root Mean Squared Error (RMSE)
#'
#' Calculates the root mean squared error (RMSE) between two vectors.
#'
#' @param y Vector of actual values.
#' @param y_hat Vector of estimates/predicted values.
#' @return RMSE value.
#' @export
rmse <- function(y, y_hat) {
  if (length(y) != length(y_hat)) {
    stop("RMSE: Vectors y and y_hat must have the same length.")
  }
  
  rmse <- sqrt(mean((y - y_hat)^2, na.rm = T))
  return(rmse)
}




#' Compute RMSE values on grouped data
#'
#' This function takes a data frame, performs RMSE between estimate and actual values on specified grouping variables,
#'
#' @param data A data frame containing the variables 'actual' and 'estimate' for regression.
#' @param grouping_by A character vector specifying the grouping variables
#' @return A data frame with columns 'from', 'term', and 'RMSE' representing the grouping variables
#' and the corresponding RMSE values.
#' @export
#' @examples
#' data <- data.frame(from = c("A", "A", "A", "A"),
#'                    term = c("X", "Y", "X", "Y"),
#'                    actual = c(1, 2, 3, 4),
#'                    estimate = c(1.5, 2.5, 3.5, 4.5))
#' compute_rmse(data, grouping_by = c("from", "term"))
#'
#' @importFrom data.table data.table
compute_rmse <- function(data, grouping_by = c("from", "description")) {
  ## -- convert to data.table
  dat <- data.table::data.table(data)
  ## -- calculate the RMSE
  rmse_df <- as.data.frame( 
                dat[ , rmse(actual, estimate), 
                 by = grouping_by ]
            )
  names(rmse_df)[names(rmse_df) == "V1"] <- "RMSE"
  return(rmse_df)
}


```



```{r test-rmse}



# Test case 1: Check if the function returns a data frame
test_that("compute_rmse returns a data frame", {
  data <- data.frame(from = c("A", "A", "A", "A"),
                    description = c("X", "Y", "X", "Y"),
                    actual = c(1, 2, 3, 4),
                    estimate = c(10, 20, 30, 40))
  df_rmse <- compute_rmse(data, grouping_by = c("from", "description"))
  expect_s3_class(df_rmse, "data.frame")
  expect_equal(df_rmse$from, c("A", "A"))
  expect_equal(df_rmse$description, c("X", "Y"))
  expect_equal(df_rmse$RMSE, c(20.12461, 28.46050), tolerance = 1e-3)
  
  
  ## -- exact match 
  data <- data.frame(from = c("A", "A", "A", "A"),
                    description = c("X", "Y", "X", "Y"),
                    actual = c(1, 2, 3, 4),
                    estimate = c(1, 2, 3, 4))
  df_rmse <- compute_rmse(data, grouping_by = c("from", "description"))
  expect_s3_class(df_rmse, "data.frame")
  expect_equal(df_rmse$from, c("A", "A"))
  expect_equal(df_rmse$description, c("X", "Y"))
  expect_equal(df_rmse$RMSE, c(0, 0), tolerance = 0)

})



test_that("RMSE function calculates correct RMSE values", {
  # Test case 1: Same vectors
  y1 <- c(1, 2, 3, 4, 5)
  y_hat1 <- c(1, 2, 3, 4, 5)
  expect_equal(rmse(y1, y_hat1), 0, 
               info = "RMSE between identical vectors should be 0.")
  
  # Test case 2: Different vectors
  y2 <- c(1, 2, 3, 4, 5)
  y_hat2 <- c(2, 3, 4, 5, 6)
  expect_equal(round(rmse(y2, y_hat2), 2), 1, 
               info = "RMSE between different vectors should be approximately 1")
  
  # Test case 3: vector with NA
  y3 <- c(1, 2, 3, 4)
  y_hat3 <- c(1, 2, NA, 4)
  expect_equal(rmse(y3, y_hat3), 0)
  
  # Test case 4: Vectors with different lengths
  y4 <- c(1, 2, 3, 4, 5)
  y_hat4 <- c(1, 2, 3, 4)
  expect_error(rmse(y4, y_hat4), 
               info = "RMSE should throw an error for vectors with different lengths.")
})



```



```{r function-subsetGenes, filename =  "subsetGenes"}

#' Subset Genes in Genomic Data
#'
#' This function filters and adjusts genomic data within the Roxygeb project, based on a specified list of genes.
# It is designed to enhance precision and customization in transcriptomics analysis by retaining only the genes of interest.
# 
#' @param l_genes A character vector specifying the genes to be retained in the dataset.
#' @param mockObj An object containing relevant genomic information to be filtered.
#'
#' @return A modified version of the 'mockObj' data object, with genes filtered according to 'l_genes'.
#'
#' @description The 'subsetGenes' function selects and retains genes from 'mockObj' that match the genes specified in 'l_genes'.
# It filters the 'groundTruth$effects' data to keep only the rows corresponding to the selected genes. 
# Additionally, it updates 'gene_dispersion' and the count data, ensuring that only the selected genes are retained.
# The function also replaces the total number of genes in 'settings$values' with the length of 'l_genes'.
# The result is a more focused and tailored genomic dataset, facilitating precision in subsequent analyses.
#'
#' @export
#' @examples
#' N_GENES = 100
#' MAX_REPLICATES = 5
#' MIN_REPLICATES = 5
#' input_var_list <- init_variable(name = "varA", sd = 0.1, level = 3)
#' mock_data <- mock_rnaseq(input_var_list, N_GENES,
#'                         min_replicates = MIN_REPLICATES, 
#'                         max_replicates = MAX_REPLICATES)
#' subset_mockobj <- subsetGenes(mock_data, l_genes = c("gene1", "gene4"))
subsetGenes <- function(l_genes, mockObj) {
  # Selects the indices of genes in 'groundTruth$effects$geneID' that are present in 'l_genes'.
  idx_gt_effects <- mockObj$groundTruth$effects$geneID %in% l_genes
  
  # Filters 'groundTruth$effects' to keep only the rows corresponding to the selected genes.
  mockObj$groundTruth$effects <- mockObj$groundTruth$effects[idx_gt_effects, ]
  
  # Updates 'gene_dispersion' by retaining values corresponding to the selected genes.
  mockObj$groundTruth$gene_dispersion <- mockObj$groundTruth$gene_dispersion[l_genes]
  
  # Filters the count data to keep only the rows corresponding to the selected genes.
  mockObj$counts <- as.data.frame(mockObj$counts[l_genes, ])
  
  # Replaces the total number of genes in 'settings$values' with the length of 'l_genes'.
  mockObj$settings$values[1] <- length(l_genes)
  
  # Returns the modified 'mockObj'.
  return(mockObj)
}


```


```{r  test-subsetGenes}
test_that("subsetGenes return correct ouptut", {
  N_GENES = 100
  MAX_REPLICATES = 5
  MIN_REPLICATES = 5
  input_var_list <- init_variable(name = "varA", mu = 10, sd = 0.1, level = 3)
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates = MIN_REPLICATES, max_replicates = MAX_REPLICATES)
  
  subset_mockobj <- subsetGenes(mock_data, l_genes = c("gene1", "gene4"))
  
  mock_data$settings$values[1] <- 2
  expect_equal(mock_data$settings, subset_mockobj$settings )
  expect_equal(mock_data$init, subset_mockobj$init )
  
  expect_equal(subset(mock_data$groundTruth$effects, geneID %in% c("gene1", "gene4")) , subset_mockobj$groundTruth$effects )
  expect_equal(mock_data$groundTruth$gene_dispersion[c("gene1", "gene4")] , subset_mockobj$groundTruth$gene_dispersion )
  expect_equal(mock_data$counts[c("gene1", "gene4"),] , subset_mockobj$counts )
  expect_equal(mock_data$metadata , subset_mockobj$metadata )
})





```



```{r function-evaluation_withmixedeffect, filename =  "evaluation_withmixedeffect"}

#' Check if the formula contains a mixed effect structure.
#'
#' This function checks if the formula contains a mixed effect structure indicated by the presence of "|".
#'
#' @param formula A formula object.
#'
#' @return \code{TRUE} if the formula contains a mixed effect structure, \code{FALSE} otherwise.
#'
#' @examples
#' is_mixedEffect_inFormula(y ~ x + (1|group))
#'
#' @export
is_mixedEffect_inFormula <- function(formula) {
  return("|" %in% all.names(formula))
}

#' Check if the formula follows a specific type I mixed effect structure.
#'
#' This function checks if the formula follows a specific type I mixed effect structure, which consists of a fixed effect and a random effect indicated by the presence of "|".
#'
#' @param formula A formula object.
# 
#' @return \code{TRUE} if the formula follows the specified type I mixed effect structure, \code{FALSE} otherwise.
# 
#' @examples
#' is_formula_mixedTypeI(formula = y ~ x + (1|group))
# 
#' @export
is_formula_mixedTypeI <- function(formula) {
  if (length(all.vars(formula)) != 3) return(FALSE)
  if (sum(all.names(formula) == "+") > 1) return(FALSE)
  if (sum(all.names(formula) == "/") > 0) return(FALSE)
  all_var_in_formula <- all.vars(formula, unique = F)
  if (length(all_var_in_formula) == 4 && all_var_in_formula[2] != all_var_in_formula[3]) return(FALSE)
  return(TRUE)
}


#' Get the categorical variable associated with the fixed effect in a type I formula.
#'
#' This function extracts the categorical variable associated with the fixed effect in a type I formula from a tidy tibble.
# The categorical variable is constructed by taking the label of the second main fixed effect term (ignoring any numeric suffix) and prefixing it with "label_".
#
#' @param tidy_tmb A tidy tibble containing model terms.
# 
#' @return The categorical variable associated with the fixed effect in the type I formula.
# 
#' @examples
#' \dontrun{
#' getCategoricalVar_inFixedEffect(tidy_tmb)
#' } 
#' @export
getCategoricalVar_inFixedEffect <- function(tidy_tmb) {
  main_fixEffs <- unique(subset(tidy_tmb, effect == "fixed")$term)
  categorical_var_inFixEff <- paste("label", gsub("\\d+$", "", main_fixEffs[2]), sep = "_")
  return(categorical_var_inFixEff)
}


#' Group log_qij values per genes and labels.
#'
#' This function groups log_qij values in a ground truth tibble per genes and labels using a specified categorical variable.
#
#' @param ground_truth A tibble containing ground truth data.
#' @param categorical_var The categorical variable to use for grouping.
# 
#' @return A list of log_qij values grouped by genes and labels.
#' @importFrom stats as.formula
#' @importFrom reshape2 dcast
#' 
# 
#' @examples
#' \dontrun{
#' group_logQij_per_genes_and_labels(ground_truth, categorical_var)
#' }
#' @export
group_logQij_per_genes_and_labels <- function(ground_truth, categorical_var) {
  str_formula <- paste(c(categorical_var, "geneID"), collapse = " ~ ")
  formula <- stats::as.formula(str_formula)
  list_logqij <- ground_truth %>%
    reshape2::dcast(
      formula,
      value.var = "log_qij_scaled",
      fun.aggregate = list
    )
  list_logqij[categorical_var] <- NULL
  return(list_logqij)
}

#' Calculate actual mixed effect values for each gene.
#'
#' This function calculates actual mixed effect values for each gene using the provided data, reference labels, and other labels in a categorical variable.
#
#' @param list_logqij A list of log_qij values grouped by genes and labels.
#' @param genes_iter_list A list of genes for which to calculate the actual mixed effect values.
#' @param categoricalVar_infos Information about the categorical variable, including reference labels and other labels.
# 
#' @return A data frame containing the actual mixed effect values for each gene.
# 
#' @examples
#' \dontrun{
#' getActualMixed_typeI(list_logqij, genes_iter_list, categoricalVar_infos)
#' }
#' @export
getActualMixed_typeI <- function(list_logqij, genes_iter_list, categoricalVar_infos) {
  labelRef_InCategoricalVar <- categoricalVar_infos$ref
  labels_InCategoricalVar <- categoricalVar_infos$labels
  labelOther_inCategoricalVar <- categoricalVar_infos$labelsOther

  data_per_gene <- lapply(genes_iter_list, function(g) {
    data_gene <- data.frame(list_logqij[[g]])
    colnames(data_gene) <- labels_InCategoricalVar
    return(data_gene)
  })
  
  l_actual_per_gene <- lapply(genes_iter_list, function(g) {
    data_gene <- data_per_gene[[g]]
    res <- calculate_actualMixed(data_gene, labelRef_InCategoricalVar, labelOther_inCategoricalVar)
    res$geneID <- g
    return(res)
  })
  
  actual_mixedEff <- do.call("rbind", l_actual_per_gene)
  rownames(actual_mixedEff) <- NULL
  return(actual_mixedEff)
}



#' Compare the mixed-effects inference to expected values.
#'
#' This function compares the mixed-effects inference obtained from a mixed-effects model to expected values derived from a ground truth dataset. 
#' The function assumes a specific type I mixed-effect structure in the input model.
# 
#' @param tidy_tmb  tidy model results obtained from fitting a mixed-effects model.
#' @param ground_truth_eff A data frame containing ground truth effects.
# 
#' @return A data frame with the comparison of estimated mixed effects to expected values.
#' @importFrom stats setNames
#' @examples
#' \dontrun{
#' inferenceToExpected_withMixedEff(tidy_tmb(l_tmb), ground_truth_eff)
#' } 
#' @export
inferenceToExpected_withMixedEff <- function(tidy_tmb, ground_truth_eff){

  # -- CategoricalVar involve in fixEff
  categorical_var <- getCategoricalVar_inFixedEffect(tidy_tmb)
  labels_InCategoricalVar <- levels(ground_truth_eff[, categorical_var])
  labelRef_InCategoricalVar <- labels_InCategoricalVar[1]
  labelOther_inCategoricalVar <- labels_InCategoricalVar[2:length(labels_InCategoricalVar)]
  categoricalVar_infos <- list(ref = labelRef_InCategoricalVar,
                               labels = labels_InCategoricalVar,
                               labelsOther = labelOther_inCategoricalVar )

  ## -- prepare data 2 get actual
  l_logqij <- group_logQij_per_genes_and_labels(ground_truth_eff, categorical_var)
  l_genes <- unique(ground_truth_eff$geneID)
  genes_iter_list <- stats::setNames(l_genes,l_genes)
  actual_mixedEff <- getActualMixed_typeI(l_logqij, genes_iter_list, categoricalVar_infos)
  res <- join_dtf(actual_mixedEff, tidy_tmb,   c("geneID", "term"), c("ID", "term"))
  names(res)[names(res) == 'geneID'] <- 'ID'
  ## -- reorder for convenience
  actual <- res$actual
  res <- res[, -1]
  res$actual <- actual
  return(res)
}


#' Calculate actual mixed effects.
#'
#' This function calculates actual mixed effects based on the given data for a specific type I mixed-effect structure.
#  It calculates the expected values, standard deviations, and correlations between the fixed and random effects.
#  The function is designed to work with specific input data for type I mixed-effect calculations.
# 
#' @param data_gene Data for a specific gene.
#' @param labelRef_InCategoricalVar The reference label for the categorical variable.
#' @param labelOther_inCategoricalVar Labels for the categorical variable other than the reference label.
#' @importFrom stats sd cor
# 
#' @return A data frame containing the calculated actual mixed effects.
# 
#' @examples
#' \dontrun{
#'  calculate_actualMixed(data_gene, labelRef_InCategoricalVar, labelOther_inCategoricalVar)
#' }
#' @export
calculate_actualMixed <- function(data_gene, labelRef_InCategoricalVar, labelOther_inCategoricalVar ){
   log_qij_scaled_intercept <- data_gene[labelRef_InCategoricalVar]
  colnames(log_qij_scaled_intercept) <- '(Intercept)'

  if (length(labelOther_inCategoricalVar == 1 )) {
    log_qij_scaled_other <- data_gene[labelOther_inCategoricalVar]
  } else log_qij_scaled_other <- data_gene[,labelOther_inCategoricalVar]
  log_qij_scaled_transf <- log_qij_scaled_other - log_qij_scaled_intercept[,"(Intercept)"]

  log_qij_scaled_transf <- cbind(log_qij_scaled_intercept, log_qij_scaled_transf)
  ## -- fix eff
  actual_fixedValues <- colMeans(log_qij_scaled_transf)

  ## -- stdev values
  std_values <- sapply(log_qij_scaled_transf, function(x) stats::sd(x))
  names(std_values) <- paste("sd", names(std_values), sep = '_')

  ## -- correlation
  corr_mat <- stats::cor(log_qij_scaled_transf)
  indx <- which(upper.tri(corr_mat, diag = FALSE), arr.ind = TRUE)
  corr2keep = corr_mat[indx]
  name_corr <- paste(rownames(corr_mat)[indx[, "row"]], colnames(corr_mat)[indx[, "col"]], sep = ".")
  names(corr2keep) <- paste("cor", name_corr, sep = "__")

  ## -- output 
  actual <- c(actual_fixedValues, std_values, corr2keep)
  res <- as.data.frame(actual)
  res$term <- rownames(res)
  rownames(res) <- NULL
  res$description <- gsub("\\d+$", "" , res$term)
  return(res)
  
  
}


#' Compare inference results to expected values for a given model.
#'
#' This function compares the inference results from a model to the expected values based on a ground truth dataset with the simulated effects. The function handles models with mixed effects and fixed effects separately, ensuring that the comparison is appropriate for the specific model type.
#'
#' If a model includes mixed effects, the function checks for support for the specific mixed effect structure and provides an informative error message if the structure is not supported.
#'
#' @param tidy_tmb A fitted model object convert to tidy dataframe.
#' @param ground_truth_eff A ground truth dataset with the simulated effects.
#' @param formula_used formula used in model 
#'
#' @return A data frame containing the comparison results, including the term names, inference values, and expected values.
#'
#' @examples
#' \dontrun{
#' evalData <- compareInferenceToExpected(l_tmb, ground_truth_eff)
#' }
#' @export
compareInferenceToExpected <- function(tidy_tmb, ground_truth_eff, formula_used) {
  ## -- parsing formula & check mixed effect
  involvMixedEffect <- is_mixedEffect_inFormula(formula_used)

  msg_e_formula_type <- "This simulation evaluation supports certain types of formulas with mixed effects, but not all.
    Please refer to the package documentation for information on supported formula structures.
    You are welcome to implement additional functions to handle specific formula types with mixed effects that are not currently supported."

  ## -- if mixed effect
  if (involvMixedEffect){
    message("Mixed effect detected in the formula structure.")

    if(!is_formula_mixedTypeI(formula_used)){
      stop(msg_e_formula_type)
    }
    evalData <- inferenceToExpected_withMixedEff(tidy_tmb, ground_truth_eff)

  ## -- only fixed effect
  } else {
    
    message("Only fixed effects are detected in the formula structure.")
    evalData <- inferenceToExpected_withFixedEff(tidy_tmb, ground_truth_eff)
  }

  return(evalData)
}


```

```{r  test-evaluation_withmixedeffect}



test_that("Test is_mixedEffect_inFormula", {
  formula1 <- y ~ a + (1 | B)
  formula2 <- ~ a + (1 | B)
  formula3 <- x ~ c + d

  expect_true(is_mixedEffect_inFormula(formula1))
  expect_true(is_mixedEffect_inFormula(formula2))
  expect_false(is_mixedEffect_inFormula(formula3))
})

test_that("Test is_formula_mixedTypeI", {
  formula1 <- y ~ x + (1 | group)
  formula2 <- y ~ z + group1 + (1 | group1)
  formula3 <- y ~ z + (1 | group1 + group2)
  formula4 <- y ~ z + (1 | group1/z)
  formula5 <- y ~ z + ( group | z ) ## z is fixed then expected on the left in parenthesis
  
  expect_true(is_formula_mixedTypeI(formula1))
  expect_false(is_formula_mixedTypeI(formula2))
  expect_false(is_formula_mixedTypeI(formula3))
  expect_false(is_formula_mixedTypeI(formula4))
  expect_false(is_formula_mixedTypeI(formula5))


})


test_that("getCategoricalVar_inFixedEffect returns the correct result", {
  
    ###### PREPARE DATA
    N_GENES = 2
    MAX_REPLICATES = 4
    MIN_REPLICATES = 4

    input_var_list <- init_variable( name = "genotype", mu = 2, sd = 0.5, level = 10) %>%
      init_variable( name = "environment", mu = c(1, 3), sd = NA, level = 2) %>%
      add_interaction(between_var = c("genotype", 'environment'), mu = 1, sd = 0.39)
    
    mock_data <- mock_rnaseq(input_var_list, N_GENES,
                             min_replicates = MIN_REPLICATES,
                             max_replicates = MAX_REPLICATES,
                             basal_expression = 3, dispersion = 100)
    
    data2fit = prepareData2fit(countMatrix = mock_data$counts, metadata =  mock_data$metadata, normalization = NULL)
    
    l_tmb <- fitModelParallel(formula = kij ~  environment  + (environment | genotype ),
                              data = data2fit, group_by = "geneID",
                              family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
      
  
    tidy_tmb <- tidy_tmb(l_tmb)
    categorical_var <- getCategoricalVar_inFixedEffect(tidy_tmb)
    expect_equal(categorical_var, "label_environment")
})

test_that("group_logQij_per_genes_and_labels returns the correct result", {
    
    ############ PREPARE DATA
    N_GENES = 2
    MAX_REPLICATES = 4
    MIN_REPLICATES = 4
    input_var_list <- init_variable( name = "genotype", mu = 2, sd = 0.5, level = 10) %>%
      init_variable( name = "environment", mu = c(1, 3), sd = NA, level = 2) %>%
      add_interaction(between_var = c("genotype", 'environment'), mu = 1, sd = 0.39)
    
    mock_data <- mock_rnaseq(input_var_list, N_GENES,
                             min_replicates = MIN_REPLICATES,
                             max_replicates = MAX_REPLICATES,
                             basal_expression = 3, dispersion = 100)
    
    data2fit = prepareData2fit(countMatrix = mock_data$counts, metadata =  mock_data$metadata, normalization = NULL)
    
    l_tmb <- fitModelParallel(formula = kij ~  environment  + (environment | genotype ),
                              data = data2fit, group_by = "geneID",
                              family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
    
    ground_truth_eff <- mock_data$groundTruth$effects
    categorical_var <- "label_environment"
    logqij_list <- group_logQij_per_genes_and_labels(ground_truth_eff, categorical_var)
    
    expect_is(logqij_list, "data.frame")
    expect_equal(attributes(logqij_list)$names , c("gene1", "gene2"))
    expect_equal(length(logqij_list$gene1), 2)
    expect_equal(length(logqij_list$gene2), 2)
    expect_equal(length(logqij_list$gene2[[1]]), 10)
})

test_that("getActualMixed_typeI returns the correct result", {
   ############ PREPARE DATA
    N_GENES = 2
    MAX_REPLICATES = 4
    MIN_REPLICATES = 4
    input_var_list <- init_variable( name = "genotype", mu = 2, sd = 0.5, level = 10) %>%
      init_variable( name = "environment", mu = c(1, 3), sd = NA, level = 2) %>%
      add_interaction(between_var = c("genotype", 'environment'), mu = 1, sd = 0.39)
    
    mock_data <- mock_rnaseq(input_var_list, N_GENES,
                             min_replicates = MIN_REPLICATES,
                             max_replicates = MAX_REPLICATES,
                             basal_expression = 3, dispersion = 100)
    
    data2fit = prepareData2fit(countMatrix = mock_data$counts, metadata =  mock_data$metadata, normalization = NULL)
    
    l_tmb <- fitModelParallel(formula = kij ~  environment  + (environment | genotype ),
                              data = data2fit, group_by = "geneID",
                              family = glmmTMB::nbinom2(link = "log"), n.cores = 1)
    
    ground_truth_eff <- mock_data$groundTruth$effects
    categorical_var <- "label_environment"
    logqij_list <- group_logQij_per_genes_and_labels(ground_truth_eff, categorical_var)
    l_genes <- unique(ground_truth_eff$geneID)
    genes_iter_list <- stats::setNames(l_genes, l_genes)
    categoricalVar_infos= list(ref = "environment1", 
                             labels = c("environment1", "environment2"), 
                             labelsOther = "environment2")
    
    ## -- test
    actual_mixedEff <- getActualMixed_typeI(logqij_list, 
                                              genes_iter_list, 
                                                categoricalVar_infos)
    
    ## -- verif
    expect_is(actual_mixedEff, "data.frame")
    expect_equal(colnames(actual_mixedEff), c("actual", "term", "description", "geneID"))
    expect_equal(unique(actual_mixedEff$geneID), c("gene1", "gene2"))
    expect_equal(unique(actual_mixedEff$term), c("(Intercept)", "environment2", 
                                                 "sd_(Intercept)", "sd_environment2", "cor__(Intercept).environment2"))

})


# Test for InferenceToExpected_withMixedEff
test_that("inferenceToExpected_withMixedEff correctly compares inference to expected values", {
  
  ## -- PREPARE DATA
  N_GENES = 2
  MAX_REPLICATES = 4
  MIN_REPLICATES = 4
  
  input_var_list <- init_variable(name = "genotype", mu = 2, sd = 0.5, level = 10) %>%
  init_variable(name = "environment", mu = c(1, 3), sd = NA, level = 2) %>%
  add_interaction(between_var = c("genotype", 'environment'), mu = 1, sd = 0.39)
  
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates = MIN_REPLICATES,
                         max_replicates = MAX_REPLICATES,
                         basal_expression = 3, dispersion = 100)
  
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata = mock_data$metadata, normalization = NULL)
  
  l_tmb <- fitModelParallel(formula = kij ~ environment + (environment | genotype),
                          data = data2fit, group_by = "geneID",
                          family = glmmTMB::nbinom2(link = "log"), n.cores = 1)

  ## -- call fonction to test
  compared_df <- inferenceToExpected_withMixedEff(tidy_tmb(l_tmb), mock_data$groundTruth$effects)

  ## -- TEST VERIF
  expect_equal(c("term", "description", "ID", "effect", 
                "component", "group", "estimate", "std.error", 
                "statistic", "p.value", "actual" ) , colnames(compared_df))
  expect_equal(c("gene1", "gene2" ) , unique(compared_df$ID))
  expect_equal(unique(compared_df$term), c("(Intercept)", "cor__(Intercept).environment2", "environment2", 
                                                 "sd_(Intercept)", "sd_environment2"))

})

# Test for calculate_actualMixed
test_that("calculate_actualMixed calculates actual mixed effects as expected", {
   ## -- PREPARE DATA
  N_GENES = 2
  MAX_REPLICATES = 4
  MIN_REPLICATES = 4
  
  input_var_list <- init_variable(name = "genotype", mu = 2, sd = 0.5, level = 10) %>%
  init_variable(name = "environment", mu = c(1, 3), sd = NA, level = 2) %>%
  add_interaction(between_var = c("genotype", 'environment'), mu = 1, sd = 0.39)
  
  mock_data <- mock_rnaseq(input_var_list, N_GENES,
                         min_replicates = MIN_REPLICATES,
                         max_replicates = MAX_REPLICATES,
                         basal_expression = 3, dispersion = 100)
  
  data2fit <- prepareData2fit(countMatrix = mock_data$counts, metadata = mock_data$metadata, normalization = NULL)
  
  
  ground_truth_eff <- mock_data$groundTruth$effects
  categorical_var <- "label_environment"
  logqij_list <- group_logQij_per_genes_and_labels(ground_truth_eff, categorical_var)
  l_genes <- unique(ground_truth_eff$geneID)
  genes_iter_list <- stats::setNames(l_genes, l_genes)
  categoricalVar_infos= list(ref = "environment1", 
                           labels = c("environment1", "environment2"), 
                           labelsOther = "environment2")
    
  ## -- call function & test
  data_per_gene <- lapply(genes_iter_list, function(g) {
                          data_gene <- data.frame(logqij_list[[g]])
                          colnames(data_gene) <- categoricalVar_infos$labels
                          return(data_gene)
                    })
  data_gene <- data_per_gene$gene1
  actual_mixed <- calculate_actualMixed(data_gene, 
                                        labelRef_InCategoricalVar = categoricalVar_infos$ref ,
                                        labelOther_inCategoricalVar = categoricalVar_infos$labelsOther)
  expect_equal( colnames(actual_mixed), c("actual", "term", "description"))
  expect_equal(actual_mixed$term, c("(Intercept)", "environment2", 
                                    "sd_(Intercept)", "sd_environment2", 
                                    "cor__(Intercept).environment2"))
  expect_equal(actual_mixed$description, c("(Intercept)", "environment", 
                                    "sd_(Intercept)", "sd_environment", 
                                    "cor__(Intercept).environment"))
})



```


```{r function-export_evaluation_report, filename =  "export_evaluation_report" }

#' Checks if an eval_report object is valid
#'
#' @param obj The object to be checked.
#' @return TRUE if the object is valid, otherwise stops the script.
#' @details This function verifies if the 'eval_report_obj' object corresponds to an eval_report object generated by the 'evaluation_report()' function. It also ensures that the object is a list containing the expected elements: 'data', 'identity', 'precision_recall', 'roc', 'counts', and 'performances'. If the object does not match these expectations, an error message is displayed.
#' @examples
#' \dontrun{
#' # Using isValidEval_report
#' eval_report <- evaluation_report()
#' isValidEval_report(eval_report)
#' }
#' @export
isValidEval_report <- function(obj){
  message_err <- "'eval_report_obj' does not correspond to HTRfit eval_report_obj. 'eval_report_obj' can be generated using evaluation_report()."
  stopifnot(is.list(obj))
  
  if (all(sapply(obj, is.null))) {
    stop("All elements in 'obj' are NULL")
  }
  ## level 1 
  expected_names <- c("data", "identity", "counts", "performances")
  ## N.B : "precision_recall", "roc" names are not mandatory
  if (!all(expected_names %in% names(obj))) {
    stop(message_err)
  }
  if (!all(names(obj) %in% expected_names)){
    warning("Unexpected list element in 'eval_report_obj'")
  }
  return(TRUE)
}

#' Exports a dataframe to a specified file
#'
#' @param data The dataframe to be exported.
#' @param outfile The name of the output file.
#' @return None (the function writes to a file).
#' @examples
#' \dontrun{
#' # Using export_dataframe
#' export_dataframe(dataframe, "output.tsv")
#' }
#' @export
export_dataframe <- function(data, outfile){
  write.table(data, outfile, quote = FALSE, sep = "\t", row.names = FALSE)
}

#' Exports an eval_report object to a specified folder
#'
#' @param eval_report_obj The eval_report object to be exported.
#' @param outfolder The output folder to export the results to. Default "." for current working directory.
#' @param plot_format The format of the plots to export (default: "png").
#' @param ... Additional arguments to be passed to ggsave while exporting identity_modelparams. 
#' For instance use `width = 20` to adjust the width of the plot identity_modelparams.
#' @return A message indicating that the export is done.
#' @details This function exports the data contained in the eval_report object to individual files in the specified folder. The exported data includes model parameters, model dispersion, aggregated performances, performances by parameters, and evaluation plots.
#' @examples
#' \dontrun{
#' eval_report <- evaluation_report()
#' export_evaluation_report(eval_report, "output_folder")
#' }
#' @export
export_evaluation_report <- function(eval_report_obj, outfolder = ".", plot_format = "png", ...){
      invisible(isValidEval_report(eval_report_obj))
      dir.create(file.path(outfolder), showWarnings = TRUE)
      user_workdir <- getwd()
      setwd(outfolder)
      message(paste("INFO: Exporting results in :", outfolder, sep = " " ))
      ## data export
      export_dataframe(eval_report_obj$data$modelparams, "data_modelparams.tsv")
      export_dataframe(eval_report_obj$data$modeldispersion, "data_modeldispersion.tsv")
      ## perf export
      export_dataframe(eval_report_obj$performances$aggregate, "performances_aggregated.tsv")
      export_dataframe(eval_report_obj$performances$byparams, "performances_byparams.tsv")
      export_eval_plots(eval_report_obj, plot_format, ...)
      setwd(user_workdir)
      return("Export: Done")
}

#' Exports evaluation plots from an eval_report object
#'
#' @param eval_report_obj The eval_report object containing evaluation plots.
#' @param extension The file extension/format of the exported plots.
#' @param ... Additional arguments to be passed to ggsave while exporting identity_modelparams. For instance use `width = 20` to adjust the width of the plot identity_modelparams.
#' @return None (the function writes the plots to files).
#' @details This function exports evaluation plots, including identity model parameters, identity model dispersion, precision-recall by parameters, precision-recall aggregated, ROC by parameters, ROC aggregated, and genes expression plots from the eval_report object to individual files in a 'plots' folder within the current working directory.
#' @examples
#' \dontrun{
#' export_eval_plots(eval_report, "png", width = 20)
#' }
#' @importFrom ggplot2 ggsave
#' @export
export_eval_plots <- function(eval_report_obj, extension, ...){
  
  dir.create("plots", showWarnings = FALSE)
  setwd("./plots")
  
  ggplot2::ggsave(filename = paste("identity_modelparams" , extension, sep = "."), eval_report_obj$identity$params, ...)
  ggplot2::ggsave(filename = paste("identity_modeldispersion" , extension, sep = "."), eval_report_obj$identity$dispersion )
  ggplot2::ggsave(filename = paste("genes_expression" , extension, sep = "."), eval_report_obj$counts)
  
  if (all(c("roc", "precision_recall") %in% names(eval_report_obj))){
      ggplot2::ggsave(filename = paste("precision_recall_byparams" , extension, sep = "."), eval_report_obj$precision_recall$params )
      ggplot2::ggsave(filename = paste("precision_recall_aggregated" , extension, sep = "."), eval_report_obj$precision_recall$aggregate)
      ggplot2::ggsave(filename = paste("roc_byparams" , extension, sep = "."), eval_report_obj$roc$params )
      ggplot2::ggsave(filename = paste("roc_aggregated" , extension, sep = "."), eval_report_obj$roc$aggregate )
  }

}


```

```{r test-export_evaluation_report}

# Test pour la fonction isValidEval_report
test_that("isValidEval_report returns TRUE for a valid eval_report object", {
  
  eval_report_obj <- list(
  data = list(
    modelparams = data.frame(a = 1:5, b = letters[1:5]),
    modeldispersion = data.frame(x = runif(10), y = rnorm(10))
  ),
  performances = list(
    aggregate = data.frame(x = runif(10), y = rnorm(10)),
    byparams = data.frame(param = 1:10, x = runif(10), y = rnorm(10))
  ),
  identity = list(
    params = ggplot2::qplot(x = rnorm(100)),
    dispersion = ggplot2::qplot(y = runif(100))
  ),
  precision_recall = list(
    params = ggplot2::qplot(x = runif(100)),
    aggregate = ggplot2::qplot(y = rnorm(100))
  ),
  roc = list(
    params = ggplot2::qplot(x = rnorm(100)),
    aggregate = ggplot2::qplot(y = runif(100))
  ),
  counts = ggplot2::qplot(x = rnorm(100))
)
  ################## -- TESTS
  expect_true(isValidEval_report(eval_report_obj))
})

# Test pour la fonction export_dataframe
test_that("export_dataframe writes dataframe to file without errors", {
  dataframe <- data.frame(x = 1:10, y = rnorm(10))
  outfile <- tempfile()
  export_dataframe(dataframe, outfile)
  expect_true(file.exists(outfile))
})

# Test pour la fonction export_eval_plots
test_that("export_eval_plots writes evaluation plots to files without errors", {
  
  eval_report_obj <- list(
  data = list(
    modelparams = data.frame(a = 1:5, b = letters[1:5]),
    modeldispersion = data.frame(x = runif(10), y = rnorm(10))
  ),
  performances = list(
    aggregate = data.frame(x = runif(10), y = rnorm(10)),
    byparams = data.frame(param = 1:10, x = runif(10), y = rnorm(10))
  ),
  identity = list(
    params = ggplot2::qplot(x = rnorm(100)),
    dispersion = ggplot2::qplot(y = runif(100))
  ),
  precision_recall = list(
    params = ggplot2::qplot(x = runif(100)),
    aggregate = ggplot2::qplot(y = rnorm(100))
  ),
  roc = list(
    params = ggplot2::qplot(x = rnorm(100)),
    aggregate = ggplot2::qplot(y = runif(100))
  ),
  counts = ggplot2::qplot(x = rnorm(100))
)
  
  ############################# -- TEST ###############"
  outfolder <- tempdir()
  setwd(outfolder)
  export_eval_plots(eval_report_obj, "png")
  expect_true(file.exists(file.path(outfolder, "plots/identity_modelparams.png")))
  expect_true(file.exists(file.path(outfolder, "plots/identity_modeldispersion.png")))
  expect_true(file.exists(file.path(outfolder, "plots/precision_recall_byparams.png")))
  expect_true(file.exists(file.path(outfolder, "plots/precision_recall_aggregated.png")))
  expect_true(file.exists(file.path(outfolder, "plots/roc_byparams.png")))
  expect_true(file.exists(file.path(outfolder, "plots/roc_aggregated.png")))
  expect_true(file.exists(file.path(outfolder, "plots/genes_expression.png")))
})


# Test pour la fonction export_evaluation_report
test_that("export_evaluation_report exports evaluation report to folder without errors", {
  outfolder <- tempdir()
  eval_report <- list(
    data = list(
      modelparams = data.frame(a = 1:5, b = letters[1:5]),
      modeldispersion = data.frame(x = runif(10), y = rnorm(10))
    ),
    performances = list(
      aggregate = data.frame(x = runif(10), y = rnorm(10)),
      byparams = data.frame(param = 1:10, x = runif(10), y = rnorm(10))
    ),
    identity = list(
      params = ggplot2::qplot(x = rnorm(100)),
      dispersion = ggplot2::qplot(y = runif(100))
    ),
    precision_recall = list(
      params = ggplot2::qplot(x = runif(100)),
      aggregate = ggplot2::qplot(y = rnorm(100))
    ),
    roc = list(
      params = ggplot2::qplot(x = rnorm(100)),
      aggregate = ggplot2::qplot(y = runif(100))
    ),
    counts = ggplot2::qplot(x = rnorm(100))
  )
  suppressWarnings(export_evaluation_report(eval_report, outfolder, plot_format = "png", width = 20))
  expect_true(file.exists(file.path(outfolder, "data_modelparams.tsv")))
  expect_true(file.exists(file.path(outfolder, "data_modeldispersion.tsv")))
  expect_true(file.exists(file.path(outfolder, "performances_aggregated.tsv")))
  expect_true(file.exists(file.path(outfolder, "performances_byparams.tsv")))
  expect_true(file.exists(file.path(outfolder, "plots", "identity_modelparams.png")))
  expect_true(file.exists(file.path(outfolder, "plots", "identity_modeldispersion.png")))
  expect_true(file.exists(file.path(outfolder, "plots", "precision_recall_byparams.png")))
  expect_true(file.exists(file.path(outfolder, "plots", "precision_recall_aggregated.png")))
  expect_true(file.exists(file.path(outfolder, "plots", "roc_byparams.png")))
  expect_true(file.exists(file.path(outfolder, "plots", "roc_aggregated.png")))
  expect_true(file.exists(file.path(outfolder, "plots", "genes_expression.png")))
})


```



```{r function-combine_mock_obj, filename =  "combine_mock_obj" }


#' Rename genes in a mock object by adding a specific index.
#'
#' @param mock_obj Mock object to be modified.
#' @param idx Index to be added to gene names.
#' @return The modified mock object with updated gene names.
#' @export
#' @examples
#' \dontrun{
#' rename_genes_in_mock_obj(mock_obj, 1)
#' }
rename_genes_in_mock_obj <- function(mock_obj, idx){
  names(mock_obj$groundTruth$gene_dispersion) <- paste(names(mock_obj$groundTruth$gene_dispersion), 
                                                       idx, sep = ".")
  mock_obj$groundTruth$effects$geneID <- paste(mock_obj$groundTruth$effects$geneID, 
                                                  idx, sep = ".")
  return(mock_obj)
}

#' Combine ground truth information from a list of mock objects.
#'
#' @param list_mock_obj List of mock objects.
#' @return A list containing the combined ground truth data.
#' @export
#' @examples
#' \dontrun{
#' l_mock_obj <- list(mock_data_1, mock_data_2)
#' combine_ground_truth(l_mock_obj)
#' }
combine_ground_truth <- function(list_mock_obj){
  list_ground_truth <- get_list_of_mock_attribute(list_mock_obj, 'groundTruth')
  list_dispersion <- get_list_of_mock_attribute(list_ground_truth, 'gene_dispersion')
  dispersions <- do.call(c, list_dispersion)
  list_effects <- get_list_of_mock_attribute(list_ground_truth, 'effects')
  effects <- do.call(rbind, list_effects)
  return(list(effects = effects, gene_dispersion = dispersions))
}


#' Get a list of specified attributes from a list of mock objects.
#'
#' @param list_mock_obj List of mock objects.
#' @param attr Name of the attribute to extract.
#' @return A list containing the values of the specified attribute for each mock object.
#' @export
#' @examples
#' \dontrun{
#' l_mock_obj <- list(mock_data_1, mock_data_2)
#' get_list_of_mock_attribute(l_mock_obj)
#' }
get_list_of_mock_attribute <- function(list_mock_obj, attr = 'init' ){
  list_attr <- list()
  for (mock_obj in list_mock_obj){
    list_attr <- append(list_attr , list(mock_obj[[attr]]))
  }
  return(list_attr)
}

#' Check if all elements in a list are identical.
#'
#' @param x List to be checked.
#' @return TRUE if all elements are identical, otherwise FALSE.
#' @export
#' @examples
#' list_non_identik <- list(1, 2, 3, 4)
#' are_all_elements_identical(list_non_identik)
#' list_identik <- list(1, 1, 1, 1)
#' are_all_elements_identical(list_identik)
are_all_elements_identical <- function(x) {
  # Vérifier si tous les éléments de la liste sont égaux au premier élément
  all_equal <- all(sapply(x[-1], function(el) identical(el, x[[1]])))
  return(all_equal)
}

#' Combine multiple mock objects into a single mock object.
#'
#' @param list_mock_obj List of mock objects to combine.
#' @param min_replicates  Minimum number of replicates (mandatory when generate_counts = TRUE).
#' If min_replicates is different from max replicates, the number of replicates is randomly selected 
#' from a uniform distribution between min and max replicates.
#' @param max_replicates Maximum replicates number (mandotory only if generate_counts = TRUE)
#' If min_replicates is different from max replicates, the number of replicates is randomly selected 
#' from a uniform distribution between min and max replicates.
#' @param sequencing_depth Sequencing depth parameter (optional, default NULL).
#' @return A combined mock object containing various simulated data.
#' @export
#' @examples
#' input_var_list <- init_variable(name = "varA", sd = 0.2, level = 3) 
#' ## -- simulate RNAseq data
#' mock_data_1 <- mock_rnaseq(input_var_list,
#'                         n_genes = 10,
#'                         min_replicates = 4,
#'                         max_replicates = 4,
#'                         generate_counts = FALSE)
#' input_var_list <- init_variable(name = "varA", sd = 0.6, level = 3) 
#' ## -- simulate RNAseq data
#' mock_data_2 <- mock_rnaseq(input_var_list,
#'                         n_genes = 10,
#'                         min_replicates = 4,
#'                         max_replicates = 4,
#'                         generate_counts = FALSE)                       
#' list_mock_obj <- list(mock_data_1, mock_data_2)
#' mock_data_cbine <- combine_mock(list_mock_obj, 4, 4, sequencing_depth = 1e6)
combine_mock <- function(list_mock_obj, min_replicates, max_replicates, sequencing_depth = NULL){
  
  ## -- verif 
  stopifnot(is.list(list_mock_obj))
  len_list_mock_obj <- length(list_mock_obj)
  message(paste("INFO:", "length(list_mock_obj):", len_list_mock_obj , sep = " "))
  stopifnot("All elements in list_mock_obj must be valid mock_obj." = all(sapply(list_mock_obj, isValidMock_obj)))
  list_init <- get_list_of_mock_attribute(list_mock_obj)
  list_levels <- lapply(list_init, function(init) getGivenAttribute(init, 'level'))
  list_names <- lapply(list_init, function(init) getGivenAttribute(init, 'name'))
  stopifnot("Name and number of levels for each variable must be identical between element in list_mock_obj." = are_all_elements_identical(list_levels))
  list_settings <- get_list_of_mock_attribute(list_mock_obj, 'settings')
  list_mock_obj <- lapply(1:length(list_mock_obj), function(idx) rename_genes_in_mock_obj(list_mock_obj[[idx]], idx) )
  mock_ground_truth <- combine_ground_truth(list_mock_obj)
  list_var <-  list_mock_obj[[1]]$init
  n_genes <- length(mock_ground_truth$gene_dispersion)
  
  message("Building mu_ij matrix")
  ## -- matrix
  matx_Muij <- getMu_ij_matrix(mock_ground_truth$effects)
  l_sampleID <- getSampleID(list_var)
  matx_bool_replication <- generateReplicationMatrix(list_var, min_replicates, max_replicates)
  mu_ij_matx_rep <- replicateMatrix(matx_Muij, matx_bool_replication)
  matx_dispersion <- getDispersionMatrix(list_var, n_genes, mock_ground_truth$gene_dispersion)
  ## same order as mu_ij_matx_rep
  matx_dispersion <- matx_dispersion[ order(row.names(matx_dispersion)), ]
  matx_dispersion_rep <- replicateMatrix(matx_dispersion, matx_bool_replication)
  
  if (!is.null(sequencing_depth)) {
    scaling_factors <- get_scaling_factor(mu_ij_matx_rep, sequencing_depth)
    invisible(get_messages_sequencing_depth(scaling_factors))
    mu_ij_dtf_rep <- scaleCountsTable(mu_ij_matx_rep, scaling_factors)
    mu_ij_matx_rep <- as.matrix(mu_ij_dtf_rep)
    ## -- rescaling effect
    mock_ground_truth$effects$log_qij_scaled <- mock_ground_truth$effects$log_qij_scaled + log(mean(scaling_factors, na.rm = T))
  } else{
    scaling_factors <- NULL
  }
  
  invisible(warning_too_low_mu_ij_row(mu_ij_matx_rep))
  matx_countsTable <- generateCountTable(mu_ij_matx_rep, matx_dispersion_rep)
  message("Counts simulation: Done")
  
  dtf_countsTable <- matx_countsTable %>% as.data.frame()
  checkFractionOfZero(dtf_countsTable)
  
  metaData <- getSampleMetadata(list_var, matx_bool_replication)
  libSize <- sum(colSums(dtf_countsTable))
  settings_df <- getSettingsTable(n_genes, min_replicates, max_replicates, libSize)
  list2ret <- list( settings = settings_df, init = list_init, 
                    groundTruth = mock_ground_truth,
                    counts = dtf_countsTable,
                    metadata = metaData,
                    scaling_factors = scaling_factors)
  ## -- clean garbage collector to save memory 
  invisible(gc(reset = TRUE, verbose = FALSE));
  return(list2ret)
}



```



```{r test-combine_mock_obj}

# Tester la fonction get_list_of_mock_attribute
test_that("get_list_of_mock_attribute returns correct attribute values", {
  
  input_var_list <- init_variable(name = "varA", sd = 0.2, level = 3) 
  #' ## -- simulate RNAseq data
  mock_data_1 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)
   input_var_list <- init_variable(name = "varA", sd = 0.6, level = 3) 
  #' ## -- simulate RNAseq data
   mock_data_2 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)                       
  list_mock_obj <- list(mock_data_1, mock_data_2)
  list_attr <- get_list_of_mock_attribute(list_mock_obj, 'groundTruth')
  
  # Vérifier que les valeurs des attributs sont correctes
  expect_equal(length(list_attr), 2)
  expect_true(is.list(list_attr[[1]]) && is.list(list_attr[[2]]))
  expect_equal(names(list_attr[[1]]), c("effects", "gene_dispersion"))
  expect_equal(names(list_attr[[2]]), c("effects", "gene_dispersion"))
  expect_equal(length(list_attr[[1]]$gene_dispersion), 10)
  expect_equal(length(list_attr[[2]]$gene_dispersion), 10)
})


# Tester la fonction are_all_elements_identical
test_that("are_all_elements_identical returns correct results", {
  # Créer des données pour les tests
  non_identical_list <- list(c(1, 1, 1), c("a", "a", "a"))
  identical_list <- list(c("a", "b", "c"), c("a", "b", "c"))
  expect_true(are_all_elements_identical(identical_list))
  expect_false(are_all_elements_identical(non_identical_list))
})


# Tester la fonction combine_mock
test_that("combine_mock returns the expected combined data", {
  
  input_var_list <- init_variable(name = "varA", sd = 0.2, level = 3) 
  #' ## -- simulate RNAseq data
  mock_data_1 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)
   input_var_list <- init_variable(name = "varA", sd = 0.6, level = 3) 
  #' ## -- simulate RNAseq data
   mock_data_2 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)                       
  list_mock_obj <- list(mock_data_1, mock_data_2)
  combined_data <- combine_mock(list_mock_obj, 3, 3, sequencing_depth = 1e6)
  
  # Vérifier que les éléments de la sortie sont corrects
  expect_true(is.list(combined_data))
  expect_true("settings" %in% names(combined_data))
  expect_true("init" %in% names(combined_data))
  expect_true("groundTruth" %in% names(combined_data))
  expect_true("counts" %in% names(combined_data))
  expect_true("metadata" %in% names(combined_data))
  expect_true("scaling_factors" %in% names(combined_data))
  
  # expect error
  input_var_list <- init_variable(name = "varA", sd = 0.2, level = 3) 
  #' ## -- simulate RNAseq data
  mock_data_1 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)
   input_var_list <- init_variable(name = "varA_false_name", sd = 0.6, level = 3) 
  #' ## -- simulate RNAseq data
   mock_data_2 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)                       
  list_mock_obj <- list(mock_data_1, mock_data_2)
  expect_error(combined_data <- combine_mock(list_mock_obj,2, 3, sequencing_depth = 1e6))
  
  
    # expect error
  input_var_list <- init_variable(name = "varA", sd = 0.2, level = 3) 
  #' ## -- simulate RNAseq data
  mock_data_1 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)
   input_var_list <- init_variable(name = "varA", sd = 0.6, level = 8) ## change nb level 
  #' ## -- simulate RNAseq data
   mock_data_2 <- mock_rnaseq(input_var_list,
                           n_genes = 10,
                           generate_counts = F)                       
  list_mock_obj <- list(mock_data_1, mock_data_2)
  expect_error(combined_data <- combine_mock(list_mock_obj, 4, 4, sequencing_depth = 1e6))
  
})


```




```{r development-inflate, eval=FALSE}
setwd("/home/adminarnaud/Documents/HTRfit/")
#usethis::create_package(path = "/Users/ex_dya/Documents/LBMC/HTRfit/")
fusen::fill_description(fields = list(Title = "HTRfit"), overwrite = T)
usethis::use_gpl_license(version = 3, include_future = TRUE)
usethis::use_pipe(export = TRUE)
devtools::document()
# Keep eval=FALSE to avoid infinite loop in case you hit the knit button
# Execute in the console directly
fusen::inflate(pkg = "/home/adminarnaud/Documents/HTRfit/", flat_file = "dev/flat_full.Rmd", 
               vignette_name = NA, open_vignette = F, overwrite = T)
```