#!/usr/bin/env python

import os, time, csv, sys, re
import argparse
import pysam as ps
import pandas as pd
import shutil as st
import subprocess as sp
import itertools
from hicstuff_log import logger
import hicstuff_io as hio
import hicstuff_digest as hcd
from Bio import SeqIO
import hicstuff_log as hcl


def pairs2matrix(
    pairs_file, mat_file, fragments_file, mat_fmt="graal", threads=1, tmp_dir=None
):
    """Generate the matrix by counting the number of occurences of each
    combination of restriction fragments in a pairs file.
    Parameters
    ----------
    pairs_file : str
        Path to a Hi-C pairs file, with frag1 and frag2 columns added.
    mat_file : str
        Path where the matrix will be written.
    fragments_file : str
        Path to the fragments_list.txt file. Used to know total
        matrix size in case some observations are not observed at the end.
    mat_fmt : str
        The format to use when writing the matrix. Can be graal or bg2 format.
    threads : int
        Number of threads to use in parallel.
    tmp_dir : str
        Temporary directory for sorting files. If None given, will use the system default.
    """
    # Number of fragments is N lines in frag list - 1 for the header
    n_frags = sum(1 for line in open(fragments_file, "r")) - 1
    frags = pd.read_csv(fragments_file, delimiter="\t")

    def write_mat_entry(frag1, frag2, contacts):
        """Write a single sparse matrix entry in either graal or bg2 format"""
        if mat_fmt == "graal":
            mat.write("\t".join(map(str, [frag1, frag2, n_occ])) + "\n")
        elif mat_fmt == "bg2":
            frag1, frag2 = int(frag1), int(frag2)
            mat.write(
                "\t".join(
                    map(
                        str,
                        [
                            frags.chrom[frag1],
                            frags.start_pos[frag1],
                            frags.end_pos[frag1],
                            frags.chrom[frag2],
                            frags.start_pos[frag2],
                            frags.end_pos[frag2],
                            contacts,
                        ],
                    )
                )
                + "\n"
            )

    pre_mat_file = mat_file + ".pre.pairs"
    # hio.sort_pairs(
    #     pairs_file,
    #     pre_mat_file,
    #     keys=["frag1", "frag2"],
    #     threads=threads,
    #     tmp_dir=tmp_dir,
    # )


    header_size = len(hio.get_pairs_header(pre_mat_file))
    with open(pre_mat_file, "r") as pairs, open(mat_file, "w") as mat:
        # Skip header lines
        for _ in range(header_size):
            next(pairs)
        prev_pair = ["0", "0"]  # Pairs identified by [frag1, frag2]
        n_occ = 0  # Number of occurences of each frag combination
        n_nonzero = 0  # Total number of nonzero matrix entries
        n_pairs = 0  # Total number of pairs entered in the matrix
        pairs_reader = csv.reader(pairs, delimiter="\t")
        # First line contains nrows, ncols and number of nonzero entries.
        # Number of nonzero entries is unknown for now
        if mat_fmt == "graal":
            mat.write("\t".join(map(str, [n_frags, n_frags, "-"])) + "\n")
        for pair in pairs_reader:
            # Fragment ids are field 8 and 9
            curr_pair = [pair[7], pair[8]]
            # Increment number of occurences for fragment pair
            if prev_pair == curr_pair:
                n_occ += 1
            # Write previous pair and start a new one
            else:
                if n_occ > 0:
                    write_mat_entry(prev_pair[0], prev_pair[1], n_occ)
                prev_pair = curr_pair
                n_pairs += n_occ
                n_occ = 1
                n_nonzero += 1
        # Write the last value
        write_mat_entry(curr_pair[0], curr_pair[1], n_occ)
        n_nonzero += 1
        n_pairs += 1

    # Edit header line to fill number of nonzero entries inplace in graal header
    if mat_fmt == "graal":
        with open(mat_file) as mat, open(pre_mat_file, "w") as tmp_mat:
            header = mat.readline()
            header = header.replace("-", str(n_nonzero))
            tmp_mat.write(header)
            st.copyfileobj(mat, tmp_mat)
        # Replace the matrix file with the one with corrected header
        os.rename(pre_mat_file, mat_file)
    else:
        os.remove(pre_mat_file)

    logger.info(
        "%d pairs used to build a contact map of %d bins with %d nonzero entries.",
        n_pairs,
        n_frags,
        n_nonzero,
    )


def pairs2cool(pairs_file, cool_file, bins_file):
    """
    Make a cooler file from the pairs file. See: https://github.com/mirnylab/cooler/ for more informations.

    Parameters
    ----------
    pairs_file : str
        Path to the pairs file containing input contact data.
    cool_file : str
        Path to the output cool file name to generate.
    bins_file : str
        Path to the file containing genomic segmentation information. (fragments_list.txt).
    """

    # Make bins file compatible with cooler cload
    bins_tmp = bins_file + ".cooler"
    bins = pd.read_csv(bins_file, sep="\t", usecols=[1, 2, 3], skiprows=1, header=None)
    bins.to_csv(bins_tmp, sep="\t", header=False, index=False)

    cooler_cmd = "cooler cload pairs -c1 2 -p1 3 -p2 4 -c2 5 {bins} {pairs} {cool}"
    cool_args = {"bins": bins_tmp, "pairs": pairs_file, "cool": cool_file}
    sp.call(cooler_cmd.format(**cool_args), shell=True)
    os.remove(bins_tmp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--pairs")
    parser.add_argument("-f", "--fragments")
    parser.add_argument("-t", "--type")
    parser.add_argument("-o", "--output")
    args = parser.parse_args()

    pairs = args.pairs
    fragments_list = args.fragments
    mat_format = args.type
    mat_out = args.output

    log_file = "hicstuff_matrix.log"
    sys.stderr = open ("hicstuff_matrix.log", "wt")
    hcl.set_file_handler(log_file)

    # Log which pairs file is being used and how many pairs are listed
    pairs_count = 0
    with open(pairs, "r") as file:
        for line in file:
            if line.startswith('#'):
                continue
            else:
                pairs_count += 1
    logger.info(
        "Generating matrix from pairs file %s (%d pairs in the file) ",
        pairs, pairs_count
    )


    if mat_format == "cool":
        # Name matrix file in .cool
        cool_file = os.path.splitext(mat_out)[0] + ".cool"
        pairs2cool(pairs, cool_file, fragments_list)
    else:
        pairs2matrix(
            pairs,
            mat_out,
            fragments_list,
            mat_fmt=mat_format,
            threads=1,
            tmp_dir=None,
        )