#!/usr/bin/env

# -*- condig: UTF-8 -*-

"""
The goal of this script is to compute the ribosomal density for genes in \
experiment 245-6-7
"""

from pathlib import Path
from typing import List
import doctest
import pandas as pd
from functools import reduce
import seaborn as sns
import numpy as np
from rpy2.robjects import r, pandas2ri


class Config_file:
    data = Path(__file__).parents[1] / "data"
    htseq_fold = data / "htseq_count_manip_245-6-7_density"
    htseq_rnaseq = htseq_fold / "htseq_count_rna"
    htseq_ribo = htseq_fold / "htseq_count_ribo"
    result = Path(__file__).parents[1] / "results"
    tdd_braf_itdd = result / "TDD_analysis" / "correlation" / \
        "TDD_BRAF_BRAF_DOWN.txt"
    tdd_dmso_dtdd = result / "TDD_analysis" / "correlation" / \
        "TDD_DMSO_BRAF_UP.txt"
    output = result / "TDD_analysis" / "density_figures"


def load_htseqfiles(mfolder: Path) -> List[Path]:
    """
    Load the list of htseq files located in that folder

    :param mfolder: A folder containing htseq count file
    :return: The list of files in `mfolder`

    >>> res = lrgoad_htseqfiles(Config_file.htseq_ribo)
    >>> len(res)
    6
    >>> [x.stem for x in res[0:2]]
    ['DMSO_245_ribo', 'DMSO_247_ribo']
    """
    return list(Path(mfolder).glob("*.tsv"))


def load_df(list_files: List[Path]) -> pd.DataFrame:
    """
    load all count files

    :param list_files: A list of htseq files
    :return: A dataframe contaning counts for every files

    >>> mfiles = [Config_file.htseq_ribo / "DMSO_245_ribo.tsv",
    ...           Config_file.htseq_ribo / "DMSO_247_ribo.tsv"]
    >>> load_df(mfiles).head()
          gene  DMSO_245_ribo  DMSO_247_ribo
    0     A1BG             44             84
    1     A1CF              0              0
    2      A2M           4376           5682
    3    A2ML1              0              0
    4  A3GALT2              0              0

    """
    list_df = [pd.read_csv(x, sep="\t", names=["gene", x.stem])
               for x in list_files]
    return reduce(lambda x, y: x.merge(y, how="inner"), list_df)


