#!/usr/bin/env python2.7
import sys
import time
import argparse
import pprint
import pickle
import gzip
import numpy as np
import validation
import parsers
import cnv_struct
# This file is part of CNVAnalysisToolkit.
# 
# CNVAnalysisToolkit is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# CNVAnalysisToolkit is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with CNVAnalysisToolkit.  If not, see <http://www.gnu.org/licenses/>.
__author__ = "Marc-Andre Legault (StatGen)"
__copyright__ = "Copyright (C) 2013 StatGen"
__license__ = "GNU General Public License v3 (GPL-3)"
[docs]def main(args):
    """Generates a graph structure representing the familial status for a given
    loci.
    The signature matrix represents the status of every individual of a given
    family at a given loci. The status can be ``+``, ``-`` or ``0``, representing
    a gain, a loss or a `no call`, respectively. This being said, given a 
    particular loci, the Mendelian inheritance can of a variant can be quickly
    assessed by contemplating the status for every indivdual from a family.
    This is why we generate matrices with the status symbol for every individual
    in the family and count the number of times a signature occurs.
    As an example, let's say that both twins and the mother have a deletion, and
    the father had no CNV called at the given region, the signature would be
    ``(-, -, -, 0)`` as the arbitrary order for signatures is always 
    ``(twin1, twin2, mother, father)``.
    The goal of such an analysis was to quickly assess the amount of inherited
    CNVs and to detect any algorithm-specific biais.
    A pileup file parsing utility is also integrated with this tool allowing
    the validation of the regions by comparing the coverage inside and outside
    of the CNV loci. Such an analysis had modest success.
    """
    p = validation.get_parser_if_valid(args)
    # Get the CNVs
    cnvs = p.get_cnvs()
    # Create the main graph
    print "Creating the main graph."
    connected_components = create_family_graph(cnvs, args.threshold)
    # Merge CNVs from same type and same sample.
    # Merge overlapping CNVs
    print "Cleaning connected components (merge same sample vertices)."
    connected_components = clean_ccs(connected_components)
    # Save the graph pickle
    print "Saving the basic graph pickle."
    if args.save is not None:
        with open(args.save, "wb") as f:
            pickle.dump(connected_components, f)
    
    # If the path to pileups is provided, coverage information will be fetched.
    pileups = (args.twin1_pileup,
               args.twin2_pileup,
               args.mother_pileup,
               args.father_pileup)
    if None not in pileups:
        print "Loading the pileups."
        print "Fetching coverage information for samples in family."
        get_coverage(connected_components, *pileups)
        print "Writing the graph with coverage information."
        with open("ccs_with_coverage.pickle", "wb") as f:
            pickle.dump(connected_components, f)
    print "Building the final matrix"
    # Build the signature matrix.
    check_profiles(connected_components)
 
