From 84e8ac51b8a4ce4d79e378d2f2095af99c884202 Mon Sep 17 00:00:00 2001 From: rlyu <rlyu@svi.edu.au> Date: Tue, 19 Jul 2022 12:14:55 +1000 Subject: [PATCH] sxo code refactor --- src/sgcocaller/sgcocaller_sxo.nim | 132 +++++++++++++++++------------- 1 file changed, 75 insertions(+), 57 deletions(-) diff --git a/src/sgcocaller/sgcocaller_sxo.nim b/src/sgcocaller/sgcocaller_sxo.nim index 2b0f849..1d68a60 100755 --- a/src/sgcocaller/sgcocaller_sxo.nim +++ b/src/sgcocaller/sgcocaller_sxo.nim @@ -1,6 +1,7 @@ ## sgcocaller sxo, using pre-generated mtx files for finding crossovers using a HMM model from findPath import pathTrackBack -from graph import addViNodeIthSperm, SeqSpermViNodes +from graph import addViNodeIthSperm,SpermViNodes +# SeqSpermViNodes import tables import hts import utils @@ -9,20 +10,24 @@ import streams # bins of SNP indexes import strutils import os +import math -proc readCountMtxToSeq*(mtxFileStream:FileStream, countMtx:var seq[seq[int]], by_cell = false):int = +proc readCountMtxToSeq*(mtxFileStream:FileStream, countMtx:var seq[seq[int16]], by_cell = false):int = var currentLineSeq:seq[int] var currentLine:seq[string] - while not mtxFileStream.atEnd(): - currentLine = mtxFileStream.readLine().splitWhitespace() - ## i j 1-based from file - currentLineSeq = map(currentLine, proc(x: string): int = parseInt(x)) - if by_cell: - countMtx[(currentLineSeq[0]-1)][(currentLineSeq[1]-1)] = currentLineSeq[2] - else: - countMtx[(currentLineSeq[1]-1)][(currentLineSeq[0]-1)] = currentLineSeq[2] + ## i j 1-based from file + if by_cell: + while not mtxFileStream.atEnd(): + currentLine = mtxFileStream.readLine().splitWhitespace() + currentLineSeq = map(currentLine, proc(x: string): int = parseInt(x)) + countMtx[(currentLineSeq[0]-1)][(currentLineSeq[1]-1)] = int16(currentLineSeq[2]) + else: + while not mtxFileStream.atEnd(): + currentLine = mtxFileStream.readLine().splitWhitespace() + currentLineSeq = map(currentLine, proc(x: string): int = parseInt(x)) + countMtx[(currentLineSeq[1]-1)][(currentLineSeq[0]-1)] = int16(currentLineSeq[2]) return 0 -proc addViNodeSXO(barcodeTable: TableRef, alleleCountTable: Table[string,allele_expr], scSpermSeq: var SeqSpermViNodes, +proc addViNodeSXO(barcodeTable: TableRef, alleleCountTable: Table[string,allele_expr], scSpermSeq: TableRef[int,SpermViNodes], snpIndex: int, thetaRef: float, thetaAlt: float, @@ -34,33 +39,27 @@ proc addViNodeSXO(barcodeTable: TableRef, alleleCountTable: Table[string,allele_ for bc, ac in alleleCountTable.pairs: var ithSperm = barcodeTable[bc] nnsize.inc() - outaltCountMtxFS.writeLine($snpIndex & " " & $(ithSperm+1) & " " & $ac.calt) - outtotalCountMtxFS.writeLine($snpIndex & " " & $(ithSperm+1) & " " & $(ac.calt+ac.cref)) + outaltCountMtxFS.writeLine($(snpIndex+1) & " " & $(ithSperm+1) & " " & $ac.calt) + outtotalCountMtxFS.writeLine($(snpIndex+1) & " " & $(ithSperm+1) & " " & $(ac.calt+ac.cref)) var emissionArray = getEmission(thetaRef=thetaRef,thetaAlt=thetaAlt,cRef=ac.cref,cAlt=ac.cAlt) - discard addViNodeIthSperm(scSpermSeq = scSpermSeq, cAlt = int(ac.calt), cRef = int(ac.cref), ithSperm = ithSperm, emissionArray = emissionArray, snpIndex = snpIndex,initProb = initProb,rec_pos =rec_pos,cmPmb = cmPmb) + discard addViNodeIthSperm(scSpermSeq = scSpermSeq, cAlt = ac.calt, cRef = ac.cref, ithSperm = ithSperm, emissionArray = emissionArray, snpIndex = (snpIndex+1), initProb = initProb,rec_pos =rec_pos,cmPmb = cmPmb) return 0 ## barcodeTable, cell barcode:cell index -proc sgcocallerSXO*(barcodeTable:TableRef, phase_dir:string, out_dir:string, thetaREF:float, thetaALT:float, cmPmb:float, s_Chrs:seq[string], initProb: array[stateRef..stateAlt, float], phasedSnpAnnotFileName:string): int = +proc sgcocallerSXO*(barcodeTable:TableRef, phase_dir:string, out_dir:string, thetaREF:float, thetaALT:float, cmPmb:float, + s_Chrs:seq[string], initProb: array[stateRef..stateAlt, float], phasedSnpAnnotFileName:string, + batchSize:int): int = var ncells = barcodeTable.len var nsnps, ithSperm:int var currentEntrySeq: seq[string] - var totalCountMtxByCell: seq[seq[int]] - var altCountMtxByCell: seq[seq[int]] + var totalCountMtxByCell: seq[seq[int16]] + var altCountMtxByCell: seq[seq[int16]] var alleleCountTable: Table[string,allele_expr] ## iterate through each selected chromosome for chrom in s_Chrs: - var nnsize = 0 - ## number of non zeros - var scSpermSeq:SeqSpermViNodes - ## matches with the order in barcodeTable - scSpermSeq.setLen(barcodeTable.len) + ## number of non zeros = number of elements in the mtx var phasedSnpannoFS,totalCountMtxFS,altCountMtxFS,outFileVStateMtx,outaltCountMtxFS,outSnpAnnotFS, outtotalCountMtxFS,viSegmentInfo:FileStream let sparseMatrixHeader = "%%MatrixMarket matrix coordinate integer general" - # if fileExists(phase_dir & chrom & "_corrected_phased_snpAnnot.txt"): - # phasedSnpannoFS = openFileStream(phase_dir & chrom & "_corrected_phased_snpAnnot.txt", fmRead) - # elif fileExists(phase_dir & chrom & "_phased_snpAnnot.txt"): - # phasedSnpannoFS = openFileStream(phase_dir & chrom & "_phased_snpAnnot.txt", fmRead) phasedSnpannoFS = openFileStream(phasedSnpAnnotFileName, fmRead) try: #_totalCount @@ -85,45 +84,64 @@ proc sgcocallerSXO*(barcodeTable:TableRef, phase_dir:string, out_dir:string, the currentEntrySeq = totalCountMtxFS.readLine().splitWhitespace() nsnps = parseInt(currentEntrySeq[0]) ## gtMtx is cell by Snp format - totalCountMtxByCell = newSeqWith(nsnps,newSeq[int](ncells)) - altCountMtxByCell = newSeqWith(nsnps,newSeq[int](ncells)) + totalCountMtxByCell = newSeqWith(nsnps,newSeq[int16](ncells)) + altCountMtxByCell = newSeqWith(nsnps,newSeq[int16](ncells)) discard readCountMtxToSeq(mtxFileStream = totalCountMtxFS, countMtx = totalCountMtxByCell, by_cell = true) discard readCountMtxToSeq(mtxFileStream = altCountMtxFS, countMtx = altCountMtxByCell, by_cell = true) - discard phasedSnpannoFS.readLine() var isnpIndex = -1 - var osnpIndex = 0 +# var osnpIndex = 0 + ## do crossover calling batch by batch to avoid using too much RAM + var all_sperm_index = (0..(barcodeTable.len-1)).toSeq() + let num_sub_batches = int(floor(all_sperm_index.len/batchSize)) + var sub_batches = all_sperm_index.distribute(num_sub_batches,spread=false) + var scSpermSeq = newTable[int,SpermViNodes]() + for sperm_batch in sub_batches: +# echo "sperm_batch.len : " & $sperm_batch + isnpIndex = -1 +# osnpIndex = 0 + phasedSnpannoFS.setPosition(0) + discard phasedSnpannoFS.readLine() + while not phasedSnpannoFS.atEnd(): + currentEntrySeq = phasedSnpannoFS.readLine().splitWhitespace() + isnpIndex.inc() + if currentEntrySeq[3] == "-1": + continue + else: + alleleCountTable = initTable[string,allele_expr]() + for bc,ithSperm in barcodeTable: + if not (ithSperm in sperm_batch): continue + if totalCountMtxByCell[isnpIndex][ithSperm] == 0: + ## not adding this vi node to the spermSeq + continue + if currentEntrySeq[3] == "1": + alleleCountTable[bc] = allele_expr(cref:(altCountMtxByCell[isnpIndex][ithSperm]), calt: (totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm])) + else: + alleleCountTable[bc] = allele_expr(cref:(totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm]), calt: (altCountMtxByCell[isnpIndex][ithSperm])) + if alleleCountTable.len == 0: + continue + # if currentEntrySeq[3] == "1": + # outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[2], currentEntrySeq[1]], sep="\t") ) + # else: + # outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[1], currentEntrySeq[2]], sep="\t") ) + discard addViNodeSXO(barcodeTable = barcodeTable, alleleCountTable = alleleCountTable, scSpermSeq = scSpermSeq, + thetaRef = thetaRef, thetaAlt = thetaAlt, snpIndex = isnpIndex, rec_pos = parseInt(currentEntrySeq[0]), nnsize = nnsize, + initProb = initProb, cmPmb = cmPmb, outaltCountMtxFS = outaltCountMtxFS, outtotalCountMtxFS = outtotalCountMtxFS) + discard pathTrackBack(scSpermSeq = scSpermSeq, thetaRef = thetaRef, thetaAlt=thetaAlt, cmPmb = cmPmb, outFileVStateMtx = outFileVStateMtx, + viSegmentInfo = viSegmentInfo) + scSpermSeq.clear() + phasedSnpannoFS.setPosition(0) + discard phasedSnpannoFS.readLine() while not phasedSnpannoFS.atEnd(): currentEntrySeq = phasedSnpannoFS.readLine().splitWhitespace() - isnpIndex.inc() - if currentEntrySeq[3] == "-1": - continue + if currentEntrySeq[3] == "1": + outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[2], currentEntrySeq[1]], sep="\t") ) + elif currentEntrySeq[3] == "0": + outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[1], currentEntrySeq[2]], sep="\t") ) else: - alleleCountTable = initTable[string,allele_expr]() - for bc in barcodeTable.keys(): - ithSperm = barcodeTable[bc] - if totalCountMtxByCell[isnpIndex][ithSperm] == 0: - ## not adding this vi node to the spermSeq - continue - if currentEntrySeq[3] == "1": - alleleCountTable[bc] = allele_expr(cref:(altCountMtxByCell[isnpIndex][ithSperm]), calt: (totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm])) - else: - alleleCountTable[bc] = allele_expr(cref:(totalCountMtxByCell[isnpIndex][ithSperm] - altCountMtxByCell[isnpIndex][ithSperm]), calt: (altCountMtxByCell[isnpIndex][ithSperm])) - if alleleCountTable.len == 0: - continue - osnpIndex.inc() - if currentEntrySeq[3] == "1": - outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[2], currentEntrySeq[1]], sep="\t") ) - else: - outSnpAnnotFS.writeLine(join([currentEntrySeq[0], currentEntrySeq[1], currentEntrySeq[2]], sep="\t") ) - discard addViNodeSXO(barcodeTable = barcodeTable, alleleCountTable = alleleCountTable, scSpermSeq = scSpermSeq, - thetaRef = thetaRef, thetaAlt = thetaAlt, snpIndex = osnpIndex, rec_pos = parseInt(currentEntrySeq[0]), nnsize = nnsize, - initProb = initProb, cmPmb = cmPmb, outaltCountMtxFS = outaltCountMtxFS, outtotalCountMtxFS = outtotalCountMtxFS) - - discard pathTrackBack(scSpermSeq = scSpermSeq, thetaRef = thetaRef,thetaAlt=thetaAlt,cmPmb = cmPmb,outFileVStateMtx = outFileVStateMtx, - viSegmentInfo = viSegmentInfo) + outSnpAnnotFS.writeLine(join([currentEntrySeq[0], "*", "*"], sep="\t") ) for fs in [outFileVStateMtx,outaltCountMtxFS,outtotalCountMtxFS]: fs.setPosition(49) - fs.write($osnpIndex & " " & $barcodeTable.len & " " & $nnsize) + fs.write($(isnpIndex+1) & " " & $barcodeTable.len & " " & $nnsize) for outFileStream in [phasedSnpannoFS,totalCountMtxFS,altCountMtxFS,outFileVStateMtx,outaltCountMtxFS,outSnpAnnotFS, outtotalCountMtxFS,viSegmentInfo]: outFileStream.close() return 0 -- GitLab