diff --git a/src/nt_composition/make_nt_correlation.py b/src/nt_composition/make_nt_correlation.py index bb35a13e04f91da4b200ce69ec974056ea22965c..4ff46811888eece4cb4e03630303b9e97332f427 100644 --- a/src/nt_composition/make_nt_correlation.py +++ b/src/nt_composition/make_nt_correlation.py @@ -17,16 +17,21 @@ import doctest from .config import ConfigNt from ..logging_conf import logging_def from tqdm import tqdm -from typing import Dict, Tuple +from typing import Dict, Tuple, Any, List import seaborn as sns import matplotlib.pyplot as plt from scipy.stats import linregress from itertools import product +from random import random +import multiprocessing as mp +import os class NoInteractionError(Exception): pass +POSITION = 0 + def get_project_colocalisation(cnx: sqlite3.Connection, project: str ) -> np.array: @@ -37,7 +42,7 @@ def get_project_colocalisation(cnx: sqlite3.Connection, project: str :param project: The project of interest :return: The table containing the number of interaction by projects """ - logging.debug('Recovering interaction') + logging.debug(f'Recovering interaction ({os.getpid()})') query = "SELECT exon1, exon2 " \ "FROM cin_exon_interaction " \ f"WHERE id_project = '{project}'" @@ -124,10 +129,11 @@ def create_density_table(arr_interaction: np.array, dic_freq: Dict[str, float], :param dic_freq: The frequency dataframe. :return: The density table """ - logging.debug('Calculating density table') + logging.debug(f'Calculating density table ({os.getpid()})') exons_list = get_interacting_exon(arr_interaction) dic = {'exon': [], 'freq_exon': [], 'freq_coloc_exon': [], 'oexon': []} - pbar = tqdm(exons_list, desc="Getting frequencies...") + pbar = tqdm(exons_list, desc=f"Getting frequencies...({os.getpid()})", + position=mp.current_process()._identity[0] - 1) for exon in pbar: freq_ex = dic_freq[exon] oexon = get_all_exon_interacting_with_another(exon, arr_interaction) @@ -199,10 +205,13 @@ def create_density_figure(nt: str, ft_type: str, dic_freq = get_frequency_dic(cnx, nt, ft_type) df = create_density_table(arr_interaction, dic_freq) df.to_csv(outfile, sep="\t", index=False) + r, p = create_density_fig(df, project, ft_type, nt) else: - logging.debug(f'The file {outfile} exist, recovering data ...') + logging.debug(f'The file {outfile} exist, recovering data ' + f'({os.getpid()})') df = pd.read_csv(outfile, sep="\t") - return create_density_fig(df, project, ft_type, nt) + s, i, r, p, stderr = linregress(df.freq_exon, df.freq_coloc_exon) + return r, p def create_scatterplot(df_cor: pd.DataFrame, ft_type: str, ft: str): @@ -221,7 +230,11 @@ def create_scatterplot(df_cor: pd.DataFrame, ft_type: str, ft: str): (df_cor['ft'] == ft)].copy() sns.scatterplot(x='cor', y='nb_interaction', data=df_cor) left, right = plt.xlim() - plt.text(df_cor.cor + right * 0.03, df_cor.nb_interaction, ) + bottom, top = plt.ylim() + for i in range(df_cor.shape[0]): + plt.text(df_cor.cor.values[i] + right * 0.005, + df_cor.nb_interaction.values[i] + random() * top / 60, + df_cor.project.values[i], fontsize=8) plt.xlabel(f"Correlation for {ft} ({ft_type}) in project") plt.ylabel("Number of total interaction in projects") plt.title(f'Project correlation for {ft} ({ft_type}) ' @@ -230,28 +243,61 @@ def create_scatterplot(df_cor: pd.DataFrame, ft_type: str, ft: str): plt.close() -def create_all_frequency_figures(logging_level: str = "DISABLE"): +def execute_density_figure_function(di: pd.DataFrame, project : str, + ft_type: str, ft: str) -> Dict[str, Any]: + """ + Execute create_density_figure and organized the results in a dictionary. + + :param project: The project of interest + :param ft_type: The feature type of interest + :param ft: The feature of interest + :return: + """ + logging.info(f'Working on {project}, {ft_type}, {ft} - {os.getpid()}') + r, p = create_density_figure(ft, ft_type, project) + tmp = {"project": project, "ft_type": ft_type, + "ft": ft, "cor": r, "pval": p, + 'nb_interaction': di[di['projects'] == project].iloc[0, 1]} + return tmp + + +def combine_dic(list_dic: Dict) -> Dict: + """ + Combine The dictionaries in list_dic. + + :param list_dic: A list of dictionaries + :return: The combined dictionary + """ + dic = {k: [] for k in list_dic[0]} + for d in list_dic: + for k in d: + dic[k].append(d[k]) + return dic + + +def create_all_frequency_figures(ps: int, + logging_level: str = "DISABLE"): """ Make density figure for every selected projects. :param logging_level: The level of data to display. + :param ps: The number of processes to create """ logging_def(ConfigNt.interaction, __file__, logging_level) di = pd.read_csv(ConfigNt.interaction_file, sep="\t") - dic = {"project": [], 'ft_type': [], 'ft': [], 'nb_interaction': [], - 'cor': [], 'pval': []} with open(ConfigNt.selected_project, 'r') as f: projects = f.read().splitlines() nt_list = ['A', 'C', 'G', 'T', 'S', 'W'] param = product(projects, nt_list, ['nt']) + pool = mp.Pool(processes=ps) + processes = [] for project, nt, ft_type in param: - logging.info(f'Working on {project}, {ft_type}, {nt}') - r, p = create_density_figure(nt, ft_type, project) - tmp = {"project": project, "ft_type": ft_type, - "ft": nt, "cor": r, "pval": p, - 'nb_interaction': di[di['projects'] == project].iloc[0, 1]} - for k in tmp.keys(): - dic[k].append(tmp[k]) + args = [di, project, ft_type, nt] + processes.append(pool.apply_async(execute_density_figure_function, args)) + results = [] + for proc in processes: + results.append(proc.get(timeout=None)) + dic = combine_dic(results) df_corr = pd.DataFrame(dic) df_corr.to_csv(ConfigNt.density_folder / "density_recap.txt", sep="\t") create_scatterplot(df_corr, "nt", "S")