From 84e8ac51b8a4ce4d79e378d2f2095af99c884202 Mon Sep 17 00:00:00 2001
From: rlyu <rlyu@svi.edu.au>
Date: Tue, 19 Jul 2022 12:14:55 +1000
Subject: [PATCH] sxo code refactor

---
 src/sgcocaller/sgcocaller_sxo.nim | 132 +++++++++++++++++-------------
 1 file changed, 75 insertions(+), 57 deletions(-)

diff --git a/src/sgcocaller/sgcocaller_sxo.nim b/src/sgcocaller/sgcocaller_sxo.nim
index 2b0f849..1d68a60 100755
--- a/src/sgcocaller/sgcocaller_sxo.nim
+++ b/src/sgcocaller/sgcocaller_sxo.nim
@@ -1,6 +1,7 @@
 ## sgcocaller sxo, using pre-generated mtx files for finding crossovers using a HMM model
 from findPath import pathTrackBack
-from graph  import addViNodeIthSperm, SeqSpermViNodes
+from graph  import addViNodeIthSperm,SpermViNodes
+# SeqSpermViNodes
 import tables
 import hts
 import utils
@@ -9,20 +10,24 @@ import streams
 # bins of SNP indexes
 import strutils
 import os
+import math
 
-proc readCountMtxToSeq*(mtxFileStream:FileStream, countMtx:var seq[seq[int]], by_cell = false):int = 
+proc readCountMtxToSeq*(mtxFileStream:FileStream, countMtx:var seq[seq[int16]], by_cell = false):int = 
   var currentLineSeq:seq[int]
   var currentLine:seq[string]
-  while not mtxFileStream.atEnd():
-    currentLine = mtxFileStream.readLine().splitWhitespace()
-    ## i j 1-based from file
-    currentLineSeq = map(currentLine, proc(x: string): int = parseInt(x))
-    if by_cell:
-      countMtx[(currentLineSeq[0]-1)][(currentLineSeq[1]-1)] = currentLineSeq[2]
-    else:
-      countMtx[(currentLineSeq[1]-1)][(currentLineSeq[0]-1)] = currentLineSeq[2]
+  ## i j 1-based from file
+  if by_cell:
+    while not mtxFileStream.atEnd():
+      currentLine = mtxFileStream.readLine().splitWhitespace()
+      currentLineSeq = map(currentLine, proc(x: string): int = parseInt(x))
+      countMtx[(currentLineSeq[0]-1)][(currentLineSeq[1]-1)] = int16(currentLineSeq[2])
+  else:
+    while not mtxFileStream.atEnd():
+      currentLine = mtxFileStream.readLine().splitWhitespace()
+      currentLineSeq = map(currentLine, proc(x: string): int = parseInt(x))
+      countMtx[(currentLineSeq[1]-1)][(currentLineSeq[0]-1)] = int16(currentLineSeq[2])
   return 0
