From 7e59ad400e0200a05a3783dba8a95f6827f7625a Mon Sep 17 00:00:00 2001
From: Laurent Modolo <laurent.modolo@ens-lyon.fr>
Date: Thu, 16 May 2019 17:04:45 +0200
Subject: [PATCH] rmi_splitter.py: add a Reads class to replace SeqIO

---
 src/rmi_splitter/rmi_splitter.py | 46 +++++++++++++++++++-------------
 1 file changed, 27 insertions(+), 19 deletions(-)

diff --git a/src/rmi_splitter/rmi_splitter.py b/src/rmi_splitter/rmi_splitter.py
index d395576..8c0abdf 100644
--- a/src/rmi_splitter/rmi_splitter.py
+++ b/src/rmi_splitter/rmi_splitter.py
@@ -11,8 +11,6 @@ from contextlib import ExitStack
 import sys
 import getopt
 from gzip import open as gzopen
-from Bio import SeqIO
-from Bio.SeqRecord import SeqRecord
 import yaml  # pip install pyyaml
 from . import suffix_tree
 
@@ -39,8 +37,8 @@ class Reads:
         read a read from an open buffer
         """
         try:
-            self.header = str(fin.readline())
-            self.seq = str(fin.readline())
+            self.header = str(fin.readline())[1:-1]
+            self.seq = str(fin.readline())[:-1]
             fin.readline()
             self.str2qual(str(fin.readline()))
         except IOError as err:
@@ -55,7 +53,8 @@ class Reads:
         del self.qual
         self.qual = list()
         for base_qual in qual:
-            self.qual.append(ord(base_qual)-64)
+            if base_qual != "\n":
+                self.qual.append(ord(base_qual)-64)
 
     def qual2str(self):
         """
@@ -71,7 +70,7 @@ class Reads:
         write a read in an open buffer
         """
         try:
-            fout.write(self.header + "\n")
+            fout.write("@" + self.header + "\n")
             fout.write(self.seq + "\n")
             fout.write("+\n")
             fout.write(self.qual2str() + "\n")
@@ -80,6 +79,16 @@ class Reads:
             fout.close()
             raise err
 
+    def __str__(self):
+        """
+        write a read as an str
+        """
+        reads = "@" + self.header + "\n"
+        reads += self.seq + "\n"
+        reads += "+\n"
+        reads += self.qual2str() + "\n"
+        return reads
+
 
 def load_yaml(path):
     """
@@ -194,9 +203,9 @@ def extract_barcode_pos(reads, start, stop, header):
     stop = stop + 1
     if header:
         if start == 0:
-            seq = reads.head[-stop:]
+            seq = reads.header[-stop:]
             return {'seq': seq, 'qual': [40 for x in range(len(seq))]}
-        seq = reads.head[-stop:-start]
+        seq = reads.header[-stop:-start]
         return {'seq': seq, 'qual': [40 for x in range(len(seq))]}
     if start == 0:
         return {
@@ -240,7 +249,9 @@ def write_umi_in_header(reads, config, adaptator, ntuple=1, verbose=False):
     params: adaptator
     """
     umi = extract_barcode(reads, config, adaptator, ntuple, verbose)
-    reads.header += str('_' + str(umi['seq']))
+    header = reads.header.split(" ")
+    header[0] += "_" + str(umi['seq'])
+    reads.header = " ".join(header)
     return reads
 
 
@@ -318,7 +329,7 @@ def match_barcode(reads, config, adaptator, barcode_dictionary, ntuple=1,
         )
         return barcode_dictionary.search_reads(
             adaptator=adaptator,
-            seq=reads['seq'],
+            seq=read['seq'],
             qual=read['qual'],
             cache=True,
             verbose=verbose
@@ -326,7 +337,7 @@ def match_barcode(reads, config, adaptator, barcode_dictionary, ntuple=1,
     except KeyError:
         if verbose:
             print("error: match_barcode() \"" +
-                  str(read['seq']) +
+                  str(reads.seq) +
                   "\" not found in \"" +
                   str(adaptator) +
                   "\" for reads " +
@@ -334,7 +345,7 @@ def match_barcode(reads, config, adaptator, barcode_dictionary, ntuple=1,
     except IndexError:
         if verbose:
             print("error: match_barcode() \"" +
-                  str(read['seq']) +
+                  str(reads.seq) +
                   "\" not found in \"" +
                   str(adaptator) +
                   "\" for reads " +
@@ -427,15 +438,12 @@ def remove_barcode_pos(reads, start, stop, header, verbose=False):
                 reads.header[(size-start):]
     else:
         if start == 0:
-            reads.qual = \
-                reads.qual[stop:]
+            reads.seq = reads.seq[stop:]
             reads.qual = reads.qual[stop:]
         else:
             stop = stop - 1
-            reads.qual = \
-                reads.qual[0:start] + \
-                reads.qual[stop:]
             reads.seq = reads.seq[0:start] + reads.seq[stop:]
+            reads.qual = reads.qual[0:start] + reads.qual[stop:]
     return reads
 
 
@@ -522,7 +530,7 @@ def remove_barcodes(reads, config, ntuple=1, verbose=False):
     params: reads int()
     return: seq
     """
-    if isinstance(Reads, reads):
+    if not isinstance(reads, Reads):
         if verbose:
             print("error: remove_barcode(), reads is not of type Reads")
         raise ValueError
@@ -621,7 +629,7 @@ def parse_ntuples_fastq(ffastqs,
                               config=config,
                               barcode_dictionary=barcode_dictionary,
                               verbose=verbose)
-        write_seqs(seqs=reads_list,
+        write_seqs(reads_list=reads_list,
                    fouts=fouts,
                    config=config,
                    sample=sample,
-- 
GitLab