diff --git a/src/find_interaction_cluster/community_file_tools/perform_topgo_analysis.py b/src/find_interaction_cluster/community_file_tools/perform_topgo_analysis.py index 3e92eb4fc3d828c7861d523ceb292b5446718ce5..345a87a13bd329d0d8e2d0b73ee31c23c35d6e49 100644 --- a/src/find_interaction_cluster/community_file_tools/perform_topgo_analysis.py +++ b/src/find_interaction_cluster/community_file_tools/perform_topgo_analysis.py @@ -7,10 +7,17 @@ Description: Perform a topGo analysis for each list of genes in the file \ produced by the script community_file_2_gene_list in results/clusters_gene_list """ +import math from pathlib import Path from typing import List +import numpy as np import pandas as pd +import plotly.express as px +import polars as pl + +from src.figures_utils.component_variance_figure import update_layout_fig +from src.figures_utils.stacked_barplot_figures import get_color_pallette from ..config import Config @@ -30,14 +37,12 @@ def write_input_file(mfile: Path, folder: Path) -> List[Path]: """ input_files = [] df = pd.read_csv(mfile, sep="\t") - df = df[-df["hg38_symbol"].isna()].copy() + df = df[-df["symbol"].isna()].copy() for cluster in df["cluster"].unique(): tmp = df[df["cluster"] == cluster] cl = cluster.replace(".", "_") outfile = folder / f"{cl}.txt" - tmp[["hg38_symbol"]].to_csv( - outfile, sep="\t", index=False, header=False - ) + tmp[["symbol"]].to_csv(outfile, sep="\t", index=False, header=False) input_files.append(outfile) return input_files @@ -53,9 +58,9 @@ def write_background(mfile: Path, folder: Path) -> Path: outf = folder / "input_topgo" outf.mkdir(exist_ok=True) df = pd.read_csv(mfile, sep="\t") - df = df[-df["hg38_symbol"].isna()].copy() + df = df[-df["symbol"].isna()].copy() outfile = folder / "background.txt" - ndf = pd.DataFrame({"hg38_symbol": df["hg38_symbol"].unique()}) + ndf = pd.DataFrame({"symbol": df["symbol"].unique()}) ndf.to_csv(outfile, sep="\t", index=False, header=False) return outfile @@ -80,6 +85,158 @@ def execute_topgo( sp.check_call(cmd, shell=True) +def get_enrichment_files(output_folder: Path) -> list[Path]: + """ + Get the enrichment files generated by topgo + + :param output_folder: The output folder where the enrichment files are \ + generated + :return: A list of Path objects representing the enrichment files + """ + return list(output_folder.glob("**/*_genes_CC_*.txt")) + + +def open_enrichment_file(outfile: Path) -> pl.DataFrame: + """ + Open an enrichment file generated by topgo + + :param outfile: The output file to open + :return: A DataFrame containing the enrichment data + """ + df = pl.read_csv( + outfile, + separator="\t", + dtypes={ + "GO.ID": pl.Utf8(), + "Term": pl.Utf8(), + "Annotated": pl.Int64(), + "Significant": pl.Int64(), + "Expected": pl.Float64(), + "fish": pl.Utf8(), + "pvalue": pl.Utf8(), + "padj": pl.Utf8(), + }, + ).with_columns( + community=pl.lit(outfile.parent.name), + fish=pl.col("fish").str.replace("< ", "").cast(pl.Float64()), + pvalue=pl.col("pvalue").str.replace("< ", "").cast(pl.Float64()), + padj=pl.col("padj").str.replace("< ", "").cast(pl.Float64()), + ) + return df + + +def load_many_files(outfiles: list[Path]) -> pl.DataFrame: + """Load many enrichment files generated by topgo + + :param outfiles: A list of Path objects representing the enrichment files + :return: A DataFrame containing the enrichment data + """ + dfs = [open_enrichment_file(outfile) for outfile in outfiles] + return pl.concat(dfs) + + +def generate_figure( + df: pl.DataFrame, + output: Path, + go_term: str, + term: str, + cummunities_name: str, +) -> None: + """Generate a figure for a given GO term + + :param df: The DataFrame containing the enrichment data + :param output: The output path for the figure + :param go_term: The GO term to generate the figure for + :param term: The term description + :param cummunities_name: The name of the clusters studied + """ + + fig = px.scatter( + df, + x="mlog10padj", + y="log2fc", + color="SPIN" if "SPIN" in df.columns else "community", + size="Significant", + hover_data=["padj", "log2fc", "community", "Significant"], + title=f"GO Term: {go_term} - {term} - {cummunities_name}", + color_discrete_sequence=df["color"].tolist(), + ) + fig = update_layout_fig(fig) + fig.update_layout(font=dict(size=15)) + fig.add_hline( + y=0, line_dash="dash", line_color="black", line_width=2, layer="below" + ) + fig.add_vline( + x=-math.log10(0.05), + line_dash="dash", + line_color="black", + line_width=2, + layer="below", + ) + fig.write_html(output) + + +def add_color_columns(df: pd.DataFrame, order: list[str]) -> pd.DataFrame: + """ + Add color columns to the dataframe based on the community column + + :param df: The dataframe to add color columns to + :param order: The order of the communities + :return: The dataframe with color columns added + """ + if not df[df["community"].str.contains("Speckle")].empty: + df["SPIN"] = [x[-1] for x in df["community"].str.split("_")] + tmp = [x.split("_")[-1] for x in order] + indexes = np.unique(tmp, return_index=True)[1] + order = [order[i] for i in sorted(indexes)] + c_dic = dict(zip(order, get_color_pallette("turbo", order))) + df["color"] = df["community"].map(c_dic) + return df + + +def create_figures(outfolder: Path, mfile: Path) -> None: + """ + Generate the enrichment scatter plot for each GO term + + :param outfolder: The output folder where the figures will be saved + :param mfile: A file containing cluster with hg38 names + """ + order = pd.read_csv(mfile, sep="\t")["cluster"].unique().tolist() + outf = outfolder / "scatter_fig" + outf.mkdir(exist_ok=True) + enri_file = get_enrichment_files(outfolder) + df = load_many_files(enri_file) + df = df.with_columns( + log2fc=(pl.col("Significant") / pl.col("Expected")).log(base=2), + mlog10padj=pl.col("padj").log(base=10) * -1, + ) + df.write_csv(outfolder / f"GO_{mfile.stem}_enrichment.csv", separator="\t") + gt = { + k: v[0] + for k, v in df.filter(pl.col("padj") < 0.05) + .select(["GO.ID", "Term"]) + .unique(subset="GO.ID") + .unique() + .rows_by_key("GO.ID") + .items() + } + df = df.filter(pl.col("GO.ID").is_in(gt.keys())) + df.write_csv( + outfolder / f"GO_{mfile.stem}_filtered_enrichment.csv", separator="\t" + ) + df = df.to_pandas() + df = add_color_columns(df, order) + df["community"] = pd.Categorical( + df["community"], categories=order, ordered=True + ) + df = df.sort_values("community") + for go_term, term in gt.items(): + tmp = df[df["GO.ID"] == go_term] + nterm = term.replace(" ", "_").replace("/", "-")[:20] + outfile = outf / f"{go_term}_{nterm}_{mfile.stem}.html" + generate_figure(tmp, outfile, go_term, term, mfile.stem) + + def execute_cmds(mfile: Path, top: int = 20) -> None: """ Execute topgo for each cluster defined in mfile @@ -97,6 +254,7 @@ def execute_cmds(mfile: Path, top: int = 20) -> None: cfolder = folder / cinput.stem cfolder.mkdir(exist_ok=True) execute_topgo(cinput, background, cfolder, top) + create_figures(folder, mfile) @lp.parse(gene_list="file")