From b1b9de59da2b724c5da5280385e4a024507d4fde Mon Sep 17 00:00:00 2001
From: Laurent Modolo <laurent@modolo.fr>
Date: Fri, 15 May 2020 11:52:28 +0200
Subject: [PATCH] 00_function.R: add cell_type color palette

---
 src/00_functions.R | 192 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 191 insertions(+), 1 deletion(-)

diff --git a/src/00_functions.R b/src/00_functions.R
index c84dad1..b15ff9a 100644
--- a/src/00_functions.R
+++ b/src/00_functions.R
@@ -9,6 +9,7 @@ library(glmmTMB)
 library(parallel)
 library(pbmcapply)
 library(plsgenomics)
+library(scater)
 
 anscombe <- function(x){
   2.0 * sqrt(x + 3.0 / 8.0)
@@ -22,6 +23,28 @@ logspace <- function(d1, d2, n) {
   return(exp(log(10) * seq(d1, d2, length.out = n)))
 }
 
+cell_type_palette <- function(cell_type){
+  cell_type_color <- list(EM = "#c90000",
+                          TEMRA = "#f4a582",
+                          TSCM = "#92c5de",
+                          CM = "#0571b0",
+                          Naive = "gray",
+                          NAIVE = "gray",
+                          Effector = "#c90000",
+                          EFF = "#c90000",
+                          Memory = "#0571b0",
+                          MEM = "#0571b0",
+                          UNK = "gray")
+  cell_types_color <- lapply(as.list(cell_type),
+  FUN = function(x, cell_type_color){
+      cell_type_color[[x]]
+    }
+    , cell_type_color)
+  cell_types_color <- unlist(cell_types_color)
+  names(cell_types_color) <- cell_type
+  return(cell_types_color)
+}
+
 QC_sample <- function(is_good) {
   c(
     base::sample(
@@ -293,6 +316,173 @@ PLS_filter <- function(sce,  group_by, genes , features = NULL,
   return(sce)
 }
 
+plsgenomics_logit_spls_stab <- function (X, Y, lambda.ridge.range, lambda.l1.range, ncomp.range, 
+          adapt = TRUE, maxIter = 100, svd.decompose = TRUE, ncores = 1, 
+          nresamp = 100, center.X = TRUE, scale.X = FALSE, weighted.center = TRUE, 
+          seed = NULL, verbose = TRUE) 
+{
+  X <- as.matrix(X)
+  n <- nrow(X)
+  p <- ncol(X)
+  index.p <- c(1:p)
+  if (is.factor(Y)) {
+    Y <- as.numeric(levels(Y))[Y]
+  }
+  Y <- as.integer(Y)
+  Y <- as.matrix(Y)
+  q <- ncol(Y)
+  one <- matrix(1, nrow = 1, ncol = n)
+  if (!is.null(seed)) {
+    set.seed(seed)
+  }
+  cnames <- NULL
+  if (!is.null(colnames(X))) {
+    cnames <- colnames(X)
+  }
+  else {
+    cnames <- paste0(1:p)
+  }
+  if (length(table(Y)) > 2) {
+    warning("message from logit.spls.stab: multicategorical response")
+    results <- multinom.spls.stab(X = X, Y = Y, lambda.ridge.range = lambda.ridge.range, 
+                                  lambda.l1.range = lambda.l1.range, ncomp.range = ncomp.range, 
+                                  adapt = adapt, maxIter = maxIter, svd.decompose = svd.decompose, 
+                                  ncores = ncores, nresamp = nresamp, center.X = center.X, 
+                                  scale.X = scale.X, weighted.center = weighted.center, 
+                                  seed = seed, verbose = verbose)
+    return(results)
+  }
+  if ((!is.matrix(X)) || (!is.numeric(X))) {
+    stop("message from logit.spls.stab: X is not of valid type")
+  }
+  if (p == 1) {
+    stop("message from logit.spls.stab: p=1 is not valid")
+  }
+  if ((!is.matrix(Y)) || (!is.numeric(Y))) {
+    stop("message from logit.spls.stab: Y is not of valid type")
+  }
+  if (q != 1) {
+    stop("message from logit.spls.stab: Y must be univariate")
+  }
+  if (nrow(Y) != n) {
+    stop("message from logit.spls.stab: the number of observations in Y is not equal to the number of row in X")
+  }
+  if (sum(is.na(Y)) != 0) {
+    stop("message from logit.spls.stab: NA values in Ytrain")
+  }
+  if (sum(!(Y %in% c(0, 1))) != 0) {
+    stop("message from logit.spls.stab: Y is not of valid type")
+  }
+  if (sum(as.numeric(table(Y)) == 0) != 0) {
+    stop("message from logit.spls.stab: there are empty classes")
+  }
+  if (any(!is.numeric(lambda.ridge.range)) || any(lambda.ridge.range < 
+                                                  0) || any(!is.numeric(lambda.l1.range)) || any(lambda.l1.range < 
+                                                                                                 0) || any(lambda.l1.range > 1)) {
+    stop("Message from logit.spls.stab: lambda is not of valid type")
+  }
+  if (any(!is.numeric(ncomp.range)) || any(round(ncomp.range) - 
+                                           ncomp.range != 0) || any(ncomp.range < 0) || any(ncomp.range > 
+                                                                                            p)) {
+    stop("Message from logit.spls.stab: ncomp is not of valid type")
+  }
+  if ((!is.numeric(maxIter)) || (round(maxIter) - maxIter != 
+                                 0) || (maxIter < 1)) {
+    stop("message from logit.spls.stab: maxIter is not of valid type")
+  }
+  if ((!is.numeric(ncores)) || (round(ncores) - ncores != 0) || 
+      (ncores < 1)) {
+    stop("message from logit.spls.stab: ncores is not of valid type")
+  }
+  if ((!is.numeric(nresamp)) || (round(nresamp) - nresamp != 
+                                 0) || (nresamp < 1)) {
+    stop("message from logit.spls.stab: nresamp is not of valid type")
+  }
+  grid.resampling <- as.matrix(Reduce("rbind", pbmcapply::pbmclapply(1:nresamp, 
+                                                        function(id.samp) {
+                                                          ntrain = floor(0.5 * n)
+                                                          index.train = sort(sample(1:n, size = ntrain))
+                                                          Xtrain = X[index.train, ]
+                                                          Ytrain = Y[index.train]
+                                                          condition = length(table(Ytrain)) < 2
+                                                          test = 0
+                                                          while (condition & test < 100) {
+                                                            index.train = sort(sample(1:n, size = ntrain))
+                                                            Xtrain = X[index.train, ]
+                                                            Ytrain = Y[index.train]
+                                                            condition = length(table(Ytrain)) < 2
+                                                            test = test + 1
+                                                          }
+                                                          if (test == 100) {
+                                                            ind0 = which(Y == 0)
+                                                            ind1 = which(Y == 1)
+                                                            index.train = sample(ind0, size = 1)
+                                                            index.train = c(index.train, sample(ind1, size = 1))
+                                                            index.train = c(index.train, sample((1:n)[which(!(1:n) %in% 
+                                                                                                              index.train)], size = ntrain - 2))
+                                                            Xtrain = X[index.train, ]
+                                                            Ytrain = Y[index.train]
+                                                          }
+                                                          paramGrid <- expand.grid(lambdaL1 = lambda.l1.range, 
+                                                                                   lambdaL2 = lambda.ridge.range, ncomp = ncomp.range, 
+                                                                                   KEEP.OUT.ATTRS = FALSE)
+                                                          grid_out <- as.matrix(Reduce("rbind", lapply(1:nrow(paramGrid), 
+                                                                                                       function(gridRow) {
+                                                                                                         lambdaL1 <- paramGrid$lambdaL1[gridRow]
+                                                                                                         lambdaL2 <- paramGrid$lambdaL2[gridRow]
+                                                                                                         ncomp <- paramGrid$ncomp[gridRow]
+                                                                                                         fit_out <- logit.spls(Xtrain = Xtrain, Ytrain = Ytrain, 
+                                                                                                                               lambda.ridge = lambdaL2, lambda.l1 = lambdaL1, 
+                                                                                                                               ncomp = ncomp, Xtest = NULL, adapt = adapt, 
+                                                                                                                               maxIter = maxIter, svd.decompose = svd.decompose, 
+                                                                                                                               center.X = center.X, scale.X = scale.X, weighted.center = weighted.center)
+                                                                                                         sel_var <- fit_out$Anames
+                                                                                                         status_var <- rep(0, length(cnames))
+                                                                                                         status_var[which(cnames %in% sel_var)] <- rep(1, 
+                                                                                                                                                       length(which(cnames %in% sel_var)))
+                                                                                                         tmp <- c(lambdaL1, lambdaL2, ncomp, id.samp, 
+                                                                                                                  sum(status_var), status_var)
+                                                                                                         return(tmp)
+                                                                                                       })))
+                                                          rownames(grid_out) <- NULL
+                                                          return(grid_out)
+                                                        }, mc.cores = ncores, mc.silent = !verbose)))
+  grid.resampling <- data.frame(grid.resampling)
+  colnames(grid.resampling) <- c("lambdaL1", "lambdaL2", "ncomp", 
+                                 "id", "nbVar", cnames)
+  grid.resampling$point <- paste0(grid.resampling$lambdaL1, 
+                                  "_", grid.resampling$lambdaL2, "_", grid.resampling$ncomp)
+  o.grid <- order(grid.resampling$nbVar)
+  grid.resampling <- grid.resampling[o.grid, ]
+  if (any(table(grid.resampling$point) < nresamp)) {
+    warning("message from logit.spls.stab: empty classe in a resampling")
+    print(table(grid.resampling$point))
+  }
+  tmp_qLambda <- as.matrix(Reduce("rbind", pbmcapply::pbmclapply(1:nresamp, 
+                                                    function(id.samp) {
+                                                      tmp1 <- subset(grid.resampling, grid.resampling$id == 
+                                                                       id.samp)
+                                                      tmp2 <- apply(t(tmp1)[-c(1:5, tail(1:ncol(grid.resampling), 
+                                                                                         1)), ], 1, cumsum)
+                                                      tmp3 <- apply(t(tmp2), 2, function(x) return(sum(x != 
+                                                                                                         0)))
+                                                      tmp4 <- cbind(tmp1[, c(1:4)], unname(tmp3))
+                                                      return(tmp4)
+                                                    }, mc.cores = ncores, mc.silent = !verbose)))
+  tmp_qLambda <- data.frame(tmp_qLambda)
+  colnames(tmp_qLambda) <- c("lambdaL1", "lambdaL2", "ncomp", 
+                             "id.samp", "qLambda")
+  qLambda <- ddply(tmp_qLambda, c("lambdaL1", "lambdaL2", "ncomp"), 
+                   function(x) colMeans(x[c("qLambda")], na.rm = TRUE))
+  o.qLambda <- order(qLambda$qLambda)
+  qLambda <- qLambda[o.qLambda, ]
+  probs_lambda <- ddply(grid.resampling, c("lambdaL1", "lambdaL2", 
+                                           "ncomp"), function(x) colMeans(x[cnames], na.rm = TRUE))
+  probs_lambda <- probs_lambda[o.qLambda, ]
+  return(list(q.Lambda = qLambda, probs.lambda = probs_lambda, 
+              p = p))
+}
+
 PLS_fit <- function(sce,
                     group_by,
                     genes,
@@ -319,7 +509,7 @@ PLS_fit <- function(sce,
           "counts") %>%
         as.matrix() %>%
         t() %>%
-    plsgenomics::logit.spls.stab(
+    plsgenomics_logit_spls_stab(
       X = .,
       Y = colData(
         altExp(sce, altExp_name))$group_by[
-- 
GitLab