[docs]def get_coverage(cc_list, twin1_pileup, twin2_pileup, 
                 mother_pileup, father_pileup, window_size = 1000):
    """Computes the coverage inside and outside of every CNV loci represented
    by a connected component in the cc_list graph.
    :param cc_list: Graph represented by a list of its connected components.
    :type cc_list: list
    :param twin1_pileup: Path to the pileup for twin1.
    :type twin1_pileup: str
    :param twin2_pileup: Path to the pileup for twin2.
    :type twin2_pileup: str
    :param mother_pileup: Path to the pileup for mother.
    :type mother_pileup: str
    :param father_pileup: Path to the pileup for father.
    :type father_pileup: str
    :param window_size: The size of genomic window around the region for
                        coverage computation.
    :type window_size: int
    Concretely, this script adds the region_doc and cc_doc attributes to every
    connected component in the graph. The difference between those values
    can then be included in the printed matrices.
    """
    cc_list = sorted(cc_list, key = lambda cc: (cc.chr, cc.start))
    t1 = open_pileup(twin1_pileup)
    t2 = open_pileup(twin2_pileup)
    mother = open_pileup(mother_pileup)
    father = open_pileup(father_pileup)
    pileups = [t1, t2, mother, father]
    normalize_pileups_starting_position(*pileups)
    indexes = create_seek_index(pileups)
    count = 0
    for cc in cc_list:
        count += 1
        if count % int(float(len(cc_list)) / 100) == 0:
            print int(round(100.0 * count / len(cc_list))), "%"
        # Seek to beginning of region - window size.
        eof = seek_to(cc.start - window_size, cc.chr, pileups, indexes)
        if eof:
            # We don't have coverage information for this cc region.
            cc.region_doc = None
            cc.out_doc = None
            continue
        
        # Average coverage until region start.
        ones = np.ones(4)
        cov_prefix = np.zeros(4)
        prefix_count = np.ones(4)
        lines = [i.readline().split("\t") for i in pileups]
        positions = [int(i[1]) for i in lines]
        while int(lines[0][1]) < cc.start:
            cov_prefix += np.array([int(i[3]) for i in lines])
            prefix_count += ones
            lines = [i.readline().split("\t") for i in pileups]
        # Average coverage in region.
        cov_region = np.zeros(4)
        region_count = np.ones(4)
        lines = [i.readline().split("\t") for i in pileups]
        while int(lines[0][1]) < cc.end:
            cov_region += np.array([int(i[3]) for i in lines])
            region_count += ones
            lines = [i.readline().split("\t") for i in pileups]
        # Average coverage after region.
        cov_postfix = np.zeros(4)
        postfix_count = np.ones(4)
        lines = [i.readline().split("\t") for i in pileups]
        while int(lines[0][1]) < cc.end + window_size:
            cov_postfix += np.array([int(i[3]) for i in lines])
            postfix_count += ones
            lines = [i.readline().split("\t") for i in pileups]
        # Compute the averages.
        cov_prefix /= prefix_count
        cov_region /= region_count
        cov_postfix /= postfix_count
        cov_out = np.zeros(4)
        for i in xrange(len(cov_prefix)):
            cov_out[i] = (cov_prefix[i] + cov_postfix[i]) / 2.0
        cc.region_doc = cov_region
        cc.out_doc = cov_out 
    for f in pileups:
        f.close()
 
[docs]def create_seek_index(pileups):
    """Create and index of the seek positions (tell) to the genomic position.
    This is used to quickly move around very large pileup files. Use it only
    on unzipped files.
    """
    indexes = [[], [], [], []]
    # The indexes are of the form: [[(chr1, pos1, seek1), ...], # sample 1
    #                               [(chr1, pos1, seek1), ...], # sample 2
    #                                       ...  
    print "Creating seek index."
    step = int(5e6)
    for i in xrange(len(pileups)):
        # For every pileup.
        stop = False
        while not stop:
            pileups[i].seek(pileups[i].tell() + step)
            # Flush old line fragment
            pileups[i].readline()
            tell_pos = pileups[i].tell()
            try:
                line = pileups[i].readline().split("\t")
                pos = int(line[1])
                chromo = int(line[0][3:])
                indexes[i].append((chromo, pos, tell_pos))
            except IndexError:
                # EOF.
                stop = True
            except ValueError:
                # Sexual chr
                stop = True
    for i in pileups:
        i.seek(0)
    print "Finished creating seek index."
    return indexes
 
[docs]def normalize_pileups_starting_position(*pileups):
    """Takes an arbitrary number of files and uses readline so that all the
       files have the same starting position as the highest starting position
       pileup.
    """
    # Make sure the start on same chromosome.
    first_lines = [f.readline().split("\t") for f in pileups]
    chromos = [int(l[0][3:]) for l in first_lines]
    assert len(set(chromos)) == 1 and type(chromos[0]) is int
    positions = [int(l[1]) for l in first_lines]
    max_starting_pos = max(positions)
    for i in xrange(len(pileups)):
        pos = positions[i]
        while pos < max_starting_pos:
            pos = int(pileups[i].readline().split("\t")[1])
 
