#!/usr/bin/env python

import os, time, csv, sys, re
import argparse
import pysam as ps
import pandas as pd
import shutil as st
import subprocess as sp
import itertools
from hicstuff.log import logger
import hicstuff.io as hio
import hicstuff.digest as hcd
from Bio import SeqIO
import hicstuff.log as hcl


def pairs2matrix(
    pairs_file, mat_file, fragments_file, mat_fmt="graal", threads=1, tmp_dir=None
):
    """Generate the matrix by counting the number of occurences of each
    combination of restriction fragments in a pairs file.
    Parameters
    ----------
    pairs_file : str
        Path to a Hi-C pairs file, with frag1 and frag2 columns added.
    mat_file : str
        Path where the matrix will be written.
    fragments_file : str
        Path to the fragments_list.txt file. Used to know total
        matrix size in case some observations are not observed at the end.
    mat_fmt : str
        The format to use when writing the matrix. Can be graal or bg2 format.
    threads : int
        Number of threads to use in parallel.
    tmp_dir : str
        Temporary directory for sorting files. If None given, will use the system default.
    """
    # Number of fragments is N lines in frag list - 1 for the header
    n_frags = sum(1 for line in open(fragments_file, "r")) - 1
    frags = pd.read_csv(fragments_file, delimiter="\t")

    def write_mat_entry(frag1, frag2, contacts):
        """Write a single sparse matrix entry in either graal or bg2 format"""
        if mat_fmt == "graal":
            mat.write("\t".join(map(str, [frag1, frag2, n_occ])) + "\n")
        elif mat_fmt == "bg2":
            frag1, frag2 = int(frag1), int(frag2)
            mat.write(
                "\t".join(
                    map(
                        str,
                        [
                            frags.chrom[frag1],
                            frags.start_pos[frag1],
                            frags.end_pos[frag1],
                            frags.chrom[frag2],
                            frags.start_pos[frag2],
                            frags.end_pos[frag2],
                            contacts,
                        ],
                    )
                )
                + "\n"
            )

    pre_mat_file = mat_file + ".pre.pairs"
    # hio.sort_pairs(
    #     pairs_file,
    #     pre_mat_file,
    #     keys=["frag1", "frag2"],
    #     threads=threads,
    #     tmp_dir=tmp_dir,
    # )


    header_size = len(hio.get_pairs_header(pre_mat_file))
    with open(pre_mat_file, "r") as pairs, open(mat_file, "w") as mat:
        # Skip header lines
        for _ in range(header_size):
            next(pairs)
        prev_pair = ["0", "0"]  # Pairs identified by [frag1, frag2]
        n_occ = 0  # Number of occurences of each frag combination
        n_nonzero = 0  # Total number of nonzero matrix entries
        n_pairs = 0  # Total number of pairs entered in the matrix
        pairs_reader = csv.reader(pairs, delimiter="\t")
        # First line contains nrows, ncols and number of nonzero entries.
        # Number of nonzero entries is unknown for now
        if mat_fmt == "graal":
            mat.write("\t".join(map(str, [n_frags, n_frags, "-"])) + "\n")
        for pair in pairs_reader:
            # Fragment ids are field 8 and 9
            curr_pair = [pair[7], pair[8]]
            # Increment number of occurences for fragment pair
            if prev_pair == curr_pair:
                n_occ += 1
            # Write previous pair and start a new one
            else:
                if n_occ > 0:
                    write_mat_entry(prev_pair[0], prev_pair[1], n_occ)
                prev_pair = curr_pair
                n_pairs += n_occ
                n_occ = 1
                n_nonzero += 1
        # Write the last value
        write_mat_entry(curr_pair[0], curr_pair[1], n_occ)
        n_nonzero += 1
        n_pairs += 1

    # Edit header line to fill number of nonzero entries inplace in graal header
    if mat_fmt == "graal":
        with open(mat_file) as mat, open(pre_mat_file, "w") as tmp_mat:
            header = mat.readline()
            header = header.replace("-", str(n_nonzero))
            tmp_mat.write(header)
            st.copyfileobj(mat, tmp_mat)
        # Replace the matrix file with the one with corrected header
        os.rename(pre_mat_file, mat_file)
    else:
        os.remove(pre_mat_file)

    logger.info(
        "%d pairs used to build a contact map of %d bins with %d nonzero entries.",
        n_pairs,
        n_frags,
        n_nonzero,
    )


def pairs2cool(pairs_file, cool_file, bins_file):
    """
    Make a cooler file from the pairs file. See: https://github.com/mirnylab/cooler/ for more informations.

    Parameters
    ----------
    pairs_file : str
        Path to the pairs file containing input contact data.
    cool_file : str
        Path to the output cool file name to generate.
    bins_file : str
        Path to the file containing genomic segmentation information. (fragments_list.txt).
    """

    # Make bins file compatible with cooler cload
    bins_tmp = bins_file + ".cooler"
    bins = pd.read_csv(bins_file, sep="\t", usecols=[1, 2, 3], skiprows=1, header=None)
    bins.to_csv(bins_tmp, sep="\t", header=False, index=False)

    cooler_cmd = "cooler cload pairs -c1 2 -p1 3 -p2 4 -c2 5 {bins} {pairs} {cool}"
    cool_args = {"bins": bins_tmp, "pairs": pairs_file, "cool": cool_file}
    sp.call(cooler_cmd.format(**cool_args), shell=True)
    os.remove(bins_tmp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--pairs")
    parser.add_argument("-f", "--fragments")
    parser.add_argument("-t", "--type")
    parser.add_argument("-o", "--output")
    args = parser.parse_args()

    pairs = args.pairs
    fragments_list = args.fragments
    mat_format = args.type
    mat_out = args.output

    log_file = "hicstuff_matrix.log"
    sys.stderr = open ("hicstuff_matrix.log", "wt")
    hcl.set_file_handler(log_file)

    # Log which pairs file is being used and how many pairs are listed
    pairs_count = 0
    with open(pairs, "r") as file:
        for line in file:
            if line.startswith('#'):
                continue
            else:
                pairs_count += 1
    logger.info(
        "Generating matrix from pairs file %s (%d pairs in the file) ",
        pairs, pairs_count
    )


    if mat_format == "cool":
        # Name matrix file in .cool
        cool_file = os.path.splitext(mat_out)[0] + ".cool"
        pairs2cool(pairs, cool_file, fragments_list)
    else:
        pairs2matrix(
            pairs,
            mat_out,
            fragments_list,
            mat_fmt=mat_format,
            threads=1,
            tmp_dir=None,
        )