diff --git a/src/find_interaction_cluster/community_file_tools/community_file_2_gene_list.py b/src/find_interaction_cluster/community_file_tools/community_file_2_gene_list.py index 63b9aed456b607612c4234dbfbb785ba16d58e21..c54acd89d293e5257ec420940c6049719e9c1c84 100644 --- a/src/find_interaction_cluster/community_file_tools/community_file_2_gene_list.py +++ b/src/find_interaction_cluster/community_file_tools/community_file_2_gene_list.py @@ -14,6 +14,7 @@ from typing import Dict, List import lazyparser as lp import pandas as pd +import polars as pl from ...figures_utils.config_figures import Config as ConfF from ..config import Config @@ -73,7 +74,6 @@ def create_df_4_a_community( community: str, size: int, dic_id: Dict[int, str], - hg38_dic: Dict[int, str], ) -> pd.DataFrame: """ Create a small dataframe based on a string containing gene id separated \ @@ -84,8 +84,6 @@ def create_df_4_a_community( belongs :param size: The size a the community :param dic_id: A dicitonary linking id of gene to their symbol - :param hg38_dic: A dicitonary linking id of gene to their hg38 symbol - :return: A dataframe containing >>> create_df_4_a_community('1, 2, 3', 'C4', 3, {1: 'DSC2', 2: 'DSC1', 3: ... 'DSG1', 4: 'DSG4', 5: 'KCTD4', 6: 'TPT1'}, {1: 'DSC2-38', 2: 'DSC1', 3: @@ -97,18 +95,15 @@ def create_df_4_a_community( """ gene_ids = get_gene_list(gene_str) gene_names = [dic_id[gn] for gn in gene_ids] - gene_names_hg38 = [hg38_dic.get(gn, "") for gn in gene_ids] if len(gene_names) != size: raise ValueError( - f"gene name size ({len(gene_names)})" - f" and size ({size}) differt! " + f"gene name size ({len(gene_names)}) and size ({size}) differt! " ) return pd.DataFrame( { "cluster": [community] * size, "size": [size] * size, - "fasterdb_symbol": gene_names, - "hg38_symbol": gene_names_hg38, + "symbol": gene_names, "gene_id": gene_ids, } ) @@ -132,10 +127,9 @@ def create_full_df(df: pd.DataFrame) -> pd.DataFrame: 4 C2 2 KCTD4 KCTD4 5 """ dic_id = create_gene_dic() - hg38_dic = create_hg38_dic() df_list = [ create_df_4_a_community( - row["genes"], row["community"], row["nodes"], dic_id, hg38_dic + row["genes"], row["community"], row["nodes"], dic_id ) for _, row in df.iterrows() ] @@ -156,7 +150,16 @@ def gene_table_creator(community_file: str, outname: str = "") -> None: df = load_community_file(Path(community_file)) df = create_full_df(df) outf = outname or Path(community_file).stem - df.to_csv(output / f"{outf}.csv", sep="\t", index=False) + df = pl.from_pandas(df) + df = df.with_columns( + symbol=pl.col("symbol") + .over("cluster", mapping_strategy="join") + .list.join(", "), + gene_id=pl.col("gene_id") + .over("cluster", mapping_strategy="join") + .list.join(", "), + ).unique() + df.write_csv(output / f"{outf}.csv", separator="\t") if __name__ == "__main__":