[docs]def seek_to(position, chromo, pileups, indexes, seek_profiling = False):
    """Seeks to a given position for all pileups. """
    # Compute metrics for the number of back-seeks as a function of the
    # distance.
    first = True
    stop = False
    ti = time.time()
    if seek_profiling: print "\tseek {}:{}".format(chromo, position)
    # Read current lines
    lines = [i.readline() for i in pileups]
    # The indexes are of the form: chromo => genomic position => seek position.
    # Start seeking
    while not stop:
        stop = True
        # Foreach pilup
        for i in xrange(len(pileups)):
            # We make sure that the parsed chromosome is still before
            # the position we're looking for.
            sexual = False
            try:
                this_chromo = int(lines[i].split("\t")[0][3:])
            except ValueError:
                sexual = True
            this_pos = int(lines[i].split("\t")[1])
            if not sexual:
                assert this_chromo <= chromo
                # If we're too early or too far from destination, check the
                # index.
                if (first and
                    abs(this_chromo * this_pos - chromo * position) > 100000):
                    if seek_profiling: sys.stdout.write("\tINDEX LOOKUP... ")
                    # Sorted genomic positions
                    milestones = sorted(indexes[i], 
                                       key = lambda x: (x[0], x[1]))
                    # Find closes 'floored' milestone and seek there.
                    # TODO: log(n) lookups since milestones are sorted.
                    j = 0
                    # Seek to beginning of the right chromosome
                    if chromo != 1:
                        while milestones[j][0] < chromo and j < len(milestones):
                            j += 1
                    # Now seek to the right position.
                    while (milestones[j][1] < position and 
                           milestones[j][0] == chromo and
                           j < len(milestones)):
                        j += 1
                    
                    if seek_profiling:
                        s = "INDEX HIT: ({}:{}) -> ".format(chromo, position)
                        s += str(milestones[j])
                        print s
                    this_chromo = milestones[j][0]
                    this_pos = milestones[j][1]
                    pileups[i].seek(milestones[j][2])
                    stop = False
                # Check if we're too far.
                if first and (this_pos > position or this_chromo > chromo):
                    go_back = True
                    while go_back:
                        # if seek_profiling: print "<<<< "
                        pileups[i].seek(pileups[i].tell() - 10000)
                        pileups[i].readline()
                        lines[i] = pileups[i].readline()
                        l = lines[i].split("\t")
                        this_pos = int(l[1])
                        this_chromo = int(l[0][3:])
                        if (this_pos <= position and
                            this_chromo <= chromo):
                            go_back = False
                    stop = False
                if this_chromo < chromo or this_pos < position:
                    stop = False
                    # Seek individually.
                    lines[i] = pileups[i].readline()
        first = False
    # True if we reached the end of the autosomes in at least one pileup.
    positions = [int(i.split("\t")[1]) for i in lines]
    tf = time.time()
    if seek_profiling: print "\tendseek {} ({})".format(positions, tf - ti)
    return sexual
 
def open_pileup(path):
    if path.endswith(".gz"):
        return gzip.open(path, "rb")
    else:
        return open(path, "r")
