diff --git a/src/collapse_annotation.py b/src/collapse_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..cba9e5ef29309f8da27a07da8a4fa4eb0897c20d --- /dev/null +++ b/src/collapse_annotation.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# Author: Francois Aguet +import numpy as np +import pandas as pd +from collections import defaultdict +from bx.intervals.intersection import IntervalTree +import argparse +import os +import gzip + + +class Exon: + def __init__(self, exon_id, number, transcript, start_pos, end_pos): + self.id = exon_id + self.number = int(number) + self.transcript = transcript + self.start_pos = start_pos + self.end_pos = end_pos + +class Transcript: + def __init__(self, transcript_id, transcript_name, transcript_type, gene, start_pos, end_pos): + self.id = transcript_id + self.name = transcript_name + self.type = transcript_type + self.gene = gene + self.start_pos = start_pos + self.end_pos = end_pos + self.exons = [] + +class Gene: + def __init__(self, gene_id, gene_name, gene_type, chrom, strand, start_pos, end_pos): + self.id = gene_id + self.name = gene_name + self.biotype = gene_type + self.chr = chrom + self.strand = strand + self.start_pos = start_pos + self.end_pos = end_pos + self.transcripts = [] + +class Annotation: + def __init__(self, gtfpath): + """Parse GTF and construct gene/transcript/exon hierarchy""" + + if gtfpath.endswith('.gtf.gz'): + opener = gzip.open(gtfpath, 'rt') + else: + opener = open(gtfpath, 'r') + + self.genes = [] + with opener as gtf: + for row in gtf: + row = row.strip().split('\t') + + if row[0][0]=='#': continue # skip header + + chrom = row[0] + annot_type = row[2] + start_pos = int(row[3]) + end_pos = int(row[4]) + strand = row[6] + + attributes = defaultdict() + for a in row[8].replace('"', '').replace('_biotype', '_type').split(';')[:-1]: + kv = a.strip().split(' ') + if kv[0]!='tag': + attributes[kv[0]] = kv[1] + else: + attributes.setdefault('tags', []).append(kv[1]) + + if annot_type=='gene': + assert 'gene_id' in attributes + if 'gene_name' not in attributes: + attributes['gene_name'] = attributes['gene_id'] + gene_id = attributes['gene_id'] + g = Gene(gene_id, attributes['gene_name'], attributes['gene_type'], chrom, strand, start_pos, end_pos) + g.source = row[1] + g.phase = row[7] + g.attributes_string = row[8].replace('_biotype', '_type') + self.genes.append(g) + + elif annot_type=='transcript': + assert 'transcript_id' in attributes + if 'transcript_name' not in attributes: + attributes['transcript_name'] = attributes['transcript_id'] + transcript_id = attributes['transcript_id'] + t = Transcript(attributes.pop('transcript_id'), attributes.pop('transcript_name'), + attributes.pop('transcript_type'), g, start_pos, end_pos) + t.attributes = attributes + g.transcripts.append(t) + + elif annot_type=='exon': + if 'exon_id' in attributes: + e = Exon(attributes['exon_id'], attributes['exon_number'], t, start_pos, end_pos) + else: + e = Exon(str(len(t.exons)+1), len(t.exons)+1, t, start_pos, end_pos) + t.exons.append(e) + + if np.mod(len(self.genes),1000)==0: + print('Parsing GTF: {0:d} genes processed\r'.format(len(self.genes)), end='\r') + print('Parsing GTF: {0:d} genes processed\r'.format(len(self.genes))) + + self.genes = np.array(self.genes) + + +def interval_union(intervals): + """ + Returns the union of all intervals in the input list + intervals: list of tuples or 2-element lists + """ + intervals.sort(key=lambda x: x[0]) + union = [intervals[0]] + for i in intervals[1:]: + if i[0] <= union[-1][1]: # overlap w/ previous + if i[1] > union[-1][1]: # only extend if larger + union[-1][1] = i[1] + else: + union.append(i) + return union + + +def subtract_segment(a, b): + """ + Subtract segment a from segment b, + return 'a' if no overlap + """ + if a[0]>=b[0] and a[0]<=b[1] and a[1]>b[1]: + return (b[1]+1,a[1]) + elif a[0]<b[0] and a[1]>=b[0] and a[1]<=b[1]: + return (a[0], b[0]-1) + elif a[0]<b[0] and a[1]>b[1]: + return [(a[0],b[0]-1), (b[1]+1,a[1])] + elif a[0]>=b[0] and a[1]<=b[1]: + return [] + else: + return a + + +def add_transcript_attributes(attributes_string): + """ + Adds transcript attributes if they were missing + (see https://www.gencodegenes.org/pages/data_format.html) + 'status' fields were dropped in Gencode 26 and later + """ + # GTF specification + if 'gene_status' in attributes_string: + attribute_order = ['gene_id', 'transcript_id', 'gene_type', 'gene_status', 'gene_name', + 'transcript_type', 'transcript_status', 'transcript_name'] + add_list = ['transcript_id', 'transcript_type', 'transcript_status', 'transcript_name'] + else: + attribute_order = ['gene_id', 'transcript_id', 'gene_type', 'gene_name', 'transcript_type', 'transcript_name'] + add_list = ['transcript_id', 'transcript_type', 'transcript_name'] + if 'level' in attributes_string: + attribute_order += ['level'] + + attr = attributes_string.strip(';').split('; ') + req = [] + opt = [] + for k in attr: + if k.split()[0] in attribute_order: + req.append(k) + else: + opt.append(k) + attr_dict = {i.split()[0]:i.split()[1].replace(';','') for i in req} + if 'gene_name' not in attr_dict: + attr_dict['gene_name'] = attr_dict['gene_id'] + if 'transcript_id' not in attr_dict: + attr_dict['transcript_id'] = attr_dict['gene_id'] + for k in add_list: + if k not in attr_dict: + attr_dict[k] = attr_dict[k.replace('transcript', 'gene')] + + return '; '.join([k+' '+attr_dict[k] for k in attribute_order] + opt)+';' + + +def collapse_annotation(annot, transcript_gtf, collapsed_gtf, blacklist=set(), collapse_only=False): + """ + Collapse transcripts into a single gene model; remove overlapping intervals + """ + + exclude = set(['retained_intron', 'readthrough_transcript']) + + # 1) 1st pass: collapse each gene, excluding blacklisted transcript types + merged_coord_dict = {} + for g in annot.genes: + exon_coords = [] + for t in g.transcripts: + if (t.id not in blacklist) and (t.type!='retained_intron') and (('tags' not in t.attributes) or len(set(t.attributes['tags']).intersection(exclude))==0): + for e in t.exons: + exon_coords.append([e.start_pos, e.end_pos]) + if exon_coords: + merged_coord_dict[g.id] = interval_union(exon_coords) + + if not collapse_only: + # 2) build interval tree with merged domains + interval_trees = defaultdict() + for g in annot.genes: + if g.id in merged_coord_dict: + for i in merged_coord_dict[g.id]: + # half-open intervals [a,b) + interval_trees.setdefault(g.chr, IntervalTree()).add(i[0], i[1]+1, [i, g.id]) + + # 3) query intervals of each gene, remove overlaps + new_coord_dict = {} + for g in annot.genes: + if g.id in merged_coord_dict: + new_intervals = [] + for i in merged_coord_dict[g.id]: # loop merged exons + ints = interval_trees[g.chr].find(i[0], i[1]+1) + # remove self + ints = [r[0] for r in ints if r[1]!=g.id] + m = set([tuple(i)]) + for v in ints: + m = [subtract_segment(mx, v) for mx in m] + # flatten + m0 = [] + for k in m: + if isinstance(k, tuple): + m0.append(k) + else: + m0.extend(k) + m = m0 + new_intervals.extend(m) + if new_intervals: + new_coord_dict[g.id] = new_intervals + + # 4) remove genes containing single-base exons only + for g in annot.genes: + if g.id in new_coord_dict: + exon_lengths = np.array([i[1]-i[0]+1 for i in new_coord_dict[g.id]]) + if np.all(exon_lengths==1): + new_coord_dict.pop(g.id) + else: + new_coord_dict = merged_coord_dict + + # 5) write to GTF + if transcript_gtf.endswith('.gtf.gz'): + opener = gzip.open(transcript_gtf, 'rt') + else: + opener = open(transcript_gtf, 'r') + + with open(collapsed_gtf, 'w') as output_gtf, opener as input_gtf: + # copy header + for line in input_gtf: + if line[:2]=='##' or line[:2]=='#!': + output_gtf.write(line) + comment = line[:2] + else: + break + output_gtf.write(comment+'collapsed version generated by GTEx pipeline\n') + for g in annot.genes: + if g.id in new_coord_dict: + start_pos = str(np.min([i[0] for i in new_coord_dict[g.id]])) + end_pos = str(np.max([i[1] for i in new_coord_dict[g.id]])) + if 'transcript_id' in g.attributes_string: + attr = g.attributes_string + else: + attr = add_transcript_attributes(g.attributes_string) + output_gtf.write('\t'.join([g.chr, g.source, 'gene', start_pos, end_pos, '.', g.strand, g.phase, attr])+'\n') + output_gtf.write('\t'.join([g.chr, g.source, 'transcript', start_pos, end_pos, '.', g.strand, g.phase, attr])+'\n') + if g.strand=='-': + new_coord_dict[g.id] = new_coord_dict[g.id][::-1] + for k,i in enumerate(new_coord_dict[g.id], 1): + output_gtf.write('\t'.join([ + g.chr, g.source, 'exon', str(i[0]), str(i[1]), '.', g.strand, g.phase, + attr+' exon_id "'+g.id+'_{0:d}; exon_number {0:d}";'.format(k)])+'\n') + + +if __name__=='__main__': + + parser = argparse.ArgumentParser(description='Collapse isoforms into single transcript per gene and remove overlapping intervals between genes') + parser.add_argument('transcript_gtf', help='Transcript annotation in GTF format') + parser.add_argument('output_gtf', help='Name of the output file') + parser.add_argument('--transcript_blacklist', help='List of transcripts to exclude (e.g., unannotated readthroughs)') + parser.add_argument('--collapse_only', action='store_true', help='') + args = parser.parse_args() + + annotation = Annotation(args.transcript_gtf) + + if args.transcript_blacklist: + blacklist_df = pd.read_csv(args.transcript_blacklist, sep='\t') + blacklist = set(blacklist_df[blacklist_df.columns[0]].values) + else: + blacklist = set() + + print('Collapsing transcripts') + collapse_annotation(annotation, args.transcript_gtf, args.output_gtf, blacklist=blacklist, collapse_only=args.collapse_only)