diff --git a/src/nt_composition/get_projects_interaction.py b/src/nt_composition/get_projects_interaction.py index d2aa5cde3b828d39e90a01b74ad871fa83a9971b..8a485638a0076c68852ae73353f08529a84b6208 100644 --- a/src/nt_composition/get_projects_interaction.py +++ b/src/nt_composition/get_projects_interaction.py @@ -21,17 +21,20 @@ from ..logging_conf import logging_def import logging -def get_interaction_by_project(cnx: sqlite3.Connection) -> pd.DataFrame: +def get_interaction_by_project(cnx: sqlite3.Connection, weight: int + ) -> pd.DataFrame: """ Get the number of interactions by projects. :param cnx: Connection to chia-pet database + :param weight: A weight threshold :return: The table containing the number of interaction by projects """ logging.debug('Getting interaction from database') - query = "SELECT id_project, COUNT(*) " \ - "FROM cin_exon_interaction " \ - "GROUP BY id_project" + query = f"SELECT id_project, COUNT(*) " \ + f"FROM cin_exon_interaction " \ + f"WHERE weight >= {weight} " \ + f"GROUP BY id_project" df = pd.read_sql_query(query, cnx) df.columns = ['projects', 'interaction_count'] df.sort_values('interaction_count', ascending=True, inplace=True) @@ -39,10 +42,11 @@ def get_interaction_by_project(cnx: sqlite3.Connection) -> pd.DataFrame: return df -def make_barplot(df: pd.DataFrame): +def make_barplot(df: pd.DataFrame, weight: int): """ Make a barplot displaying the number of interactions for every project. + :param weight: The minimum weight of interaction to concidere them :param df: The dataframe containing the number of interaction by \ projects """ @@ -53,8 +57,8 @@ def make_barplot(df: pd.DataFrame): plt.figure(figsize=(20, 12)) sns.barplot(x="projects", y="interaction_count", data=df) plt.xticks(rotation=90) - plt.savefig(ConfigNt.interaction_file.parent / - ConfigNt.interaction_file.name.replace('txt', 'pdf'), + outfile = ConfigNt.get_interaction_file(weight) + plt.savefig(outfile.parent / outfile.name.replace('txt', 'pdf'), bbox_inches='tight') plt.close() @@ -75,15 +79,18 @@ def select_projects(df: pd.DataFrame): outf.write("\n".join(sp) + "\n") -def get_interactions_number(logging_level: str = "DISABLE"): +def get_interactions_number(weight: int = 1, logging_level: str = "DISABLE"): """ Get the number of interaction by projects + + :param weight: The minimum weight of correlation to consider them """ logging_def(ConfigNt.interaction, __file__, logging_level) + logging.info(f'Recovering interaction count with a weight of {weight}') cnx = sqlite3.connect(ConfigNt.db_file) - df = get_interaction_by_project(cnx) - make_barplot(df) - df.to_csv(ConfigNt.interaction_file, + df = get_interaction_by_project(cnx, weight) + make_barplot(df, weight) + df.to_csv(ConfigNt.get_interaction_file(weight), sep="\t", index=False) sns.barplot() select_projects(df)