[docs]def check_profiles(graph):
    """Counts the signatures and prints the matrix given a signature graph.
    As described in the :py:func:`generate_cnv_graph.main` method, the signatures
    represent the status for every member of the family at a given loci.
    
    A sample matrix could be:
    +-------+-------+--------+--------+-------+
    | Twin1 | Twin2 | Mother | Father | Count |
    +=======+=======+========+========+=======+
    |   \+  |  \+   |   \-   |    0   |   52  |
    +-------+-------+--------+--------+-------+
    |   0   |  \-   |   \-   |    0   |  105  |
    +-------+-------+--------+--------+-------+
    |   \-  |  \-   |   \-   |   \-   |   21  |
    +-------+-------+--------+--------+-------+
    Which says that at 52 loci, both twins had gains, the mother had a deletion
    and the father had no detected CNV.
    Same reasoning goes for the two other signatures.
    """
    profiles = {}
    # profile (e.g. +0+0) to array of mean coverage.
    coverage = {}
    indexes = {
                    "twin1": 0,
                    "twin2": 1,
                    "mother": 2,
                    "father": 3,
              }
    symbols = {
                    "loss": '-',
                    "gain": '+'
              }
    depth_mode = True
    for cc in graph:
        if not hasattr(cc, "region_doc"):
            depth_mode = False
        if depth_mode:
            delta = cc.region_doc - cc.out_doc
        # All no calls at first
        profile = ['0', ] * 4
        for vertex in cc.adj_list:
            cnv = vertex.cnv
            profile[indexes[cnv.source]] = symbols[cnv.type]
        profile = "".join(profile)
        if depth_mode:
            if coverage.get(profile) is None:
                coverage[profile] = delta
            else:
                coverage[profile] = np.mean(
                    np.array([coverage[profile], delta]),
                    axis = 0
                )
            
        if profiles.get(profile) is None:
            profiles[profile] = 1
        profiles[profile] += 1
    if depth_mode:
        # Dcov if the difference in coverage from in region to out of the
        # region: cov_in - cov_out => negative if deletion.
        print ("Twin1\tTwin2\tMother\tFather\tCount\tDcov_twin1\t"
               "Dcov_twin2\tDcov_mother\tDcov_father")
        for key in profiles:
            print "{}\t{}\t{}".format(
                "\t".join(list(key)), 
                profiles[key],
                "\t".join([str(i) for i in coverage[key]])
            )
    else:
        print "Twin1\tTwin2\tMother\tFather"
        for key in profiles:
            print "{}\t{}".format(
                "\t".join(list(key)),
                profiles[key],
            )
 
[docs]def clean_ccs(ccs):
    """Merges CNVs that have the same source (sample). 
    :param ccs: The graph as a list of Connected Components.
    :type ccs: list
    This is used so that connected components represent families with a single
    representation for every individual. Thus, we merge indirectly overlapping
    loci, meaning that if two CNVs from an individual are both overlapped by
    CNVs from another individual within a family, they will be merged.
    """
    empty_ccs = set()
    inspected = set()
    for cc in ccs:
        to_remove = set()
        to_add = set()
        for i in xrange(len(cc.adj_list)):
            v1 = cc.adj_list[i]
            cnv1 = v1.cnv
            inspected.add(i)
            for j in xrange(len(cc.adj_list)):
                if j not in inspected:
                    v2 = cc.adj_list[j]
                    cnv2 = v2.cnv
                    if cnv1.source == cnv2.source and cnv1.type == cnv2.type:
                        assert cnv1.chr == cnv2.chr
                        params = {
                            "start": min(cnv1.start, cnv2.start),
                            "end": max(cnv1.end, cnv2.end),
                            "algo": "merged",
                            "source": cnv1.source,
                            "chr": cnv1.chr,
                            "type": cnv1.type,
                        }
                        new_cnv = cnv_struct.cnv(**params)
                        new_vertex = Vertex(new_cnv)
                        to_remove.add(v1)
                        to_remove.add(v2)
                        to_add.add(new_vertex)
                        inspected.add(j)
        for v in to_remove:
            cc.remove(v)
        for v in to_add:
            cc.add(v)
        if len(cc.adj_list) == 0:
            empty_ccs.add(cc)
    for cc in empty_ccs:
        ccs.remove(cc)
    return ccs
     
[docs]def create_family_graph(cnvs, threshold):
    """Creates the graph representing the CNVs from a given family as connected
    components.
    This graph is defined as follows. The nodes represent CNVs and the edges
    represent overlap between CNVs. The complete graph is thus made of multiple
    connected components representing different loci. 
    """
    connected_components = []
    # Transform dict in list if necessary
    if type(cnvs) is not list:
        cnvs = parsers.family_to_list(cnvs)
    for cnv in cnvs:
        if cnv.chr in ["chrX", "chrY"]:
            continue
        matching_ccs = []
        for cc in connected_components:
            if cc.suitable(cnv):
                # If cnv is suitable for a cc, add it to the list.
                matching_ccs.append(cc)
        if len(matching_ccs) > 1:
            # Merge the matching ccs, using this cnv.
            master = merge_ccs(matching_ccs, cnv)
            for cc in matching_ccs:
                connected_components.remove(cc)
            connected_components.append(master)
        elif len(matching_ccs) == 1:
            # Add this CNV to the only matching cc.
            matching_ccs[0].add(cnv)
        elif len(matching_ccs) == 0:
            # Create a new CC for this cnv.
            c = ConnectedComponent(cnv, threshold)
            connected_components.append(c)
    return connected_components
 
