diff --git a/src/collapse_annotation.py b/src/collapse_annotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cba9e5ef29309f8da27a07da8a4fa4eb0897c20d
--- /dev/null
+++ b/src/collapse_annotation.py
@@ -0,0 +1,287 @@
+#!/usr/bin/env python3
+# Author: Francois Aguet
+import numpy as np
+import pandas as pd
+from collections import defaultdict
+from bx.intervals.intersection import IntervalTree
+import argparse
+import os
+import gzip
+
+
+class Exon:
+    def __init__(self, exon_id, number, transcript, start_pos, end_pos):
+        self.id = exon_id
+        self.number = int(number)
+        self.transcript = transcript
+        self.start_pos = start_pos
+        self.end_pos = end_pos
+
+class Transcript:
+    def __init__(self, transcript_id, transcript_name, transcript_type, gene, start_pos, end_pos):
+        self.id = transcript_id
+        self.name = transcript_name
+        self.type = transcript_type
+        self.gene = gene
+        self.start_pos = start_pos
+        self.end_pos = end_pos
+        self.exons = []
+
+class Gene:
+    def __init__(self, gene_id, gene_name, gene_type, chrom, strand, start_pos, end_pos):
+        self.id = gene_id
+        self.name = gene_name
+        self.biotype = gene_type
+        self.chr = chrom
+        self.strand = strand
+        self.start_pos = start_pos
+        self.end_pos = end_pos
+        self.transcripts = []
+
+class Annotation:
+    def __init__(self, gtfpath):
+        """Parse GTF and construct gene/transcript/exon hierarchy"""
+
+        if gtfpath.endswith('.gtf.gz'):
+            opener = gzip.open(gtfpath, 'rt')
+        else:
+            opener = open(gtfpath, 'r')
+
+        self.genes = []
+        with opener as gtf:
+            for row in gtf:
+                row = row.strip().split('\t')
+
+                if row[0][0]=='#': continue # skip header
+
+                chrom = row[0]
+                annot_type = row[2]
+                start_pos = int(row[3])
+                end_pos  = int(row[4])
+                strand = row[6]
+
+                attributes = defaultdict()
+                for a in row[8].replace('"', '').replace('_biotype', '_type').split(';')[:-1]:
+                    kv = a.strip().split(' ')
+                    if kv[0]!='tag':
+                        attributes[kv[0]] = kv[1]
+                    else:
+                        attributes.setdefault('tags', []).append(kv[1])
+
+                if annot_type=='gene':
+                    assert 'gene_id' in attributes
+                    if 'gene_name' not in attributes:
+                        attributes['gene_name'] = attributes['gene_id']
+                    gene_id = attributes['gene_id']
+                    g = Gene(gene_id, attributes['gene_name'], attributes['gene_type'], chrom, strand, start_pos, end_pos)
+                    g.source = row[1]
+                    g.phase = row[7]
+                    g.attributes_string = row[8].replace('_biotype', '_type')
+                    self.genes.append(g)
+
+                elif annot_type=='transcript':
+                    assert 'transcript_id' in attributes
+                    if 'transcript_name' not in attributes:
+                        attributes['transcript_name'] = attributes['transcript_id']
+                    transcript_id = attributes['transcript_id']
+                    t = Transcript(attributes.pop('transcript_id'), attributes.pop('transcript_name'),
+                                   attributes.pop('transcript_type'), g, start_pos, end_pos)
+                    t.attributes = attributes
+                    g.transcripts.append(t)
+
+                elif annot_type=='exon':
+                    if 'exon_id' in attributes:
+                        e = Exon(attributes['exon_id'], attributes['exon_number'], t, start_pos, end_pos)
+                    else:
+                        e = Exon(str(len(t.exons)+1), len(t.exons)+1, t, start_pos, end_pos)
+                    t.exons.append(e)
+
+                if np.mod(len(self.genes),1000)==0:
+                    print('Parsing GTF: {0:d} genes processed\r'.format(len(self.genes)), end='\r')
+            print('Parsing GTF: {0:d} genes processed\r'.format(len(self.genes)))
+
+        self.genes = np.array(self.genes)
+
+
+def interval_union(intervals):
+    """
+    Returns the union of all intervals in the input list
+      intervals: list of tuples or 2-element lists
+    """
+    intervals.sort(key=lambda x: x[0])
+    union = [intervals[0]]
+    for i in intervals[1:]:
+        if i[0] <= union[-1][1]:  # overlap w/ previous
+            if i[1] > union[-1][1]:  # only extend if larger
+                union[-1][1] = i[1]
+        else:
+            union.append(i)
+    return union
+
+
+def subtract_segment(a, b):
+    """
+    Subtract segment a from segment b,
+      return 'a' if no overlap
+    """
+    if a[0]>=b[0] and a[0]<=b[1] and a[1]>b[1]:
+        return (b[1]+1,a[1])
+    elif a[0]<b[0] and a[1]>=b[0] and a[1]<=b[1]:
+        return (a[0], b[0]-1)
+    elif a[0]<b[0] and a[1]>b[1]:
+        return [(a[0],b[0]-1), (b[1]+1,a[1])]
+    elif a[0]>=b[0] and a[1]<=b[1]:
+        return []
+    else:
+        return a
+
+
+def add_transcript_attributes(attributes_string):
+    """
+    Adds transcript attributes if they were missing
+    (see https://www.gencodegenes.org/pages/data_format.html)
+    'status' fields were dropped in Gencode 26 and later
+    """
+    # GTF specification
+    if 'gene_status' in attributes_string:
+        attribute_order = ['gene_id', 'transcript_id', 'gene_type', 'gene_status', 'gene_name',
+                           'transcript_type', 'transcript_status', 'transcript_name']
+        add_list = ['transcript_id', 'transcript_type', 'transcript_status', 'transcript_name']
+    else:
+        attribute_order = ['gene_id', 'transcript_id', 'gene_type', 'gene_name', 'transcript_type', 'transcript_name']
+        add_list = ['transcript_id', 'transcript_type', 'transcript_name']
+    if 'level' in attributes_string:
+        attribute_order += ['level']
+
+    attr = attributes_string.strip(';').split('; ')
+    req = []
+    opt = []
+    for k in attr:
+        if k.split()[0] in attribute_order:
+            req.append(k)
+        else:
+            opt.append(k)
+    attr_dict = {i.split()[0]:i.split()[1].replace(';','') for i in req}
+    if 'gene_name' not in attr_dict:
+        attr_dict['gene_name'] = attr_dict['gene_id']
+    if 'transcript_id' not in attr_dict:
+        attr_dict['transcript_id'] = attr_dict['gene_id']
+    for k in add_list:
+        if k not in attr_dict:
+            attr_dict[k] = attr_dict[k.replace('transcript', 'gene')]
+
+    return '; '.join([k+' '+attr_dict[k] for k in attribute_order] + opt)+';'
+
+
+def collapse_annotation(annot, transcript_gtf, collapsed_gtf, blacklist=set(), collapse_only=False):
+    """
+    Collapse transcripts into a single gene model; remove overlapping intervals
+    """
+
+    exclude = set(['retained_intron', 'readthrough_transcript'])
+
+    # 1) 1st pass: collapse each gene, excluding blacklisted transcript types
+    merged_coord_dict = {}
+    for g in annot.genes:
+        exon_coords = []
+        for t in g.transcripts:
+            if (t.id not in blacklist) and (t.type!='retained_intron') and (('tags' not in t.attributes) or len(set(t.attributes['tags']).intersection(exclude))==0):
+                for e in t.exons:
+                    exon_coords.append([e.start_pos, e.end_pos])
+        if exon_coords:
+            merged_coord_dict[g.id] = interval_union(exon_coords)
+
+    if not collapse_only:
+        # 2) build interval tree with merged domains
+        interval_trees = defaultdict()
+        for g in annot.genes:
+            if g.id in merged_coord_dict:
+                for i in merged_coord_dict[g.id]:
+                    # half-open intervals [a,b)
+                    interval_trees.setdefault(g.chr, IntervalTree()).add(i[0], i[1]+1, [i, g.id])
+
+        # 3) query intervals of each gene, remove overlaps
+        new_coord_dict = {}
+        for g in annot.genes:
+            if g.id in merged_coord_dict:
+                new_intervals = []
+                for i in merged_coord_dict[g.id]:  # loop merged exons
+                    ints = interval_trees[g.chr].find(i[0], i[1]+1)
+                    # remove self
+                    ints = [r[0] for r in ints if r[1]!=g.id]
+                    m = set([tuple(i)])
+                    for v in ints:
+                        m = [subtract_segment(mx, v) for mx in m]
+                        # flatten
+                        m0 = []
+                        for k in m:
+                            if isinstance(k, tuple):
+                                m0.append(k)
+                            else:
+                                m0.extend(k)
+                        m = m0
+                    new_intervals.extend(m)
+                if new_intervals:
+                    new_coord_dict[g.id] = new_intervals
+
+        # 4) remove genes containing single-base exons only
+        for g in annot.genes:
+            if g.id in new_coord_dict:
+                exon_lengths = np.array([i[1]-i[0]+1 for i in new_coord_dict[g.id]])
+                if np.all(exon_lengths==1):
+                    new_coord_dict.pop(g.id)
+    else:
+        new_coord_dict = merged_coord_dict
+
+    # 5) write to GTF
+    if transcript_gtf.endswith('.gtf.gz'):
+        opener = gzip.open(transcript_gtf, 'rt')
+    else:
+        opener = open(transcript_gtf, 'r')
+
+    with open(collapsed_gtf, 'w') as output_gtf, opener as input_gtf:
+        # copy header
+        for line in input_gtf:
+            if line[:2]=='##' or line[:2]=='#!':
+                output_gtf.write(line)
+                comment = line[:2]
+            else:
+                break
+        output_gtf.write(comment+'collapsed version generated by GTEx pipeline\n')
+        for g in annot.genes:
+            if g.id in new_coord_dict:
+                start_pos = str(np.min([i[0] for i in new_coord_dict[g.id]]))
+                end_pos = str(np.max([i[1] for i in new_coord_dict[g.id]]))
+                if 'transcript_id' in g.attributes_string:
+                    attr = g.attributes_string
+                else:
+                    attr = add_transcript_attributes(g.attributes_string)
+                output_gtf.write('\t'.join([g.chr, g.source, 'gene', start_pos, end_pos, '.', g.strand, g.phase, attr])+'\n')
+                output_gtf.write('\t'.join([g.chr, g.source, 'transcript', start_pos, end_pos, '.', g.strand, g.phase, attr])+'\n')
+                if g.strand=='-':
+                    new_coord_dict[g.id] = new_coord_dict[g.id][::-1]
+                for k,i in enumerate(new_coord_dict[g.id], 1):
+                    output_gtf.write('\t'.join([
+                        g.chr, g.source, 'exon', str(i[0]), str(i[1]), '.', g.strand, g.phase,
+                        attr+' exon_id "'+g.id+'_{0:d}; exon_number {0:d}";'.format(k)])+'\n')
+
+
+if __name__=='__main__':
+
+    parser = argparse.ArgumentParser(description='Collapse isoforms into single transcript per gene and remove overlapping intervals between genes')
+    parser.add_argument('transcript_gtf', help='Transcript annotation in GTF format')
+    parser.add_argument('output_gtf', help='Name of the output file')
+    parser.add_argument('--transcript_blacklist', help='List of transcripts to exclude (e.g., unannotated readthroughs)')
+    parser.add_argument('--collapse_only', action='store_true', help='')
+    args = parser.parse_args()
+
+    annotation = Annotation(args.transcript_gtf)
+
+    if args.transcript_blacklist:
+        blacklist_df = pd.read_csv(args.transcript_blacklist, sep='\t')
+        blacklist = set(blacklist_df[blacklist_df.columns[0]].values)
+    else:
+        blacklist = set()
+
+    print('Collapsing transcripts')
+    collapse_annotation(annotation, args.transcript_gtf, args.output_gtf, blacklist=blacklist, collapse_only=args.collapse_only)