#!/usr/bin/env python

import os, time, csv, sys, re
import argparse
import pysam as ps
import pandas as pd
import itertools
from hicstuff_log import logger
import hicstuff_log as hcl
import hicstuff_io as hio
import hicstuff_digest as hcd
import hicstuff_stats as hcs
from Bio import SeqIO
import logging
from hicstuff_version import __version__



def bam2pairs(bam1, bam2, out_pairs, info_contigs, min_qual=30):
    """
    Make a .pairs file from two Hi-C bam files sorted by read names.
    The Hi-C mates are matched by read identifier. Pairs where at least one
    reads maps with MAPQ below  min_qual threshold are discarded. Pairs are
    sorted by readID and stored in upper triangle (first pair higher).
    Parameters
    ----------
    bam1 : str
        Path to the name-sorted BAM file with aligned Hi-C forward reads.
    bam2 : str
        Path to the name-sorted BAM file with aligned Hi-C reverse reads.
    out_pairs : str
        Path to the output space-separated .pairs file with columns
        readID, chr1 pos1 chr2 pos2 strand1 strand2
    info_contigs : str
        Path to the info contigs file, to get info on chromosome sizes and order.
    min_qual : int
        Minimum mapping quality required to keep a Hi-C pair.
    """
    forward = ps.AlignmentFile(bam1, "rb")
    reverse = ps.AlignmentFile(bam2, "rb")

    # Generate header lines
    format_version = "## pairs format v1.0\n"
    sorting = "#sorted: readID\n"
    cols = "#columns: readID chr1 pos1 chr2 pos2 strand1 strand2\n"
    # Chromosome order will be identical in info_contigs and pair files
    chroms = pd.read_csv(info_contigs, sep="\t").apply(
        lambda x: "#chromsize: %s %d\n" % (x.contig, x.length), axis=1
    )
    with open(out_pairs, "w") as pairs:
        pairs.writelines([format_version, sorting, cols] + chroms.tolist())
        pairs_writer = csv.writer(pairs, delimiter="\t")
        n_reads = {"total": 0, "mapped": 0}
        # Remember if some read IDs were missing from either file
        unmatched_reads = 0
        # Remember if all reads in one bam file have been read
        exhausted = [False, False]
        # Iterate on both BAM simultaneously
        end_regex = re.compile(r'/[12]$')
        for end1, end2 in itertools.zip_longest(forward, reverse):
            # Remove end-specific suffix if any
            end1.query_name = re.sub(end_regex, '', end1.query_name)
            end2.query_name = re.sub(end_regex, '', end2.query_name)
            # Both file still have reads
            # Check if reads pass filter
            try:
                end1_passed = end1.mapping_quality >= min_qual
            # Happens if end1 bam file has been exhausted
            except AttributeError:
                exhausted[0] = True
                end1_passed = False
            try:
                end2_passed = end2.mapping_quality >= min_qual
            # Happens if end2 bam file has been exhausted
            except AttributeError:
                exhausted[1] = True
                end2_passed = False
            # Skip read if mate is not present until they match or reads
            # have been exhausted
            while sum(exhausted) == 0 and end1.query_name != end2.query_name:
                # Get next read and check filters again
                # Count single-read iteration
                unmatched_reads += 1
                n_reads["total"] += 1
                if end1.query_name < end2.query_name:
                    try:
                        end1 = next(forward)
                        end1_passed = end1.mapping_quality >= min_qual
                    # If EOF is reached in BAM 1
                    except (StopIteration, AttributeError):
                        exhausted[0] = True
                        end1_passed = False
                    n_reads["mapped"] += end1_passed
                elif end1.query_name > end2.query_name:
                    try:
                        end2 = next(reverse)
                        end2_passed = end2.mapping_quality >= min_qual
                    # If EOF is reached in BAM 2
                    except (StopIteration, AttributeError):
                        exhausted[1] = True
                        end2_passed = False
                    n_reads["mapped"] += end2_passed

            # 2 reads processed per iteration, unless one file is exhausted
            n_reads["total"] += 2 - sum(exhausted)
            n_reads["mapped"] += sum([end1_passed, end2_passed])
            # Keep only pairs where both reads have good quality
            if end1_passed and end2_passed:

                # Flipping to get upper triangle
                if (
                    end1.reference_id == end2.reference_id
                    and end1.reference_start > end2.reference_start
                ) or end1.reference_id > end2.reference_id:
                    end1, end2 = end2, end1
                pairs_writer.writerow(
                    [
                        end1.query_name,
                        end1.reference_name,
                        end1.reference_start + 1,
                        end2.reference_name,
                        end2.reference_start + 1,
                        "-" if end1.is_reverse else "+",
                        "-" if end2.is_reverse else "+",
                    ]
                )
    pairs.close()
    if unmatched_reads > 0:
        logger.warning(
            "%d reads were only present in one BAM file. Make sure you sorted reads by name before running the pipeline.",
            unmatched_reads,
        )
    logger.info(
        "{perc_map}% reads (single ends) mapped with Q >= {qual} ({mapped}/{total})".format(
            total=n_reads["total"],
            mapped=n_reads["mapped"],
            perc_map=round(100 * n_reads["mapped"] / n_reads["total"]),
            qual=min_qual,
        )
    )