[docs]def merge_ccs(cc_list, cnv):
    """Merges connected components by using their respective overlap to cnv. """
    master = cc_list[0]
    min_start = master.start
    max_end = master.end
    for cc in cc_list[1:]:
        if cc.start < min_start:
            min_start = cc.start
        if cc.end > max_end:
            max_end = cc.end
        master.adj_list += cc.adj_list
    master.start = min_start
    master.end = max_end
    master.add(cnv)
    return master
 
[docs]class ConnectedComponent(object):
    """A class representing a graph-theoretic connected component.
    :param first_vertex: The first vertex is either the first CNV in the
                         connected component or the first 
                         :py:class:`generate_cnv_graph.Vertex` object.
                         Internally, whenever a CNV is added to a connected
                         component, they are converted to instances of this 
                         class.
    :type first_vertex: Either :py:class:`cnv_struct.cnv` or
                        :py:class:`generate_cnv_graph.Vertex`
    :param overlap_threshold: An overlap threshold that defines if two CNVs will
                              be considered identical.
    :type overlap_threshold: float (between 0 and 1)
    Internally, a ConnectedComponent object knows where the represented loci
    starts, ends, on which chromosome, the overlap threshold used and represents
    its set of vertices as a list (the ``adj_list`` attribute).
    """
    def __init__(self, first_vertex, overlap_threshold = 0.7):
        if type(first_vertex) is cnv_struct.cnv:
            first_vertex = Vertex(first_vertex)
        elif not type(first_vertex) is Vertex:
            raise TypeError("Invalid type for new ConnectedComponent: ".format(
                str(type(first_vertex))))
            
        self.start = first_vertex.cnv.start
        self.end = first_vertex.cnv.end
        self.chr = first_vertex.cnv.chr
        self.adj_list = [first_vertex, ]
        self.overlap_threshold = overlap_threshold
[docs]    def cnv_generator(self):
        """Returns a generator (iterable) of cnv objects from the vertices of 
           this connected component. """
        return (i.cnv for i in self.adj_list)
 
[docs]    def suitable(self, cnv):
        """Determines if the given cnv should be added to this connected 
           component. """
        ov = cnv_struct.ro_to_best_in_list(cnv, self.cnv_generator())
        return (min(ov) > self.overlap_threshold)
 
[docs]    def remove(self, vertex):
        """Removes a vertex from this connected component. """
        # Remove any edge to this vertex.
        for other_vertex in self.adj_list:
            to_remove = set()
            for edge in other_vertex.edges:
                if edge.v1 == vertex:
                    to_remove.add(edge)
                elif edge.v2 == vertex:
                    to_remove.add(edge)
            for edge in to_remove:
                other_vertex.edges.remove(edge)
        self.adj_list.remove(vertex)
 
[docs]    def add(self, cnv):
        """Adds the given cnv to the adjacency list. """
        add = False
        # For every vertex already in graph.
        if type(cnv) is cnv_struct.cnv:
            new_vertex = Vertex(cnv)
        elif type(cnv) is Vertex:
            new_vertex = cnv
        else:
            raise TypeError("Cannot add a '{}' object to the ConnectedComponent "
                            "only Vertex and cnv_struct.cnv classes are "
                            "allowed".format(type(cnv)))
        for v in self.adj_list:
            # If it overlaps with the new guy
            if cnv_struct.overlap(new_vertex.cnv, 
                                  v.cnv, 
                                  self.overlap_threshold):
                # Build the edge, add it to every vertice.
                weight = cnv_struct.overlap(
                    new_vertex.cnv, 
                    v.cnv,
                    global_ov = True
                )
                e = Edge(new_vertex, v, weight)
                new_vertex.add_edge(e)
                v.add_edge(e)
                add = True
                
        if add:
            self.adj_list.append(new_vertex)
            self.start = min(self.start, new_vertex.cnv.start)
            self.end = max(self.end, new_vertex.cnv.end)
        if not add:
            print ("Warning: No overlapping vertex in connected component. "
                   "the cnv was not added.")
 
    def __repr__(self):
        s = "ConnectedComponent:\n"
        s += pprint.pformat(self.adj_list)
        return s
 
