Skip to content
Snippets Groups Projects
Commit 76474bcf authored by elabaron's avatar elabaron
Browse files

add src/collapse_annotation.py

parent f65a0eef
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3
# Author: Francois Aguet
import numpy as np
import pandas as pd
from collections import defaultdict
from bx.intervals.intersection import IntervalTree
import argparse
import os
import gzip
class Exon:
def __init__(self, exon_id, number, transcript, start_pos, end_pos):
self.id = exon_id
self.number = int(number)
self.transcript = transcript
self.start_pos = start_pos
self.end_pos = end_pos
class Transcript:
def __init__(self, transcript_id, transcript_name, transcript_type, gene, start_pos, end_pos):
self.id = transcript_id
self.name = transcript_name
self.type = transcript_type
self.gene = gene
self.start_pos = start_pos
self.end_pos = end_pos
self.exons = []
class Gene:
def __init__(self, gene_id, gene_name, gene_type, chrom, strand, start_pos, end_pos):
self.id = gene_id
self.name = gene_name
self.biotype = gene_type
self.chr = chrom
self.strand = strand
self.start_pos = start_pos
self.end_pos = end_pos
self.transcripts = []
class Annotation:
def __init__(self, gtfpath):
"""Parse GTF and construct gene/transcript/exon hierarchy"""
if gtfpath.endswith('.gtf.gz'):
opener = gzip.open(gtfpath, 'rt')
else:
opener = open(gtfpath, 'r')
self.genes = []
with opener as gtf:
for row in gtf:
row = row.strip().split('\t')
if row[0][0]=='#': continue # skip header
chrom = row[0]
annot_type = row[2]
start_pos = int(row[3])
end_pos = int(row[4])
strand = row[6]
attributes = defaultdict()
for a in row[8].replace('"', '').replace('_biotype', '_type').split(';')[:-1]:
kv = a.strip().split(' ')
if kv[0]!='tag':
attributes[kv[0]] = kv[1]
else:
attributes.setdefault('tags', []).append(kv[1])
if annot_type=='gene':
assert 'gene_id' in attributes
if 'gene_name' not in attributes:
attributes['gene_name'] = attributes['gene_id']
gene_id = attributes['gene_id']
g = Gene(gene_id, attributes['gene_name'], attributes['gene_type'], chrom, strand, start_pos, end_pos)
g.source = row[1]
g.phase = row[7]
g.attributes_string = row[8].replace('_biotype', '_type')
self.genes.append(g)
elif annot_type=='transcript':
assert 'transcript_id' in attributes
if 'transcript_name' not in attributes:
attributes['transcript_name'] = attributes['transcript_id']
transcript_id = attributes['transcript_id']
t = Transcript(attributes.pop('transcript_id'), attributes.pop('transcript_name'),
attributes.pop('transcript_type'), g, start_pos, end_pos)
t.attributes = attributes
g.transcripts.append(t)
elif annot_type=='exon':
if 'exon_id' in attributes:
e = Exon(attributes['exon_id'], attributes['exon_number'], t, start_pos, end_pos)
else:
e = Exon(str(len(t.exons)+1), len(t.exons)+1, t, start_pos, end_pos)
t.exons.append(e)
if np.mod(len(self.genes),1000)==0:
print('Parsing GTF: {0:d} genes processed\r'.format(len(self.genes)), end='\r')
print('Parsing GTF: {0:d} genes processed\r'.format(len(self.genes)))
self.genes = np.array(self.genes)
def interval_union(intervals):
"""
Returns the union of all intervals in the input list
intervals: list of tuples or 2-element lists
"""
intervals.sort(key=lambda x: x[0])
union = [intervals[0]]
for i in intervals[1:]:
if i[0] <= union[-1][1]: # overlap w/ previous
if i[1] > union[-1][1]: # only extend if larger
union[-1][1] = i[1]
else:
union.append(i)
return union
def subtract_segment(a, b):
"""
Subtract segment a from segment b,
return 'a' if no overlap
"""
if a[0]>=b[0] and a[0]<=b[1] and a[1]>b[1]:
return (b[1]+1,a[1])
elif a[0]<b[0] and a[1]>=b[0] and a[1]<=b[1]:
return (a[0], b[0]-1)
elif a[0]<b[0] and a[1]>b[1]:
return [(a[0],b[0]-1), (b[1]+1,a[1])]
elif a[0]>=b[0] and a[1]<=b[1]:
return []
else:
return a
def add_transcript_attributes(attributes_string):
"""
Adds transcript attributes if they were missing
(see https://www.gencodegenes.org/pages/data_format.html)
'status' fields were dropped in Gencode 26 and later
"""
# GTF specification
if 'gene_status' in attributes_string:
attribute_order = ['gene_id', 'transcript_id', 'gene_type', 'gene_status', 'gene_name',
'transcript_type', 'transcript_status', 'transcript_name']
add_list = ['transcript_id', 'transcript_type', 'transcript_status', 'transcript_name']
else:
attribute_order = ['gene_id', 'transcript_id', 'gene_type', 'gene_name', 'transcript_type', 'transcript_name']
add_list = ['transcript_id', 'transcript_type', 'transcript_name']
if 'level' in attributes_string:
attribute_order += ['level']
attr = attributes_string.strip(';').split('; ')
req = []
opt = []
for k in attr:
if k.split()[0] in attribute_order:
req.append(k)
else:
opt.append(k)
attr_dict = {i.split()[0]:i.split()[1].replace(';','') for i in req}
if 'gene_name' not in attr_dict:
attr_dict['gene_name'] = attr_dict['gene_id']
if 'transcript_id' not in attr_dict:
attr_dict['transcript_id'] = attr_dict['gene_id']
for k in add_list:
if k not in attr_dict:
attr_dict[k] = attr_dict[k.replace('transcript', 'gene')]
return '; '.join([k+' '+attr_dict[k] for k in attribute_order] + opt)+';'
def collapse_annotation(annot, transcript_gtf, collapsed_gtf, blacklist=set(), collapse_only=False):
"""
Collapse transcripts into a single gene model; remove overlapping intervals
"""
exclude = set(['retained_intron', 'readthrough_transcript'])
# 1) 1st pass: collapse each gene, excluding blacklisted transcript types
merged_coord_dict = {}
for g in annot.genes:
exon_coords = []
for t in g.transcripts:
if (t.id not in blacklist) and (t.type!='retained_intron') and (('tags' not in t.attributes) or len(set(t.attributes['tags']).intersection(exclude))==0):
for e in t.exons:
exon_coords.append([e.start_pos, e.end_pos])
if exon_coords:
merged_coord_dict[g.id] = interval_union(exon_coords)
if not collapse_only:
# 2) build interval tree with merged domains
interval_trees = defaultdict()
for g in annot.genes:
if g.id in merged_coord_dict:
for i in merged_coord_dict[g.id]:
# half-open intervals [a,b)
interval_trees.setdefault(g.chr, IntervalTree()).add(i[0], i[1]+1, [i, g.id])
# 3) query intervals of each gene, remove overlaps
new_coord_dict = {}
for g in annot.genes:
if g.id in merged_coord_dict:
new_intervals = []
for i in merged_coord_dict[g.id]: # loop merged exons
ints = interval_trees[g.chr].find(i[0], i[1]+1)
# remove self
ints = [r[0] for r in ints if r[1]!=g.id]
m = set([tuple(i)])
for v in ints:
m = [subtract_segment(mx, v) for mx in m]
# flatten
m0 = []
for k in m:
if isinstance(k, tuple):
m0.append(k)
else:
m0.extend(k)
m = m0
new_intervals.extend(m)
if new_intervals:
new_coord_dict[g.id] = new_intervals
# 4) remove genes containing single-base exons only
for g in annot.genes:
if g.id in new_coord_dict:
exon_lengths = np.array([i[1]-i[0]+1 for i in new_coord_dict[g.id]])
if np.all(exon_lengths==1):
new_coord_dict.pop(g.id)
else:
new_coord_dict = merged_coord_dict
# 5) write to GTF
if transcript_gtf.endswith('.gtf.gz'):
opener = gzip.open(transcript_gtf, 'rt')
else:
opener = open(transcript_gtf, 'r')
with open(collapsed_gtf, 'w') as output_gtf, opener as input_gtf:
# copy header
for line in input_gtf:
if line[:2]=='##' or line[:2]=='#!':
output_gtf.write(line)
comment = line[:2]
else:
break
output_gtf.write(comment+'collapsed version generated by GTEx pipeline\n')
for g in annot.genes:
if g.id in new_coord_dict:
start_pos = str(np.min([i[0] for i in new_coord_dict[g.id]]))
end_pos = str(np.max([i[1] for i in new_coord_dict[g.id]]))
if 'transcript_id' in g.attributes_string:
attr = g.attributes_string
else:
attr = add_transcript_attributes(g.attributes_string)
output_gtf.write('\t'.join([g.chr, g.source, 'gene', start_pos, end_pos, '.', g.strand, g.phase, attr])+'\n')
output_gtf.write('\t'.join([g.chr, g.source, 'transcript', start_pos, end_pos, '.', g.strand, g.phase, attr])+'\n')
if g.strand=='-':
new_coord_dict[g.id] = new_coord_dict[g.id][::-1]
for k,i in enumerate(new_coord_dict[g.id], 1):
output_gtf.write('\t'.join([
g.chr, g.source, 'exon', str(i[0]), str(i[1]), '.', g.strand, g.phase,
attr+' exon_id "'+g.id+'_{0:d}; exon_number {0:d}";'.format(k)])+'\n')
if __name__=='__main__':
parser = argparse.ArgumentParser(description='Collapse isoforms into single transcript per gene and remove overlapping intervals between genes')
parser.add_argument('transcript_gtf', help='Transcript annotation in GTF format')
parser.add_argument('output_gtf', help='Name of the output file')
parser.add_argument('--transcript_blacklist', help='List of transcripts to exclude (e.g., unannotated readthroughs)')
parser.add_argument('--collapse_only', action='store_true', help='')
args = parser.parse_args()
annotation = Annotation(args.transcript_gtf)
if args.transcript_blacklist:
blacklist_df = pd.read_csv(args.transcript_blacklist, sep='\t')
blacklist = set(blacklist_df[blacklist_df.columns[0]].values)
else:
blacklist = set()
print('Collapsing transcripts')
collapse_annotation(annotation, args.transcript_gtf, args.output_gtf, blacklist=blacklist, collapse_only=args.collapse_only)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment