# WARNING - Generated by {fusen} from dev/flat_full.Rmd: do not edit by hand

#' Check if Data is Valid for Model Fitting
#'
#' This function checks whether the provided data contains all the variables required in the model formula for fitting.
#'
#' @param data2fit The data frame or tibble containing the variables to be used for model fitting.
#' @param formula The formula specifying the model to be fitted.
#'
#' @return \code{TRUE} if all the variables required in the formula are present in \code{data2fit}, otherwise an error is raised.
#'
#' @examples
#' data(iris)
#' formula <- Sepal.Length ~ Sepal.Width + Petal.Length
#' isValidInput2fit(iris, formula) # Returns TRUE if all required variables are present
#' @export
isValidInput2fit <- function(data2fit, formula){
  vec_bool <- all.vars(formula) %in% colnames(data2fit)
  for (i in seq_along(vec_bool)){
    if (isFALSE(vec_bool[i]) ) {
      stop(paste("Variable", all.vars(formula)[i],  "not found"))
    }
  }
  return(TRUE)
}



#' Check if group by exist in data
#'
#' @param data The data framecontaining the variables to be used for model fitting.
#' @param group_by Column name in data representing the grouping variable 
#'
#' @return \code{TRUE} if exist otherwise an error is raised
#'
#' @examples
#' is_validGroupBy(mtcars, 'mpg')
#' @export
is_validGroupBy <- function(data, group_by){
  validGroupBy <- group_by %in% names(data)
  if (!validGroupBy) 
    stop("<Group by> doen't exist in data !")
  return(TRUE)
}




#' Drop Random Effects from a Formula
#'
#' This function allows you to remove random effects from a formula by specifying 
#' which terms to drop. It checks for the presence of vertical bars ('|') in the 
#' terms of the formula and drops the random effects accordingly. If all terms 
#' are random effects, the function updates the formula to have only an intercept. 
#'
#' @param form The formula from which random effects should be dropped.
#'
#' @return A modified formula with specified random effects dropped.
#'
#' @examples
#' # Create a formula with random effects
#' formula <- y ~ x1 + (1 | group) + (1 | subject)
#' # Drop the random effects related to 'group'
#' modified_formula <- drop_randfx(formula)
#'
#' @importFrom stats terms drop.terms
#' @importFrom stats update
#'
#' @export
drop_randfx <- function(form) {
  form.t <- stats::terms(form)
  dr <- grepl("|", labels(form.t), fixed = TRUE)
  if (mean(dr) == 1) {
    form.u <- stats::update(form, . ~ 1)
  } else {
    if (mean(dr) == 0) {
      form.u <- form
    } else {
      form.td <- stats::drop.terms(form.t, which(dr))
      form.u <- stats::update(form, form.td)
    }
  }
  form.u
}



#' Check if a Model Matrix is Full Rank
#'
#' This function checks whether a model matrix is full rank, which is essential for 
#' certain statistical analyses. It computes the eigenvalues of the crossproduct 
#' of the model matrix and determines if the first eigenvalue is positive and if 
#' the ratio of the last eigenvalue to the first is within a defined tolerance.
#'
#' This function is inspired by a similar function found in the Limma package.
#'
#' @param metadata The metadata used to create the model matrix.
#' @param formula The formula used to specify the model matrix.
#'
#' @return \code{TRUE} if the model matrix is full rank, \code{FALSE} otherwise.
#'
#' @examples
#' metadata <- data.frame(x = rnorm(10), y = rnorm(10))
#' formula <- y ~ x
#' is_fullrank(metadata, formula)
#'
#'
#' @importFrom stats model.matrix
#' @export
is_fullrank <- function(metadata, formula) {
  ## drop random eff
  formula <- drop_randfx(formula)
  ## TEST
  model_matrix <- stats::model.matrix(data = metadata, formula)
  e <- eigen(crossprod(model_matrix), symmetric = TRUE, only.values = TRUE)$values
  modelFullRank <- e[1] > 0 && abs(e[length(e)] / e[1]) > 1e-13
  
  if (!modelFullRank) {
    warning("The model matrix is not full rank. One or more variables or interaction terms in the design formula are linear combinations of the others.")
    return(FALSE)
  }

  
  return(TRUE)
}





#' Fit a model using the fitModel function.
#' @param group ID to fit
#' @param formula Formula specifying the model formula
#' @param data Data frame containing the data
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return Fitted model object or NULL if there was an error
#' @export
#' @examples
#' fitModel("mtcars" , formula = mpg ~ cyl + disp, data = mtcars)
fitModel <- function(group , formula, data, ...) {
  is_fullrank(data, formula)
  # Fit the model using glm.nb from the GLmmTMB package
  model <- glmmTMB::glmmTMB(formula, ..., data = data )
  
  ## -- save additional info
  model$frame <- data
  model$groupId <- group
   ## family in ... => avoid error in future update
  additional_args <- list(...)
  familyArgs <- additional_args[['family']]
  if (!is.null(familyArgs)) model$call$family <- familyArgs
  ## control in ... => avoid error in future update
  controlArgs <- additional_args[['control']]
  if (!is.null(controlArgs)) model$call$control <- controlArgs
  
  return(model)
}



