Skip to content
Snippets Groups Projects
Verified Commit a971a892 authored by nfontrod's avatar nfontrod
Browse files

src/figures_utils/violin_most_enriched_ft_com.py: script to create violin plot...

src/figures_utils/violin_most_enriched_ft_com.py: script to create violin plot of the most enriched component for each clusters
parent c5b30fde
Branches
No related tags found
No related merge requests found
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Description: The goal of this script is to create a violin plot displaying
the most enriched components for each communities. The violin can be \
displayed with or without genes dot inside them.
"""
from pathlib import Path
import lazyparser as lp
import pandas as pd
import plotly.graph_objects as go
from src.figures_utils.community_pca import merge_figures
from src.figures_utils.stacked_barplot_figures import get_color_pallette
from src.find_interaction_cluster.Heatmap_distance.distance_figure import (
update_layout_fig,
)
from ..find_interaction_cluster.nt_and_community import create_dataframe
from .config_figures import Config
def create_figure(
df: pd.DataFrame,
com_name: str,
outfile: Path,
cpnt_type: str,
df_m: pd.DataFrame,
) -> None:
"""
Create a figure that show the percentage of genes of each hub
defined in other communities.
:param df: DataFrame containing the percentage of genes of each hub.
:param com_name: Name of the community inside the second community file.
:param outfile: Path to the output file.
:param cpnt_type: The type of component of interest
:param df_m: DataFrame containing the mean percentage of a component.
"""
fig = go.Figure()
# Create a figuredf["SPIN"].unique()
colors = get_color_pallette("tab10", list(df["cpnt"].unique()))
for i, cpnt in enumerate(df["cpnt"].unique()):
pts = "all" if df[df["cpnt"] == cpnt].shape[0] < 300 else False
fig.add_trace(
go.Violin(
x=df[df["cpnt"] == cpnt]["cpnt"],
y=df[df["cpnt"] == cpnt]["percentage"],
name=cpnt,
customdata=df[df["cpnt"] == cpnt]["id_gene"],
hovertemplate="%{customdata}<br>Percentage: %{y:.2f}<br>cpnt: %{x}<extra></extra>",
box_visible=False,
points=pts,
meanline_visible=True,
line_color="black",
marker_color="black",
fillcolor=colors[i],
spanmode="hard",
)
)
fig.update_traces(marker=dict(size=2), pointpos=-0)
if df_m is not None and not df_m.empty:
fig.add_trace(
go.Scatter(
x=df_m["cpnt"],
y=df_m["percentage"],
mode="markers",
marker=dict(
color="pink", size=15, line=dict(color="pink", width=1)
),
marker_symbol="line-ew",
name="Mean all",
)
)
fig = update_layout_fig(fig)
yaxis = (
"Percentage genes"
if not cpnt_type.endswith("mer")
else "Number by kilobases"
)
fig.update_layout(
font=dict(size=20),
xaxis_title=cpnt_type,
yaxis_title=yaxis,
title=com_name,
)
# Save the figure
fig.write_image(outfile)
def filter_cpnt(df_mc: pd.DataFrame, com: str) -> list[str] | None:
"""
filter the components to display in the figure
:param df_mc: DataFrame containing the mean gene frequency data.
:param com: The community of interest
:return: The list of components to keep
"""
if len(df_mc["cpnt"].unique()) <= 5:
return None
tmp = df_mc.drop("community", axis=1).groupby("cpnt").mean().reset_index()
tmp = df_mc[df_mc["community"] == com].merge(tmp, on="cpnt", how="left")
tmp["rfreq"] = tmp["percentage_x"] - tmp["percentage_y"]
return list(
tmp.sort_values(by="rfreq", ascending=False)
.query("rfreq > 0")
.head(5)["cpnt"]
)
def generate_figures(
com_file: str, com_name: str, region: str = "gene", cpnt_type: str = "nt"
):
"""
Generate violin plots for each community displaying either all \
components or a subset of components
:param com_file: Path to the community file.
:param com_name: Name of the community defined in the community file.
:param region: Region of interest.
:param cpnt_type: The type of component of interest.
"""
df = create_dataframe(com_file, "gene", region, cpnt_type).drop(
"community_size", axis=1
)
df_m = df.drop(["community", "id_gene"], axis=1).mean()
df_cm = df.drop(["id_gene"], axis=1).groupby("community").mean()
df = df.melt(
id_vars=["community", "id_gene"],
var_name="cpnt",
value_name="percentage",
)
df_m = (
df.drop(["id_gene", "community"], axis=1)
.groupby(["cpnt"])
.mean()
.reset_index()
)
df_cm = (
df.drop(["id_gene"], axis=1)
.groupby(["community", "cpnt"])
.mean()
.reset_index()
)
list_figure = []
coms = pd.read_csv(com_file, sep="\t")["community"].unique()
for com in coms:
flt = filter_cpnt(df_cm, com)
df_flt = df[df["community"] == com].drop("community", axis=1).copy()
if flt:
df_flt = df_flt[df_flt["cpnt"].isin(flt)].copy()
df_flt["cpnt"] = pd.Categorical(
df_flt["cpnt"], categories=flt, ordered=True
)
df_flt.sort_values(by="cpnt", inplace=True)
df_flt["cpnt"] = df_flt["cpnt"].astype(str)
outfile = Config.violin_plot_cpnt / f"{com}_{cpnt_type}.pdf"
create_figure(
df_flt,
com,
outfile,
cpnt_type,
df_cm[df_cm["cpnt"].isin(flt)]
.drop("community", axis=1)
.groupby("cpnt")
.mean()
.reset_index(),
)
list_figure.append(outfile)
merge_figures(
Config.violin_plot_cpnt, list_figure, f"{com_name}_{cpnt_type}.pdf"
)
@lp.parse(com_file="file")
def main(
com_file: str,
com_name: str = "",
region: str = "gene",
cpnt_type: str = "nt",
):
"""
Generate violin plots for each community displaying either all \
components or a subset of components
:param com_file: Path to the community file.
:param com_name: Name of the community inside the community file.
:param region: Region of interest.
:param cpnt_type: The type of component of interest.
"""
Config.violin_plot_cpnt.mkdir(exist_ok=True)
if not com_name:
com_name = Path(com_file).stem
generate_figures(com_file, com_name, region, cpnt_type)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment