diff --git a/src/find_interaction_cluster/community_figures/fig_functions.py b/src/find_interaction_cluster/community_figures/fig_functions.py index 124bf0db76fd59256cf69caf18cfedf6d497548b..20747cae31f8df8677c4f55bb701387f6bbc9dea 100644 --- a/src/find_interaction_cluster/community_figures/fig_functions.py +++ b/src/find_interaction_cluster/community_figures/fig_functions.py @@ -315,7 +315,8 @@ def expand_results_perm(df: pd.DataFrame, rdf: pd.DataFrame, target_col: str, def make_barplot(df_bar: pd.DataFrame, outfile: Path, - target_col: str, feature: str, target_kind: str = "") -> None: + target_col: str, feature: str, target_kind: str = "", + sd_community: Optional[str] = "sd") -> None: """ Create a barplot showing the frequency of `nt` for every community \ of exons/gene in `df_bar`. @@ -327,11 +328,13 @@ def make_barplot(df_bar: pd.DataFrame, outfile: Path, target_col. :param target_col: The name of the column containing the data of interest :param feature: The king of feature of interest + :param sd_community: sd to display community error bar, None to display \ + nothing """ sns.set(context="poster") g = sns.catplot(x="community", y=target_col, data=df_bar, kind="point", - ci="sd", aspect=2.5, height=14, errwidth=0.5, capsize=.4, - scale=0.5, + ci=sd_community, aspect=2.5, height=14, errwidth=0.5, + capsize=.4, scale=0.5, palette=["red"] + ["darkgray"] * (df_bar.shape[0] - 1)) g2 = sns.catplot(x="community", y=target_col, data=df_bar, kind="bar", ci="sd", aspect=2.5, height=14, errwidth=0.5, capsize=.4, @@ -347,7 +350,9 @@ def make_barplot(df_bar: pd.DataFrame, outfile: Path, for i, p in enumerate(g2.ax.patches): stats = "*" if df_bara.iloc[i, :]["p-adj"] < 0.05 else "" com = df_bara.iloc[i, :]["community"] - csd = np.std(df_bar.loc[df_bar["community"] == com, target_col]) + csd = 0 + if sd_community == "sd": + csd = np.std(df_bar.loc[df_bar["community"] == com, target_col]) g.ax.annotate(stats, (p.get_x() + p.get_width() / 2., p.get_height() + csd), ha='center', va='center', xytext=(0, 10), fontsize=12, @@ -357,7 +362,8 @@ def make_barplot(df_bar: pd.DataFrame, outfile: Path, def make_barplot_perm(df_bar: pd.DataFrame, outfile: Path, target_col: str, feature: str, - target_kind: str = "") -> None: + target_kind: str = "", + sd_community: Optional[str] = "sd") -> None: """ Create a barplot showing the frequency of `nt` for every community \ of exons/gene in `df_bar`. @@ -369,6 +375,8 @@ def make_barplot_perm(df_bar: pd.DataFrame, outfile: Path, target_col. :param target_col: The name of the column containing the data of interest :param feature: The king of feature of interest + :param sd_community: sd to display community error bar, None to display \ + nothing """ sns.set(context="poster") df_ctrl = df_bar.loc[df_bar[f"id_{feature}"] == 'ctrl', :] @@ -377,13 +385,17 @@ def make_barplot_perm(df_bar: pd.DataFrame, outfile: Path, ci="sd", aspect=2.5, height=14, errwidth=0.5, capsize=.4, palette=["darkgray"] * (df_bar.shape[0])) g = sns.catplot(x="community", y=target_col, data=df_bar, kind="point", - ci="sd", aspect=2.5, height=14, errwidth=0.5, capsize=.4, - scale=0.5, palette=["darkgray"] * (df_bar.shape[0])) + ci=sd_community, aspect=2.5, height=14, errwidth=0.5, + capsize=.4, scale=0.5, + palette=["darkgray"] * (df_bar.shape[0])) xrange = g.ax.get_xlim() + yrange = g.ax.get_ylim() df_ctrl.plot(x="community", y=target_col, kind="scatter", ax=g.ax, yerr="ctrl_std", legend=False, zorder=10, color=(0.8, 0.2, 0.2, 0.4)) g.ax.set_xlim(xrange) + if sd_community is None: + g.ax.set_ylim(yrange) g.fig.subplots_adjust(top=0.9) target_kind = f" ({target_kind})" if target_kind else "" g.fig.suptitle(f"Mean frequency of {target_col}{target_kind}" @@ -395,7 +407,9 @@ def make_barplot_perm(df_bar: pd.DataFrame, outfile: Path, for i, p in enumerate(g2.ax.patches): stats = "*" if df_bara.iloc[i, :]["p-adj"] < 0.05 else "" com = df_bara.iloc[i, :]["community"] - csd = np.std(df_bar.loc[df_bar["community"] == com, target_col]) + csd = 0 + if sd_community == "sd": + csd = np.std(df_bar.loc[df_bar["community"] == com, target_col]) g.ax.annotate(stats, (p.get_x() + p.get_width() / 2., p.get_height() + csd), ha='center', va='center', xytext=(0, 10), fontsize=12, @@ -405,7 +419,7 @@ def make_barplot_perm(df_bar: pd.DataFrame, outfile: Path, def barplot_creation(df_bar: pd.DataFrame, outfig: Path, cpnt: str, test_type: str, feature: str, - target_kind) -> None: + target_kind: str, sd_community: bool) -> None: """ Reformat a dataframe with the enrichment of a nucleotide frequency \ for every feature for every community and then create a \ @@ -421,11 +435,15 @@ def barplot_creation(df_bar: pd.DataFrame, outfig: Path, :param test_type: The type of test to make (permutation or lm) :param target_kind: An optional name that describe a bit further \ target_col. + :param sd_community: True to display the errors bars for communities, + False else. """ + sd_community = "sd" if sd_community else None if test_type == "lm": make_barplot(df_bar, outfig, cpnt, feature, target_kind) else: - make_barplot_perm(df_bar, outfig, cpnt, feature, target_kind) + make_barplot_perm(df_bar, outfig, cpnt, feature, target_kind, + sd_community) def get_feature_by_community(df: pd.DataFrame, feature: str) -> Dict: @@ -461,7 +479,8 @@ def create_community_fig(df: pd.DataFrame, feature: str, outfile_ctrl: Path, test_type: str, dic_com: Optional[Dict] = None, target_kind: str = "", - iteration: int = 10000) -> None: + iteration: int = 10000, + sd_community: bool = True) -> None: """ Create a dataframe with a control community, save it as a table and \ as a barplot figure. @@ -480,6 +499,8 @@ def create_community_fig(df: pd.DataFrame, feature: str, :param target_kind: An optional name that describe a bit further \ target_col. :param iteration: The number of sub samples to create + :param sd_community: True to display the errors bars for communities, + False else. """ if dic_com is None: dic_com = {} if test_type == 'lm' \ @@ -495,4 +516,4 @@ def create_community_fig(df: pd.DataFrame, feature: str, bar_outfile = str(outfile_ctrl).replace(".pdf", "_bar.txt") df_bar.to_csv(bar_outfile, sep="\t", index=False) barplot_creation(df_bar, outfile_ctrl, target_col, test_type, feature, - target_kind) + target_kind, sd_community)