diff --git a/src/find_interaction_cluster/community_figures/fig_functions.py b/src/find_interaction_cluster/community_figures/fig_functions.py index d7367e87eef4b4044a729f3b87a8656e8b8fd2ae..3fa685b08634ed53b85d023600b69adb8a30152f 100644 --- a/src/find_interaction_cluster/community_figures/fig_functions.py +++ b/src/find_interaction_cluster/community_figures/fig_functions.py @@ -77,8 +77,10 @@ def lm_maker_summary(df: pd.DataFrame, outfile: Path, target_col: str, mod = f"mod <- glm({target_col} ~ log(community_size) + community," \ f"data=data, family=binomial(link='logit'))" df[target_col] = df[target_col].astype(int) - tmp = df[[target_col, 'community']].groupby('community').mean().reset_index() - bad_groups = tmp.loc[tmp[target_col] == 0, "community"].to_list() + tmp = df[[target_col, 'community']].groupby('community').mean().\ + reset_index() + bad_groups = tmp.loc[(tmp[target_col] == 0) | (tmp[target_col] == 1), + "community"].to_list() if "C-CTRL" in bad_groups: print("Control group as a mean value equals to 0, exiting...") exit(1) @@ -105,7 +107,10 @@ def lm_maker_summary(df: pd.DataFrame, outfile: Path, target_col: str, res_df.rename({'index': 'community'}, inplace=True, axis=1) res_df['community'] = res_df['community'].str.replace('community', '') res_df.loc[res_df['community'] == "(Intercept)", "community"] = "C-CTRL" - mean_df = df[[target_col, "community", "community_size"]]. \ + good_cols = [target_col, "community", "community_size"] + if "community_data" in df.columns: + good_cols.append("community_data") + mean_df = df[good_cols]. \ groupby(["community", "community_size"]).mean().reset_index() return df, res_df.merge(mean_df, how="left", on="community") @@ -151,8 +156,11 @@ def expand_results_lm(df: pd.DataFrame, rdf: pd.DataFrame, :return: The merged dataframe: i.e df with the stats columns """ p_col = "Pr(>|t|)" if test_type == "lm" else "Pr(>|z|)" - df = df[[f"id_{feature}", target_col, "community", - "community_size"]].copy() + good_cols = [f"id_{feature}", target_col, "community", + "community_size"] + if "community_data" in df.columns: + good_cols.append("community_data") + df = df[good_cols].copy() rdf = rdf[["community", "community_size", p_col, target_col]].copy() rdf.rename({target_col: f"mean_{target_col}", p_col: "p-adj"}, axis=1, inplace=True) @@ -317,9 +325,11 @@ def expand_results_perm(df: pd.DataFrame, rdf: pd.DataFrame, target_col: str, :param iteration: The number of iteration :return: The merged dataframe: i.e df with the stats columns """ + good_cols = [f"id_{feature}", target_col, "community", "community_size"] + if "community_data" in df.columns: + good_cols.append("community_data") df = df.loc[-df["community"].isna(), - [f"id_{feature}", target_col, - "community", "community_size"]].copy() + good_cols].copy() ctrl_df = rdf[[f"{target_col}_mean_{iteration}_ctrl", f"{target_col}_std_{iteration}_ctrl", "community"]].copy() rdf = rdf[["community", "community_size", target_col, "p-adj", @@ -334,6 +344,39 @@ def expand_results_perm(df: pd.DataFrame, rdf: pd.DataFrame, target_col: str, return pd.concat([df_ctrl, df], axis=0, ignore_index=True) +def handle_community_data(g: sns.FacetGrid, df_bar: pd.DataFrame, + display_size: bool): + """ + + :param g: A seaborn FacetGrid + :param df_bar: A dataframe with the enrichment of a \ + nucleotide frequency for every community (without control) + :param display_size: True to display the size of the community above \ + each one of them False to display nothing. (default False) + :return: The seaborn fascetgrid + """ + com_col = "community_data" + if com_col in df_bar.columns: + df_val = df_bar[["community", com_col]].groupby("community").mean()\ + .reset_index() + df_val = df_bar.drop("community_data", axis=1)\ + .merge(df_val, how="left", + on="community")[["community", + "community_data"]].drop_duplicates() + ax3 = g.ax.twinx() + ax3.set_ylabel('community_size', color="purple") + if display_size: + ax3.spines["right"].set_position(("axes", 1.03)) + ax3.spines["right"].set_visible(True) + df_val.plot(x="community", y="community_data", kind="scatter", ax=ax3, + legend=False, zorder=55, + color=(0.8, 0.2, 0.8, 0.4)) + ax3.tick_params(axis='y', labelcolor="purple") + ax3.grid(False) + g.set(xticklabels=[]) + return g + + def display_size_fig(g: sns.FacetGrid, display_size: bool, target_col: str, df_bar: pd.DataFrame): """ @@ -361,7 +404,7 @@ def display_size_fig(g: sns.FacetGrid, display_size: bool, target_col: str, if max(sizes) - min(sizes) > 500: ax2.set_yscale("log") g.ax.set_xlim(xrange) - g.set(xticklabels=[]) + g = handle_community_data(g, df_bar, display_size) return g @@ -398,7 +441,7 @@ def make_barplot(df_bar: pd.DataFrame, outfile: Path, test_type: str, target_kind = f" ({target_kind})" if target_kind else "" g.fig.suptitle(f"Mean frequency of {target_col}{target_kind}" f" among community of {feature}s\n" - f"(stats obtained with a lm test)") + f"(stats obtained with a {test_type} test)") g = display_size_fig(g, display_size, target_col, df_bar) g.ax.set_ylabel(f'Frequency of {target_col}') df_bara = df_bar.drop_duplicates(subset="community", keep="first") @@ -571,6 +614,8 @@ def create_community_fig(df: pd.DataFrame, feature: str, :param display_size: True to display the size of the community above \ each one of them False to display nothing. (default False) """ + df.to_csv(str(outfile_ctrl).replace(".pdf", ".tmp.txt"), sep="\t", + index=False) if dic_com is None: dic_com = {} if test_type != 'permutation' \ else get_feature_by_community(df, feature)