[docs]class Vertex(object):
    """A container class for CNVs which allows Edge objects to link overlapping
    CNVs together.
    """
    def __init__(self, cnv_obj):
        """Initialize this vertex with the given CNV object.
        :param cnv_obj: The CNV object for this Vertex.
        :type cnv_obj: :py:class:`cnv_struct.cnv`
        """
        self.cnv = cnv_obj
        self.edges = []
[docs]    def add_edge(self, e):
        """Add an edge between vertices. 
        :param e: The edge object to add to this vertex.
        :type e: :py:class:`Edge`
        Concretely, this is used to link overlapping CNVs together.
        """
        self.edges.append(e)
 
    def __repr__(self):
        return "<Vertex ({})>".format(self.cnv)
 
[docs]class Edge(object):
    """A simple Edge object that links two vertices together. Edges are weighted.
    Edges should be generated with this class and added using
    :py:func:`generate_cnv_graph.add_edge`
    """
    def __init__(self, v1, v2, w):
        self.v1 = v1
        self.v2 = v2
        self.weight = w
 
def test():
    cnv1 = cnv_struct.cnv(pos = "chr1:2-5", type="gain", source="father")
    cnv2 = cnv_struct.cnv(pos = "chr1:3-10", type="gain", source="twin1")
    cnv3 = cnv_struct.cnv(pos = "chr1:20-25", type="gain", source="twin1")
    cnv4 = cnv_struct.cnv(pos = "chr1:27-30", type="gain", source="twin2")
    cnv5 = cnv_struct.cnv(pos = "chr1:24-28", type="gain", source="mother")
    cnv6 = cnv_struct.cnv(pos = "chr2:24-28", type="gain", source="mother")
    cnv7 = cnv_struct.cnv(pos = "chr2:19-24", type="gain", source="father")
    cnv8 = cnv_struct.cnv(pos = "chr2:25-35", type="gain", source="twin1")
    cnv9 = cnv_struct.cnv(pos = "chr2:34-42", type="gain", source="father")
    li = [cnv1, cnv2, cnv3, cnv4, cnv5, cnv6, cnv7, cnv8, cnv9]
    connected_components = create_family_graph(li, 0)
    connected_components = clean_ccs(connected_components)
    # Build ccs manually
    cc1 = ConnectedComponent(cnv1, 0)
    cc1.add(cnv2)
    cc2 = ConnectedComponent(cnv3, 0)
    cc2.add(cnv4)
    cc2.add(cnv5)
    cc3 = ConnectedComponent(cnv6, 0)
    cc3.add(cnv7)
    cc3.add(cnv8)
    for cc in connected_components: print cc
    print
    for cc in [cc1, cc2, cc3]: print cc
if __name__ == "__main__":
    desc = "Generates the cnv graph used to build the summary matrix."
    parser = argparse.ArgumentParser(description = desc)
    parser = validation.add_pickle_args(parser)
    parser = validation.add_dir_args(parser)
    parser.add_argument("--threshold",
                        type = float,
                        default = 0.7,
                        help = "Overlap threshold."
                       )
    parser.add_argument("--save",
                        type = str,
                        default = None,
                        help = "Output pickle if you want to keep the graph."
                       )
    for i in ("twin1", "twin2", "mother", "father"):
        parser.add_argument(
            "--{}_pileup".format(i),
            type = str,
            default = None,
            help = "Path to pileup for {}".format(i)             
        )
    main(parser.parse_args())