From a042944054658f49d8060312efdeeb61836ea9ce Mon Sep 17 00:00:00 2001
From: aduvermy <arnaud.duvermy@ens-lyon.fr>
Date: Thu, 4 Apr 2024 14:20:55 +0200
Subject: [PATCH] enhance report

Former-commit-id: 3315d7b7a669f610970b9c041ca5a2c96dddad44
Former-commit-id: 12faaaf4fc44cca459c6d05fd8177e018394a65b
Former-commit-id: 91f0afc716a5a87eb01a079ba77a18b85b712e54
---
 R/export_evaluation_report.R | 17 ++++++++++++-----
 R/simulation_report.R        | 27 ++++++++++++++++++---------
 2 files changed, 30 insertions(+), 14 deletions(-)

diff --git a/R/export_evaluation_report.R b/R/export_evaluation_report.R
index cc52203..65cd61c 100644
--- a/R/export_evaluation_report.R
+++ b/R/export_evaluation_report.R
@@ -19,7 +19,8 @@ isValidEval_report <- function(obj){
     stop("All elements in 'obj' are NULL")
   }
   ## level 1 
-  expected_names <- c("data", "identity", "precision_recall", "roc", "counts", "performances")
+  expected_names <- c("data", "identity", "counts", "performances")
+  ## N.B : "precision_recall", "roc" names are not mandatory
   if (!all(expected_names %in% names(obj))) {
     stop(message_err)
   }