-proc addViNodeSXO(barcodeTable: TableRef, alleleCountTable: Table[string,allele_expr], scSpermSeq: var SeqSpermViNodes,
+proc addViNodeSXO(barcodeTable: TableRef, alleleCountTable: Table[string,allele_expr], scSpermSeq: TableRef[int,SpermViNodes],
                   snpIndex: int,
                   thetaRef: float,
                   thetaAlt: float,
@@ -34,33 +39,27 @@ proc addViNodeSXO(barcodeTable: TableRef, alleleCountTable: Table[string,allele_
   for bc, ac in alleleCountTable.pairs:
     var ithSperm = barcodeTable[bc]
     nnsize.inc()
-    outaltCountMtxFS.writeLine($snpIndex & " " & $(ithSperm+1) & " " & $ac.calt)
-    outtotalCountMtxFS.writeLine($snpIndex & " " & $(ithSperm+1) & " " & $(ac.calt+ac.cref))
+    outaltCountMtxFS.writeLine($(snpIndex+1) & " " & $(ithSperm+1) & " " & $ac.calt)
+    outtotalCountMtxFS.writeLine($(snpIndex+1) & " " & $(ithSperm+1) & " " & $(ac.calt+ac.cref))
     var emissionArray = getEmission(thetaRef=thetaRef,thetaAlt=thetaAlt,cRef=ac.cref,cAlt=ac.cAlt)
-    discard addViNodeIthSperm(scSpermSeq = scSpermSeq, cAlt = int(ac.calt), cRef = int(ac.cref), ithSperm = ithSperm, emissionArray = emissionArray, snpIndex = snpIndex,initProb = initProb,rec_pos =rec_pos,cmPmb = cmPmb)
+    discard addViNodeIthSperm(scSpermSeq = scSpermSeq, cAlt = ac.calt, cRef = ac.cref, ithSperm = ithSperm, emissionArray = emissionArray, snpIndex = (snpIndex+1), initProb = initProb,rec_pos =rec_pos,cmPmb = cmPmb)
   return 0
 ## barcodeTable, cell barcode:cell index
-proc sgcocallerSXO*(barcodeTable:TableRef, phase_dir:string, out_dir:string, thetaREF:float, thetaALT:float, cmPmb:float, s_Chrs:seq[string], initProb: array[stateRef..stateAlt, float], phasedSnpAnnotFileName:string): int =
+proc sgcocallerSXO*(barcodeTable:TableRef, phase_dir:string, out_dir:string, thetaREF:float, thetaALT:float, cmPmb:float,
+                    s_Chrs:seq[string], initProb: array[stateRef..stateAlt, float], phasedSnpAnnotFileName:string,
+                    batchSize:int): int =
   var ncells = barcodeTable.len
   var nsnps, ithSperm:int
   var currentEntrySeq: seq[string]
-  var totalCountMtxByCell: seq[seq[int]]
-  var altCountMtxByCell: seq[seq[int]]
+  var totalCountMtxByCell: seq[seq[int16]]
+  var altCountMtxByCell: seq[seq[int16]]
   var alleleCountTable: Table[string,allele_expr]
   ## iterate through each selected chromosome
   for chrom in s_Chrs:
-
     var nnsize = 0
-    ## number of non zeros
-    var scSpermSeq:SeqSpermViNodes
-    ## matches with the order in barcodeTable
-    scSpermSeq.setLen(barcodeTable.len)
+    ## number of non zeros = number of elements in the mtx
     var phasedSnpannoFS,totalCountMtxFS,altCountMtxFS,outFileVStateMtx,outaltCountMtxFS,outSnpAnnotFS, outtotalCountMtxFS,viSegmentInfo:FileStream
     let sparseMatrixHeader = "%%MatrixMarket matrix coordinate integer general"
-    # if fileExists(phase_dir & chrom & "_corrected_phased_snpAnnot.txt"):
-    #   phasedSnpannoFS = openFileStream(phase_dir & chrom & "_corrected_phased_snpAnnot.txt", fmRead)
-    # elif fileExists(phase_dir & chrom & "_phased_snpAnnot.txt"):
-    #   phasedSnpannoFS = openFileStream(phase_dir & chrom & "_phased_snpAnnot.txt", fmRead)
     phasedSnpannoFS =  openFileStream(phasedSnpAnnotFileName, fmRead)
     try:
       #_totalCount
@@ -85,45 +84,64 @@ proc sgcocallerSXO*(barcodeTable:TableRef, phase_dir:string, out_dir:string, the
     currentEntrySeq = totalCountMtxFS.readLine().splitWhitespace()
     nsnps = parseInt(currentEntrySeq[0])
     ## gtMtx is cell by Snp format
-    totalCountMtxByCell = newSeqWith(nsnps,newSeq[int](ncells))
-    altCountMtxByCell = newSeqWith(nsnps,newSeq[int](ncells))
+    totalCountMtxByCell = newSeqWith(nsnps,newSeq[int16](ncells))
+    altCountMtxByCell = newSeqWith(nsnps,newSeq[int16](ncells))
     discard readCountMtxToSeq(mtxFileStream = totalCountMtxFS, countMtx = totalCountMtxByCell, by_cell = true)
     discard readCountMtxToSeq(mtxFileStream = altCountMtxFS, countMtx = altCountMtxByCell, by_cell = true)
-    discard phasedSnpannoFS.readLine()
     var isnpIndex = -1 
-    var osnpIndex = 0
+#    var osnpIndex = 0
+    ## do crossover calling batch by batch to avoid using too much RAM
+    var all_sperm_index = (0..(barcodeTable.len-1)).toSeq()
+    let num_sub_batches = int(floor(all_sperm_index.len/batchSize))
+    var sub_batches = all_sperm_index.distribute(num_sub_batches,spread=false)
+    var scSpermSeq = newTable[int,SpermViNodes]()
+    for sperm_batch in sub_batches:
+#      echo "sperm_batch.len : " & $sperm_batch
+      isnpIndex = -1 
+#      osnpIndex = 0
+      phasedSnpannoFS.setPosition(0)
+      discard phasedSnpannoFS.readLine()
+      while not phasedSnpannoFS.atEnd():
+        currentEntrySeq = phasedSnpannoFS.readLine().splitWhitespace()
+        isnpIndex.inc()
+        if currentEntrySeq[3] == "-1":        
+          continue
+        else:
+          alleleCountTable = initTable[string,allele_expr]()
+          for bc,ithSperm in barcodeTable:
+            if not (ithSperm in sperm_batch): continue
+            if totalCountMtxByCell[isnpIndex][ithSperm] == 0:
+              ## not adding this vi node to the spermSeq
+              continue
+            if currentEntrySeq[3] == "1":
+              alleleCountTable[bc] = allele_expr(cref:(altCountMtxByCell[isnpIndex][ithSperm]), calt: (totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm]))
+            else:
+              alleleCountTable[bc] = allele_expr(cref:(totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm]), calt: (altCountMtxByCell[isnpIndex][ithSperm]))     
+          if alleleCountTable.len == 0:
+            continue
+          # if currentEntrySeq[3] == "1":
+          #   outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[2], currentEntrySeq[1]], sep="\t") )
+          # else:
+          #   outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[1], currentEntrySeq[2]], sep="\t") )
+          discard addViNodeSXO(barcodeTable = barcodeTable,  alleleCountTable = alleleCountTable,  scSpermSeq = scSpermSeq,
+                            thetaRef = thetaRef, thetaAlt = thetaAlt, snpIndex = isnpIndex, rec_pos = parseInt(currentEntrySeq[0]), nnsize = nnsize,
+                            initProb = initProb, cmPmb = cmPmb, outaltCountMtxFS = outaltCountMtxFS, outtotalCountMtxFS = outtotalCountMtxFS)   
+      discard pathTrackBack(scSpermSeq = scSpermSeq, thetaRef = thetaRef, thetaAlt=thetaAlt, cmPmb = cmPmb, outFileVStateMtx = outFileVStateMtx,
+                            viSegmentInfo = viSegmentInfo)
+      scSpermSeq.clear()
+    phasedSnpannoFS.setPosition(0)
+    discard phasedSnpannoFS.readLine()
     while not phasedSnpannoFS.atEnd():
       currentEntrySeq = phasedSnpannoFS.readLine().splitWhitespace()
-      isnpIndex.inc()
-      if currentEntrySeq[3] == "-1":        
-        continue
+      if currentEntrySeq[3] == "1":
+        outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[2], currentEntrySeq[1]], sep="\t") )
+      elif currentEntrySeq[3] == "0":
+        outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[1], currentEntrySeq[2]], sep="\t") )
       else:
-        alleleCountTable = initTable[string,allele_expr]()
-        for bc in barcodeTable.keys():
-          ithSperm = barcodeTable[bc]
-          if totalCountMtxByCell[isnpIndex][ithSperm] == 0:
-            ## not adding this vi node to the spermSeq
-            continue
-          if currentEntrySeq[3] == "1":
-            alleleCountTable[bc] = allele_expr(cref:(altCountMtxByCell[isnpIndex][ithSperm]), calt: (totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm]))
-          else:
-            alleleCountTable[bc] = allele_expr(cref:(totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm]), calt: (altCountMtxByCell[isnpIndex][ithSperm]))     
-        if alleleCountTable.len == 0:
-          continue
-        osnpIndex.inc()
-        if currentEntrySeq[3] == "1":
-          outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[2], currentEntrySeq[1]], sep="\t") )
-        else:
-          outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[1], currentEntrySeq[2]], sep="\t") )
-        discard addViNodeSXO(barcodeTable = barcodeTable,  alleleCountTable = alleleCountTable,  scSpermSeq = scSpermSeq,
-                           thetaRef = thetaRef, thetaAlt = thetaAlt, snpIndex = osnpIndex, rec_pos = parseInt(currentEntrySeq[0]), nnsize = nnsize,
-                           initProb = initProb, cmPmb = cmPmb, outaltCountMtxFS = outaltCountMtxFS, outtotalCountMtxFS = outtotalCountMtxFS)
-       
-    discard pathTrackBack(scSpermSeq = scSpermSeq, thetaRef = thetaRef,thetaAlt=thetaAlt,cmPmb = cmPmb,outFileVStateMtx = outFileVStateMtx,
-                          viSegmentInfo = viSegmentInfo)
+        outSnpAnnotFS.writeLine(join([currentEntrySeq[0], "*", "*"], sep="\t") )
     for fs in [outFileVStateMtx,outaltCountMtxFS,outtotalCountMtxFS]:
       fs.setPosition(49)
-      fs.write($osnpIndex & " " & $barcodeTable.len & " " & $nnsize) 
+      fs.write($(isnpIndex+1) & " " & $barcodeTable.len & " " & $nnsize) 
     for outFileStream in [phasedSnpannoFS,totalCountMtxFS,altCountMtxFS,outFileVStateMtx,outaltCountMtxFS,outSnpAnnotFS, outtotalCountMtxFS,viSegmentInfo]:
       outFileStream.close()
   return 0
-- 
GitLab