#' Fit the model based using fitModel functions.
#'
#' @param groups list of group ID
#' @param group_by Column name in data representing the grouping variable
#' @param data Data frame containing the data
#' @return list of dataframe 
#' @export
#' @importFrom stats setNames
#' @examples
#' prepare_dataParallel(groups = iris$Species, group_by = "Species", 
#'                  data = iris )
prepare_dataParallel <- function(groups, group_by, data) {
  
  l_data2parallel <- lapply( stats::setNames(groups, groups) , function( group_id ){
                      subset_data <- data[ data[[ group_by ]] == group_id, ]
                      return(subset_data)
                  })
  return(l_data2parallel)
}



#' Launch the model fitting process for a specific group.
#'
#' This function fits the model using the specified group, group_by, formula, and data.
#' It handles warnings and errors during the fitting process and returns the fitted model or NULL if there was an error.
#'
#' @param data Data frame containing the data
#' @param group_by Column name in data representing the grouping variable
#' @param formula Formula specifying the model formula
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return List with 'glance' and 'summary' attributes representing the fitted model or NULL if there was an error
#' @export
#' @examples
#' launchFit(group_by = "Species", 
#'            formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
#'            data = iris[ iris[["Species"]] == "setosa" , ] )
launchFit <- function(data, group_by, formula, ...) {
  group <- unique(data[[ group_by ]]) 
  tryCatch(
    expr = {
      withCallingHandlers(
          fitModel(group , formula, data, ...),
          warning = function(w) {
            message(paste(Sys.time(), "warning for group", group, ":", conditionMessage(w)))
            invokeRestart("muffleWarning")
          })
    },
    error = function(e) {
      message(paste(Sys.time(), "error for group", group, ":", conditionMessage(e)))
      NULL
    }
  )
}


#' Fit models in parallel for each group using mclapply and handle logging.
#' Uses parallel_fit to fit the models.
#'
#' @param groups Vector of unique group values
#' @param group_by Column name in data representing the grouping variable
#' @param formula Formula specifying the model formula
#' @param data Data frame containing the data
#' @param n.cores The number of CPU cores to use for parallel processing.
#'  If set to NULL (default), the number of available CPU (minus 1) cores will be automatically detected.
#' @param log_file File to write log (default : Rtmpdir/htrfit.log)
#' @param cl_type cluster type (defautl "PSOCK"). "FORK" is recommanded for linux.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return List of fitted model objects or NULL for any errors
#' @export
#' @examples
#' parallel_fit(group_by = "Species", groups =  iris$Species, 
#'                formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                data = iris, n.cores = 1 )
parallel_fit <- function(groups, group_by, formula, data, n.cores = NULL, 
                         log_file = paste(tempdir(check = FALSE), "htrfit.log", sep = "/"), 
                         cl_type = "PSOCK",  ...) {
  
  if (is.null(n.cores)) n.cores <- max(1, parallel::detectCores(logical = FALSE) - 1)
  
  message(paste("CPU(s) number :", n.cores, sep = " "))
  message(paste("Cluster type :", cl_type, sep = " "))

  
  ## get data for parallelization
  l_data2parallel <- prepare_dataParallel(groups, group_by, data)

  clust <- parallel::makeCluster(n.cores, outfile = log_file , type= cl_type )
  parallel::clusterExport(clust, c("fitModel", "is_fullrank", "drop_randfx"))
  results_fit <- parallel::parLapply(clust, X = l_data2parallel, 
                                     fun = launchFit, 
                                     group_by = group_by, formula = formula, ...)
                                     
  parallel::stopCluster(clust) ; invisible(gc(reset = T, verbose = F, full = T));
  #closeAllConnections()
  return(results_fit)
}

#' Fit models in parallel for each group using mclapply and handle logging.
#' Uses parallel_fit to fit the models.
#'
#' @param formula Formula specifying the model formula
#' @param data Data frame containing the data
#' @param group_by Column name in data representing the grouping variable
#' @param n.cores The number of CPU cores to use for parallel processing.
#'               If set to NULL (default), the number of available CPU cores will be automatically detected.
#' @param log_file File path to save the log messages (default : Rtmpdir/htrfit.log)
#' @param cl_type cluster type (default "PSOCK"). "FORK" is recommended for linux.
#' @param ... Additional arguments to be passed to the glmmTMB::glmmTMB function
#' @return List of fitted model objects or NULL for any errors
#' @export
#' @examples
#' fitModelParallel(formula = Sepal.Length ~ Sepal.Width + Petal.Length, 
#'                  data = iris, group_by = "Species", n.cores = 1) 
fitModelParallel <- function(formula, data, group_by, n.cores = NULL, cl_type = "PSOCK" , 
                             log_file = paste(tempdir(check = FALSE), "htrfit.log", sep = "/"), ...) {
  
  stopifnot(cl_type %in% c("PSOCK", "FORK"))
  ## Some verification
  isValidInput2fit(data, formula)
  is_validGroupBy(data, group_by)

  ## -- print log location
  message( paste("Log file location", log_file, sep =': ') ) 
  
  # Fit models in parallel and capture the results
  groups <- unique(data[[ group_by ]])
  results <- parallel_fit(groups, group_by, formula, data, n.cores, log_file, cl_type, ...)
  #results <- mergeListDataframes(results)
  return(results)
}