@@ -90,15 +91,21 @@ export_evaluation_report <- function(eval_report_obj, outfolder, plot_format = "
 #' @importFrom ggplot2 ggsave
 #' @export
 export_eval_plots <- function(eval_report_obj, extension, ...){
+  
   dir.create("plots", showWarnings = FALSE)
   setwd("./plots")
+  
   ggplot2::ggsave(filename = paste("identity_modelparams" , extension, sep = "."), eval_report_obj$identity$params, ...)
   ggplot2::ggsave(filename = paste("identity_modeldispersion" , extension, sep = "."), eval_report_obj$identity$dispersion )
-  ggplot2::ggsave(filename = paste("precision_recall_byparams" , extension, sep = "."), eval_report_obj$precision_recall$params )
-  ggplot2::ggsave(filename = paste("precision_recall_aggregated" , extension, sep = "."), eval_report_obj$precision_recall$aggregate)
-  ggplot2::ggsave(filename = paste("roc_byparams" , extension, sep = "."), eval_report_obj$roc$params )
-  ggplot2::ggsave(filename = paste("roc_aggregated" , extension, sep = "."), eval_report_obj$roc$aggregate )
   ggplot2::ggsave(filename = paste("genes_expression" , extension, sep = "."), eval_report_obj$counts)
+  
+  if (all(c("roc", "precision_recall") %in% names(eval_report_obj))){
+      ggplot2::ggsave(filename = paste("precision_recall_byparams" , extension, sep = "."), eval_report_obj$precision_recall$params )
+      ggplot2::ggsave(filename = paste("precision_recall_aggregated" , extension, sep = "."), eval_report_obj$precision_recall$aggregate)
+      ggplot2::ggsave(filename = paste("roc_byparams" , extension, sep = "."), eval_report_obj$roc$params )
+      ggplot2::ggsave(filename = paste("roc_aggregated" , extension, sep = "."), eval_report_obj$roc$aggregate )
+  }
+
 }
 
 
diff --git a/R/simulation_report.R b/R/simulation_report.R
index 2097180..92e48d4 100644
--- a/R/simulation_report.R
+++ b/R/simulation_report.R
@@ -157,12 +157,9 @@ evaluation_report <- function(mock_obj, list_gene = NULL,
         list_tmb <- list_tmb[list_gene]
   }
   
-  
   ## -- eval data
   eval_data <- get_eval_data(list_tmb, dds, mock_obj, coeff_threshold, alt_hypothesis)
-  
   ## -- identity plot
-  #identity_data <- rbind_model_params_and_dispersion(eval_data)
   params_identity_eval <- eval_identityTerm( eval_data$modelparams, palette_color, palette_shape )
   dispersion_identity_eval <- eval_identityTerm(eval_data$modeldispersion, palette_color, palette_shape)
   
@@ -172,6 +169,12 @@ evaluation_report <- function(mock_obj, list_gene = NULL,
     eval_data2metrics <- eval_data$modelparams
   }
 
+  ## aggregate RMSE + R2
+  rmse_modelparams <- compute_rmse(eval_data2metrics, grouping_by = "from")
+  rsquare_modelparams <- compute_rsquare(eval_data2metrics, grouping_by = "from")
+  rsquare_modelparams$from <- NULL
+  aggregate_metrics <- cbind(rmse_modelparams, rsquare_modelparams)
+  
   
   ## -- counts plot
   counts_violinplot <- counts_plot(mock_obj)
@@ -186,9 +189,11 @@ evaluation_report <- function(mock_obj, list_gene = NULL,
     return(
           list(
                 data = eval_data, 
-                identity = list( params = params_identity_eval$p,
+                identity = list(params = params_identity_eval$p,
                                 dispersion = dispersion_identity_eval$p ), 
-                counts = counts_violinplot
+                counts = counts_violinplot,
+                performances = list(byparams = rbind(dispersion_identity_eval$R2, params_identity_eval$R2),
+                                    aggregate = aggregate_metrics )
                 
           ))
       
@@ -208,7 +213,8 @@ evaluation_report <- function(mock_obj, list_gene = NULL,
   ## -- acc, recall, sensib, speci, ...
   metrics_obj <- get_ml_metrics_obj(eval_data2metrics, alpha_risk )
   ## -- merge all metrics in one obj
-  model_perf_obj <- get_performances_metrics_obj( r2_params = params_identity_eval$R2,                        
+  model_perf_obj <- get_performances_metrics_obj( r2_params = params_identity_eval$R2,   
+                                                  r2_agg = aggregate_metrics, 
                                                   dispersion_identity_eval$R2, 
                                                   pr_curve_obj,
                                                   roc_curve_obj,
@@ -248,15 +254,16 @@ evaluation_report <- function(mock_obj, list_gene = NULL,
 #' The function generates separate data frames for metric values by parameter value and for the
 #' aggregated metric values.
 #' 
-#' @param r2_params R-squared values from model parameters evaluation object.
-#' @param r2_dispersion R-squared values from dispersion evaluation object.
+#' @param r2_params R-squared and RMSE values from model parameters evaluation object.
+#' @param r2_agg R-squared and RMSE values aggregated.
+#' @param r2_dispersion R-squared and RMSE values from dispersion evaluation object.
 #' @param pr_obj PR object generated from evaluation report.
 #' @param roc_obj ROC object generated from evaluation report.
 #' @param ml_metrics_obj Machine learning performance metrics object.
 #' 
 #' @return A list containing separate data frames for by-parameter and aggregated metric values.
 #' @export 
-get_performances_metrics_obj <- function(r2_params , r2_dispersion,
+get_performances_metrics_obj <- function(r2_params , r2_agg ,r2_dispersion,
                                          pr_obj, roc_obj, ml_metrics_obj ){
   ## -- by params  
   auc_mtrics_params <- join_dtf(pr_obj$byparams$pr_auc , roc_obj$byparams$roc_auc,
@@ -271,6 +278,8 @@ get_performances_metrics_obj <- function(r2_params , r2_dispersion,
                       k1 = c("from"), k2 = c("from"))
   metrics_agg <- join_dtf(auc_mtrics_agg , ml_metrics_obj$aggregate,
                              k1 = c("from"), k2 = c("from"))
+  metrics_agg <- join_dtf(metrics_agg, r2_agg,
+                          k1 = c("from"), k2 = c("from"))
   return(list(byparams = metrics_params, aggregate = metrics_agg ))  
 }
 
-- 
GitLab