From a042944054658f49d8060312efdeeb61836ea9ce Mon Sep 17 00:00:00 2001 From: aduvermy <arnaud.duvermy@ens-lyon.fr> Date: Thu, 4 Apr 2024 14:20:55 +0200 Subject: [PATCH] enhance report Former-commit-id: 3315d7b7a669f610970b9c041ca5a2c96dddad44 Former-commit-id: 12faaaf4fc44cca459c6d05fd8177e018394a65b Former-commit-id: 91f0afc716a5a87eb01a079ba77a18b85b712e54 --- R/export_evaluation_report.R | 17 ++++++++++++----- R/simulation_report.R | 27 ++++++++++++++++++--------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/R/export_evaluation_report.R b/R/export_evaluation_report.R index cc52203..65cd61c 100644 --- a/R/export_evaluation_report.R +++ b/R/export_evaluation_report.R @@ -19,7 +19,8 @@ isValidEval_report <- function(obj){ stop("All elements in 'obj' are NULL") } ## level 1 - expected_names <- c("data", "identity", "precision_recall", "roc", "counts", "performances") + expected_names <- c("data", "identity", "counts", "performances") + ## N.B : "precision_recall", "roc" names are not mandatory if (!all(expected_names %in% names(obj))) { stop(message_err) } @@ -90,15 +91,21 @@ export_evaluation_report <- function(eval_report_obj, outfolder, plot_format = " #' @importFrom ggplot2 ggsave #' @export export_eval_plots <- function(eval_report_obj, extension, ...){ + dir.create("plots", showWarnings = FALSE) setwd("./plots") + ggplot2::ggsave(filename = paste("identity_modelparams" , extension, sep = "."), eval_report_obj$identity$params, ...) ggplot2::ggsave(filename = paste("identity_modeldispersion" , extension, sep = "."), eval_report_obj$identity$dispersion ) - ggplot2::ggsave(filename = paste("precision_recall_byparams" , extension, sep = "."), eval_report_obj$precision_recall$params ) - ggplot2::ggsave(filename = paste("precision_recall_aggregated" , extension, sep = "."), eval_report_obj$precision_recall$aggregate) - ggplot2::ggsave(filename = paste("roc_byparams" , extension, sep = "."), eval_report_obj$roc$params ) - ggplot2::ggsave(filename = paste("roc_aggregated" , extension, sep = "."), eval_report_obj$roc$aggregate ) ggplot2::ggsave(filename = paste("genes_expression" , extension, sep = "."), eval_report_obj$counts) + + if (all(c("roc", "precision_recall") %in% names(eval_report_obj))){ + ggplot2::ggsave(filename = paste("precision_recall_byparams" , extension, sep = "."), eval_report_obj$precision_recall$params ) + ggplot2::ggsave(filename = paste("precision_recall_aggregated" , extension, sep = "."), eval_report_obj$precision_recall$aggregate) + ggplot2::ggsave(filename = paste("roc_byparams" , extension, sep = "."), eval_report_obj$roc$params ) + ggplot2::ggsave(filename = paste("roc_aggregated" , extension, sep = "."), eval_report_obj$roc$aggregate ) + } + } diff --git a/R/simulation_report.R b/R/simulation_report.R index 2097180..92e48d4 100644 --- a/R/simulation_report.R +++ b/R/simulation_report.R @@ -157,12 +157,9 @@ evaluation_report <- function(mock_obj, list_gene = NULL, list_tmb <- list_tmb[list_gene] } - ## -- eval data eval_data <- get_eval_data(list_tmb, dds, mock_obj, coeff_threshold, alt_hypothesis) - ## -- identity plot - #identity_data <- rbind_model_params_and_dispersion(eval_data) params_identity_eval <- eval_identityTerm( eval_data$modelparams, palette_color, palette_shape ) dispersion_identity_eval <- eval_identityTerm(eval_data$modeldispersion, palette_color, palette_shape) @@ -172,6 +169,12 @@ evaluation_report <- function(mock_obj, list_gene = NULL, eval_data2metrics <- eval_data$modelparams } + ## aggregate RMSE + R2 + rmse_modelparams <- compute_rmse(eval_data2metrics, grouping_by = "from") + rsquare_modelparams <- compute_rsquare(eval_data2metrics, grouping_by = "from") + rsquare_modelparams$from <- NULL + aggregate_metrics <- cbind(rmse_modelparams, rsquare_modelparams) + ## -- counts plot counts_violinplot <- counts_plot(mock_obj) @@ -186,9 +189,11 @@ evaluation_report <- function(mock_obj, list_gene = NULL, return( list( data = eval_data, - identity = list( params = params_identity_eval$p, + identity = list(params = params_identity_eval$p, dispersion = dispersion_identity_eval$p ), - counts = counts_violinplot + counts = counts_violinplot, + performances = list(byparams = rbind(dispersion_identity_eval$R2, params_identity_eval$R2), + aggregate = aggregate_metrics ) )) @@ -208,7 +213,8 @@ evaluation_report <- function(mock_obj, list_gene = NULL, ## -- acc, recall, sensib, speci, ... metrics_obj <- get_ml_metrics_obj(eval_data2metrics, alpha_risk ) ## -- merge all metrics in one obj - model_perf_obj <- get_performances_metrics_obj( r2_params = params_identity_eval$R2, + model_perf_obj <- get_performances_metrics_obj( r2_params = params_identity_eval$R2, + r2_agg = aggregate_metrics, dispersion_identity_eval$R2, pr_curve_obj, roc_curve_obj, @@ -248,15 +254,16 @@ evaluation_report <- function(mock_obj, list_gene = NULL, #' The function generates separate data frames for metric values by parameter value and for the #' aggregated metric values. #' -#' @param r2_params R-squared values from model parameters evaluation object. -#' @param r2_dispersion R-squared values from dispersion evaluation object. +#' @param r2_params R-squared and RMSE values from model parameters evaluation object. +#' @param r2_agg R-squared and RMSE values aggregated. +#' @param r2_dispersion R-squared and RMSE values from dispersion evaluation object. #' @param pr_obj PR object generated from evaluation report. #' @param roc_obj ROC object generated from evaluation report. #' @param ml_metrics_obj Machine learning performance metrics object. #' #' @return A list containing separate data frames for by-parameter and aggregated metric values. #' @export -get_performances_metrics_obj <- function(r2_params , r2_dispersion, +get_performances_metrics_obj <- function(r2_params , r2_agg ,r2_dispersion, pr_obj, roc_obj, ml_metrics_obj ){ ## -- by params auc_mtrics_params <- join_dtf(pr_obj$byparams$pr_auc , roc_obj$byparams$roc_auc, @@ -271,6 +278,8 @@ get_performances_metrics_obj <- function(r2_params , r2_dispersion, k1 = c("from"), k2 = c("from")) metrics_agg <- join_dtf(auc_mtrics_agg , ml_metrics_obj$aggregate, k1 = c("from"), k2 = c("from")) + metrics_agg <- join_dtf(metrics_agg, r2_agg, + k1 = c("from"), k2 = c("from")) return(list(byparams = metrics_params, aggregate = metrics_agg )) } -- GitLab