Skip to content
Snippets Groups Projects
get_projects_interaction.py 3.93 KiB
#!/usr/bin/env python3

# -*- coding: UTF-8 -*-

"""
Description: The goal of this script is to get the total  \
number of interaction by projects and select 9 projects with the following \
requirements:

* 3 projects must be those with the minimum possible interactions.
* 3 projects must be those with the greatest number of interactions
* 3 projects must contains an average number of interactions
"""

import sqlite3
import pandas as pd
from .config import ConfigNt
import seaborn as sns
import matplotlib.pyplot as plt
from ..logging_conf import logging_def
import logging


def get_interaction_by_project(cnx: sqlite3.Connection, weight: int,
                               same_gene: bool) -> pd.DataFrame:
    """
    Get the number of interactions by projects.

    :param cnx: Connection to chia-pet database
    :param weight: A weight threshold
    :param same_gene: Say if we are considering interaction within the same \
    gene
    :return: The table containing the number of interaction by projects
    """
    logging.debug('Getting interaction from database')
    if same_gene:
        query = f"SELECT id_project, COUNT(*) " \
                f"FROM cin_exon_interaction " \
                f"WHERE weight >= {weight} " \
                f"GROUP BY id_project"
    else:
        query = f"""SELECT id_project, COUNT(*)
                    FROM cin_exon_interaction t1, cin_exon t2, cin_exon t3
                    WHERE t1.weight >= {weight}
                    AND t1.exon1 = t2.id
                    AND t1.exon2 = t3.id
                    AND t2.id_gene != t3.id_gene
                    GROUP BY id_project"""
    df = pd.read_sql_query(query, cnx)
    df.columns = ['projects', 'interaction_count']
    df.sort_values('interaction_count', ascending=True, inplace=True)
    logging.debug(df.head())
    return df


def make_barplot(df: pd.DataFrame, weight: int):
    """
    Make a barplot displaying the number of interactions for every project.

    :param weight: The minimum weight of interaction to concidere them
    :param df: The dataframe containing the number of interaction by \
    projects
    """
    logging.debug("Creating barplot figure")
    ConfigNt.interaction.mkdir(parents=True, exist_ok=True)
    sns.set()
    sns.set_context('talk')
    plt.figure(figsize=(20, 12))
    sns.barplot(x="projects", y="interaction_count", data=df)
    plt.xticks(rotation=90)
    outfile = ConfigNt.get_interaction_file(weight)
    plt.savefig(outfile.parent / outfile.name.replace('txt', 'pdf'),
                bbox_inches='tight')
    plt.close()


def select_projects(df: pd.DataFrame):
    """
    Select the wanted projects and write them in a file

    :param df: The dataframe containing the number of interaction by \
    projects
    """
    logging.debug("Selecting projects")
    sp = list(df[df['interaction_count'] > 2000].projects.values)[0:2]
    sp += list(df[df['interaction_count'] > 30000].projects.values)[0:2]
    sp += list(df[df['interaction_count'] > 100000].projects.values)[0:2]
    sp += list(df[df['interaction_count'] > 400000].projects.values)[0:2]
    with ConfigNt.selected_project.open('w') as outf:
        outf.write("\n".join(sp) + "\n")


def get_interactions_number(weight: int = 1, same_gene: bool = False,
                            logging_level: str = "DISABLE"):
    """
    Get the number of interaction by projects

    :param weight: The minimum weight of correlation to consider them
    :param same_gene: Say if we are considering interaction within the same \
    gene
    """
    logging_def(ConfigNt.interaction, __file__, logging_level)
    logging.info(f'Recovering interaction count with a weight of {weight}')
    cnx = sqlite3.connect(ConfigNt.db_file)
    df = get_interaction_by_project(cnx, weight, same_gene)
    make_barplot(df, weight)
    df.to_csv(ConfigNt.get_interaction_file(weight),
              sep="\t", index=False)
    sns.barplot()
    select_projects(df)


if __name__ == "__main__":
    get_interactions_number()