def generate_log_header(log_path, input1, input2, genome, enzyme):
    hcl.set_file_handler(log_path, formatter=logging.Formatter(""))
    logger.info("## hicstuff: v%s log file", __version__)
    logger.info("## date: %s", time.strftime("%Y-%m-%d %H:%M:%S"))
    logger.info("## enzyme: %s", str(enzyme))
    logger.info("## input1: %s ", input1)
    logger.info("## input2: %s", input2)
    logger.info("## ref: %s", genome)
    logger.info("---")
    #hcl.set_file_handler(log_path, formatter=hcl.logfile_formatter)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-b1", "--bam1")
    parser.add_argument("-b2", "--bam2")
    parser.add_argument("-o", "--out_pairs")
    parser.add_argument("-x", "--out_idx")
    parser.add_argument("-i", "--info_contigs")
    parser.add_argument("-q", "--min_qual")
    parser.add_argument("-e", "--enzyme")
    parser.add_argument("-f", "--fasta")
    parser.add_argument("-c", "--circular")
    args = parser.parse_args()

    bam1 = args.bam1
    bam2 = args.bam2
    out_pairs = args.out_pairs
    out_idx = args.out_idx
    info_contigs = args.info_contigs
    min_qual = int(args.min_qual)
    enzyme = args.enzyme
    fasta = args.fasta
    circular = args.circular

    #hicstuff case sensitive enzymes adaptation
    if enzyme == "hindiii":
        enzyme = "HindIII"
    elif enzyme == "dpnii":
        enzyme = "DpnII"
    elif enzyme == "bglii":
        enzyme = "BglII"
    elif enzyme == "mboi":
        enzyme = "MboI"
    elif enzyme == "arima":
        enzyme = ["DpnII","HinfI"]

    log_file = "hicstuff_pairs.log"
    sys.stderr = open ("hicstuff_pairs.log", "wt")
    hcl.set_file_handler(log_file)
    generate_log_header(log_file, bam1, bam2, fasta, enzyme)

    bam2pairs(bam1, bam2, out_pairs, info_contigs, min_qual)

    restrict_table = {}
    for record in SeqIO.parse(hio.read_compressed(fasta), "fasta"):
        # Get chromosome restriction table
        restrict_table[record.id] = hcd.get_restriction_table(
            record.seq, enzyme, circular=circular
        )

    hcd.attribute_fragments(out_pairs, out_idx, restrict_table)

    hio.sort_pairs(
        out_idx,
        out_idx + ".sorted",
        keys=["chr1", "pos1", "chr2", "pos2"],
        threads=1,
        tmp_dir=None,
    )
    os.rename(out_idx + ".sorted", out_idx)