def compute_cpm(df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute the cpm of the dataframe df

    :param df: A dataframe containing raw read count
    :return: The dataframe with the cpm read counts

    >>> dsdf = pd.DataFrame({"gene": [f"gn_{i}" for i in range(1, 6)],
    ... "x1": [10000, 20000, 50000, 12000, 8000],
    ... "x2": [200000, 400000, 1000000, 240000, 160000]})
    >>> compute_cpm(dsdf)
       gene        x1        x2
    0  gn_1  100000.0  100000.0
    1  gn_2  200000.0  200000.0
    2  gn_3  500000.0  500000.0
    3  gn_4  120000.0  120000.0
    4  gn_5   80000.0   80000.0
    """
    col = [c for c in df.columns if c != "gene"]
    for cc in col:
        fctr = 1e6 / df[cc].sum()
        df[cc] = df[cc] * fctr
    return df


def compute_ribosome_density(df_cpm: pd.DataFrame):
    """
    compute the density of ribosomes

    :param df_cpm: A dataframe containing CPM
    :return: The ribosomes density for every replicates

    >>> dsdf = pd.DataFrame({"gene": list("ABCDE"),
    ... "DMSO_245_ribo": [5, 10, 15, 20, 25],
    ... "DMSO_245_rna": [5, 5, 5, 5, 5],
    ... "DMSO_246_ribo": [5, 5, 5, 5, 5],
    ... "DMSO_246_rna": [5, 10, 15, 20, 0]})
    >>> compute_ribosome_density(dsdf)
      gene treatment replicate   density
    0    A      DMSO       245  1.000000
    1    A      DMSO       246  1.000000
    2    B      DMSO       245  2.000000
    3    B      DMSO       246  0.500000
    4    C      DMSO       245  3.000000
    5    C      DMSO       246  0.333333
    6    D      DMSO       245  4.000000
    7    D      DMSO       246  0.250000
    8    E      DMSO       245  5.000000
    """
    df = df_cpm.melt(id_vars=["gene"], value_name="cpm", var_name="sample")
    df["condition"] = df["sample"].str.replace("_(ribo|rna)", "", regex=True)
    df["kind"] = df["sample"].str.extract(r'(ribo|rna)')
    df = df.pivot(index=["gene", "condition"], columns="kind", values="cpm"
                  ).reset_index()
    df.columns.name = None
    df["density"] = df["ribo"] / df["rna"]
    df["treatment"] = df["condition"].str.extract(r"(DMSO|BRAF)_")
    df["replicate"] = df["condition"].str.extract(r"_(24.)")
    df.drop(["ribo", "rna", "condition"], inplace=True, axis=1)
    df = df[(-df["density"].isna() & -np.isinf(df["density"]))]
    return df[["gene", "treatment", "replicate", "density"]]


def create_density_df() -> pd.DataFrame:
    """
    Compute the density dataframe
    """
    file_ribo = load_htseqfiles(Config_file.htseq_ribo)
    file_rna = load_htseqfiles(Config_file.htseq_rnaseq)
    mfiles = file_ribo + file_rna
    df = load_df(mfiles)
    df = compute_cpm(df)
    df = compute_ribosome_density(df)
    return df


def add_group_column(df_density: pd.DataFrame, gene_file: Path
                     ) -> pd.DataFrame:
    """
    Add a column containing gene having a TDD increasing/desreasing in BRAF
    condition

    :param df_density: A dataframe containing densities
    :param gene_file: A file containing gene whose TDD changes in BRAF
    condition

    >>> genes_l = list("ABC")
    >>> temp_file = Path("/tmp/tmp.genes")
    >>> tf = temp_file.open("w")
    >>> x = tf.write("\\n".join(genes_l))
    >>> tf.close()
    >>> dsdd = pd.DataFrame({
    ... 'gene': ['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D', 'E', 'E'],
    ... 'treatment': ['DMSO'] * 10,
    ... 'replicate': [245, 246, 245, 246, 245, 246, 245, 246, 245, 246],
    ... 'density': [1.0, 1.0, 2.0, 0.5, 3.0, 0.333333, 4.0, 0.25, 5.0, 0.2]})
    >>> add_group_column(dsdd, temp_file)
      gene treatment  replicate   density group
    0    A      DMSO        245  1.000000   tmp
    1    A      DMSO        246  1.000000   tmp
    2    B      DMSO        245  2.000000   tmp
    3    B      DMSO        246  0.500000   tmp
    4    C      DMSO        245  3.000000   tmp
    5    C      DMSO        246  0.333333   tmp
    6    D      DMSO        245  4.000000  CTRL
    7    D      DMSO        246  0.250000  CTRL
    8    E      DMSO        245  5.000000  CTRL
    9    E      DMSO        246  0.200000  CTRL
    >>> temp_file.unlink()
    """
    genes = gene_file.open("r").read().splitlines()
    df_density["group"] = ["CTRL"] * df_density.shape[0]
    df_density.loc[df_density["gene"].isin(genes), "group"] = gene_file.stem
    return df_density


def avg_replicate(df_density: pd.DataFrame) -> pd.DataFrame:
    """
    Averages the density by gene and drop replicate column

    :param df_density: A dataframe containing ribosomes densities

    >>> dsdd = pd.DataFrame(
    ... {'gene': ['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D', 'E', 'E'],
    ... 'treatment': ['DMSO'] * 10 ,
    ... 'replicate': [245, 246] * 5,
    ... 'density': [1.0, 1.0, 2.0, 0.5, 3.0, 0.333333, 4.0, 0.25, 5.0, 0.2],
    ... 'group': ['tmp'] * 6 + ['CTRL'] * 4})
    >>> avg_replicate(dsdd)
      gene treatment group   density
    0    A      DMSO   tmp  1.000000
    1    B      DMSO   tmp  1.250000
    2    C      DMSO   tmp  1.666667
    3    D      DMSO  CTRL  2.125000
    4    E      DMSO  CTRL  2.600000
    """
    return df_density.drop(
        "replicate", axis=1).groupby(
        ["gene", "treatment", "group"]).mean().reset_index()


def statistical_analysis(df_density: pd.DataFrame, output: Path,
                         list_type: str) -> float:
    """Get the density pvalue

    :param df_density: A density dataframe
    :type df_density: pd.DataFrame
    :param outout: The folder where the diagnostics will be produced
    :param the kind of list studied
    :return: The pvalue needed
    """
    pandas2ri.activate()
    output = output / "diag"
    output.mkdir(exist_ok=True)
    treat = df_density["treatment"].unique()[0]
    outfig = output / f"diag_{treat}_{list_type}.pdf"
    outpval = output / f"diag_pval_{treat}_{list_type}.txt"
    stat_s = r("""
    require("DHARMa")
    require("glmmTMB")
    function(data, outfig, outpval){{
        mod <- glmmTMB(log1p(density) ~ group + replicate, ziformula = ~1,
                       family="ziGamma", data=data)
        simulationOutput <- simulateResiduals(fittedModel = mod, n = 200)
        pdf(outfig)
        plot(simulationOutput)
        dev.off()
        res <- summary(mod)
        tmp <- as.data.frame(res$coefficients$cond)
        tmp['col'] <- rownames(tmp)
        write.table(tmp, outpval, sep="\t")
        return(res$coefficients$cond[2, 4])
    }}
    """)
    res = stat_s(df_density, str(outfig), str(outpval))
    return float(res)


def create_density_figure(df_density: pd.DataFrame, gene_file: Path,
                          output_file: Path) -> None:
    """
    Create a violinplot indicating the ribosome densities between the groups
    of genes studied between each treatment

    :param df_density: A DataFrame of density
    :param gene_file: The gene_file used to build the group column of the \
    df_density dataframe
    :param output_file: File where the figure will be created
    """
    sns.set_theme(context="talk", font_scale=2, style="whitegrid")
    df_mean = avg_replicate(df_density)
    df_mean["log10_density"] = np.log(df_mean["density"] + 1)
    ctreatment = "DMSO" if "UP" in gene_file.stem else "BRAF"
    df_mean = df_mean[df_mean["treatment"] == ctreatment]
    pval = statistical_analysis(
        df_density[df_density["treatment"] == ctreatment].copy(),
        output_file.parent, gene_file.stem)
    g = sns.catplot(x="group", y="log10_density",
                    data=df_mean, kind="violin", height=10, aspect=1.7,
                    palette={"TDD_BRAF_BRAF_DOWN": "grey", "CTRL": "white",
                             "TDD_DMSO_BRAF_UP": "grey"}, cut=True)
    pval = "%.2e" % pval
    title = f"Ribosome density " \
            f"between {gene_file.stem} and CTRL" \
            f"genes in {ctreatment} condition\n" \
            f"(pval = {pval})"
    g.fig.subplots_adjust(top=0.98)
    g.fig.suptitle(title, fontsize=10)
    g.set(ylabel="log10(density +1)", xlabel="")
    g.savefig(output_file)


def density_figure_maker():
    """
    Create the violin indicating the ribosome density figure
    """
    df = create_density_df()
    list_files = [Config_file.tdd_braf_itdd, Config_file.tdd_dmso_dtdd]
    for mfile in list_files:
        dfc = add_group_column(df.copy(), mfile)
        Config_file.output.mkdir(exist_ok=True)
        output_fig = Config_file.output / \
            f"density_figure_gene_{mfile.stem}.pdf"
        dfc.to_csv(output_fig.parent /
                   f"{output_fig.stem}.txt", sep="\t", index=False)
        create_density_figure(dfc, mfile, output_fig)


if __name__ == "__main__":
    import sys

    if len(sys.argv) >= 2 and sys.argv[1] == "test":
        doctest.testmod()
    else:
        density_figure_maker()