From e97db92d802a16a4f0d0b5cf5c366e8ad0945b09 Mon Sep 17 00:00:00 2001
From: pqiao29 <pqiao@student.unimelb.edu.au>
Date: Mon, 30 Mar 2020 22:39:31 +1100
Subject: [PATCH] clean up recovery.FUN: till line 117 in revised
 (recovery.FUN) and line 192 in original (progress.seiqhrf.icm)

---
 R/Chu_recovery.R                   | 454 +++++++++++++++++++++++++++++
 R/{mod_status.R => recovery.FUN.R} | 244 ++++++----------
 man/recovery.FUN.Rd                |   6 +-
 tests/testthat/test-initialize.R   |   4 +-
 tests/testthat/test-recovery.R     |  20 ++
 5 files changed, 576 insertions(+), 152 deletions(-)
 create mode 100644 R/Chu_recovery.R
 rename R/{mod_status.R => recovery.FUN.R} (69%)
 create mode 100644 tests/testthat/test-recovery.R

diff --git a/R/Chu_recovery.R b/R/Chu_recovery.R
new file mode 100644
index 0000000..dd91707
--- /dev/null
+++ b/R/Chu_recovery.R
@@ -0,0 +1,454 @@
+## Churches' original function
+## internal in sirplus
+
+progress.seiqhrf.icm <- function(dat, at, seed = NULL) {
+    
+    if(!is.null(seed)) set.seed(seed)
+    #print(at)
+    #print(dat$control$type)
+    #print("-------")
+    
+    # Conditions --------------------------------------------------------------
+    if (!(dat$control$type %in% c("SIR", "SIS", "SEIR", "SEIQHR", "SEIQHRF"))) {
+        return(dat)
+    }
+    
+    
+    # Variables ---------------------------------------------------------------
+    active <- dat$attr$active
+    status <- dat$attr$status
+    
+    groups <- dat$param$groups
+    group <- dat$attr$group
+    
+    type <- dat$control$type
+    recovState <- ifelse(type %in% c("SIR", "SEIR", "SEIQHR", "SEIQHRF"), "r", "s")
+    progState <- "i"
+    quarState <- "q"
+    hospState <- "h"
+    fatState <- "f"
+    
+    # --- progress from exposed to infectious ----
+    prog.rand <- dat$control$prog.rand
+    prog.rate <- dat$param$prog.rate
+    prog.rate.g2 <- dat$param$prog.rate.g2
+    prog.dist.scale <- dat$param$prog.dist.scale
+    prog.dist.shape <- dat$param$prog.dist.shape
+    prog.dist.scale.g2 <- dat$param$prog.dist.scale.g2
+    prog.dist.shape.g2 <- dat$param$prog.dist.shape.g2
+    
+    nProg <- nProgG2 <- 0
+    idsElig <- which(active == 1 & status == "e")
+    nElig <- length(idsElig)
+    
+    if (nElig > 0) {
+        
+        gElig <- group[idsElig]
+        rates <- c(prog.rate, prog.rate.g2)
+        ratesElig <- rates[gElig]
+        
+        if (prog.rand == TRUE) {
+            vecProg <- which(rbinom(nElig, 1, ratesElig) == 1)
+            if (length(vecProg) > 0) {
+                idsProg <- idsElig[vecProg]
+                nProg <- sum(group[idsProg] == 1)
+                nProgG2 <- sum(group[idsProg] == 2)
+                status[idsProg] <- progState
+                dat$attr$infTime[idsProg] <- at
+            }
+        } else {
+            vecTimeSinceExp <- at - dat$attr$expTime[idsElig]
+            gammaRatesElig <- pweibull(vecTimeSinceExp, prog.dist.shape, scale=prog.dist.scale) 
+            nProg <- round(sum(gammaRatesElig[gElig == 1], na.rm=TRUE))
+            if (nProg > 0) {
+                ids2bProg <- ssample(idsElig[gElig == 1], 
+                                     nProg, prob = gammaRatesElig[gElig == 1])
+                status[ids2bProg] <- progState
+                dat$attr$infTime[ids2bProg] <- at
+                # debug
+                if (FALSE & at <= 30) {
+                    print(paste("at:", at))
+                    print("idsElig:")
+                    print(idsElig[gElig == 1])
+                    print("vecTimeSinceExp:")
+                    print(vecTimeSinceExp[gElig == 1])
+                    print("gammaRatesElig:")
+                    print(gammaRatesElig)
+                    print(paste("nProg:",nProg))
+                    print(paste("sum of elig rates:", round(sum(gammaRatesElig[gElig == 1]))))
+                    print(paste("sum(gElig == 1):", sum(gElig == 1)))
+                    print("ids progressed:")
+                    print(ids2bProg)
+                    print("probs of ids to be progressed:")
+                    print(gammaRatesElig[which(idsElig %in% ids2bProg)]) 
+                    print("days since exposed of ids to be progressed:")
+                    print(vecTimeSinceExp[which(idsElig %in% ids2bProg)]) 
+                    print("------")
+                }  
+            }
+            if (groups == 2) {
+                nProgG2 <- round(sum(gammaRatesElig[gElig == 2], na.rm=TRUE))
+                if (nProgG2 > 0) {
+                    ids2bProgG2 <- ssample(idsElig[gElig == 2], 
+                                           nProgG2, prob = gammaRatesElig[gElig == 2])
+                    status[ids2bProgG2] <- progState
+                    dat$attr$infTime[ids2bProgG2] <- at
+                }
+            }
+        }
+    }
+    dat$attr$status <- status
+    
+    if (type %in% c("SEIQHR", "SEIQHRF")) {  
+        # ----- quarantine ------- 
+        quar.rand <- dat$control$quar.rand
+        quar.rate <- dat$param$quar.rate
+        quar.rate.g2 <- dat$param$quar.rate.g2
+        
+        nQuar <- nQuarG2 <- 0
+        idsElig <- which(active == 1 & status == "i")
+        nElig <- length(idsElig)
+        
+        if (nElig > 0) {
+            
+            gElig <- group[idsElig]
+            rates <- c(quar.rate, quar.rate.g2)
+            
+            if (length(quar.rate) > 1) {
+                qrate <- quar.rate[at]
+            } else {
+                qrate <- quar.rate
+            }
+            if (length(quar.rate.g2) > 1) {
+                qrate.g2 <- quar.rate.g2[at]
+            } else {
+                qrate.g2 <- quar.rate.g2
+            }
+            rates <- c(qrate, qrate.g2)
+            ratesElig <- rates[gElig]
+            if (quar.rand == TRUE) {
+                vecQuar <- which(rbinom(nElig, 1, ratesElig) == 1)
+                if (length(vecQuar) > 0) {
+                    idsQuar <- idsElig[vecQuar]
+                    nQuar <- sum(group[idsQuar] == 1)
+                    nQuarG2 <- sum(group[idsQuar] == 2)
+                    status[idsQuar] <- quarState
+                    dat$attr$quarTime[idsQuar] <- at
+                }
+            } else {
+                nQuar <- min(round(sum(ratesElig[gElig == 1])), sum(gElig == 1))
+                idsQuar <- ssample(idsElig[gElig == 1], nQuar)
+                status[idsQuar] <- quarState
+                dat$attr$quarTime[idsQuar] <- at
+                if (groups == 2) {
+                    nQuarG2 <- min(round(sum(ratesElig[gElig == 2])), sum(gElig == 2))
+                    idsQuarG2 <- ssample(idsElig[gElig == 2], nQuarG2)
+                    status[idsQuarG2] <- quarState
+                    dat$attr$quarTime[idsQuarG2] <- at
+                }
+            }
+        }
+        dat$attr$status <- status
+        
+        # ----- need to be hospitalised ------- 
+        hosp.rand <- dat$control$hosp.rand
+        hosp.rate <- dat$param$hosp.rate
+        hosp.rate.g2 <- dat$param$hosp.rate.g2
+        
+        nHosp <- nHospG2 <- 0
+        idsElig <- which(active == 1 & (status == "i" | status == "q"))
+        nElig <- length(idsElig)
+        idsHosp <- numeric(0)
+        
+        if (nElig > 0) {
+            
+            gElig <- group[idsElig]
+            rates <- c(hosp.rate, hosp.rate.g2)
+            ratesElig <- rates[gElig]
+            
+            if (hosp.rand == TRUE) {
+                vecHosp <- which(rbinom(nElig, 1, ratesElig) == 1)
+                if (length(vecHosp) > 0) {
+                    idsHosp <- idsElig[vecHosp]
+                    nHosp <- sum(group[idsHosp] == 1)
+                    nHospG2 <- sum(group[idsHosp] == 2)
+                    status[idsHosp] <- hospState
+                }
+            } else {
+                nHosp <- min(round(sum(ratesElig[gElig == 1])), sum(gElig == 1))
+                idsHosp <- ssample(idsElig[gElig == 1], nHosp)
+                status[idsHosp] <- hospState
+                if (groups == 2) {
+                    nHospG2 <- min(round(sum(ratesElig[gElig == 2])), sum(gElig == 2))
+                    idsHospG2 <- ssample(idsElig[gElig == 2], nHospG2)
+                    status[idsHospG2] <- hospState
+                    idsHosp <- c(idsHosp, idsHospG2)
+                }
+            }
+        }
+        dat$attr$status <- status
+        dat$attr$hospTime[idsHosp] <- at
+        
+        # ----- discharge from need to be hospitalised ------- 
+        disch.rand <- dat$control$disch.rand
+        disch.rate <- dat$param$disch.rate
+        disch.rate.g2 <- dat$param$disch.rate.g2
+        
+        nDisch <- nDischG2 <- 0
+        idsElig <- which(active == 1 & status == "h")
+        nElig <- length(idsElig)
+        idsDisch <- numeric(0)
+        
+        if (nElig > 0) {
+            
+            gElig <- group[idsElig]
+            rates <- c(disch.rate, disch.rate.g2)
+            
+            if (length(disch.rate) > 1) {
+                dcrate <- disch.rate[at]
+            } else {
+                dcrate <- disch.rate
+            }
+            if (length(disch.rate.g2) > 1) {
+                dcrate.g2 <- disch.rate.g2[at]
+            } else {
+                dcrate.g2 <- disch.rate.g2
+            }
+            
+            rates <- c(dcrate, dcrate.g2)
+            ratesElig <- rates[gElig]
+            
+            if (disch.rand == TRUE) {
+                vecDisch <- which(rbinom(nElig, 1, ratesElig) == 1)
+                if (length(vecDisch) > 0) {
+                    idsDisch <- idsElig[vecDisch]
+                    nDisch <- sum(group[idsDisch] == 1)
+                    nDischG2 <- sum(group[idsDisch] == 2)
+                    status[idsDisch] <- recovState
+                }
+            } else {
+                nDisch <- min(round(sum(ratesElig[gElig == 1])), sum(gElig == 1))
+                idsDisch <- ssample(idsElig[gElig == 1], nDisch)
+                status[idsDisch] <- recovState
+                if (groups == 2) {
+                    nDischG2 <- min(round(sum(ratesElig[gElig == 2])), sum(gElig == 2))
+                    idsDischG2 <- ssample(idsElig[gElig == 2], nDischG2)
+                    status[idsDischG2] <- recovState
+                    idsDisch <- c(idsDisch, idsDischG2)
+                }
+            }
+        }
+        dat$attr$status <- status
+        dat$attr$dischTime[idsDisch] <- at
+    }
+    
+    # ----- recover ------- 
+    rec.rand <- dat$control$rec.rand
+    rec.rate <- dat$param$rec.rate
+    rec.rate.g2 <- dat$param$rec.rate.g2
+    rec.dist.scale <- dat$param$rec.dist.scale
+    rec.dist.shape <- dat$param$rec.dist.shape
+    rec.dist.scale.g2 <- dat$param$rec.dist.scale.g2
+    rec.dist.shape.g2 <- dat$param$rec.dist.shape.g2
+    
+    nRecov <- nRecovG2 <- 0
+    idsElig <- which(active == 1 & (status == "i" | status == "q" | status == "h"))
+    nElig <- length(idsElig)
+    idsRecov <- numeric(0)
+    
+    if (nElig > 0) {
+        
+        gElig <- group[idsElig]
+        rates <- c(rec.rate, rec.rate.g2)
+        ratesElig <- rates[gElig]
+        
+        if (rec.rand == TRUE) {
+            vecRecov <- which(rbinom(nElig, 1, ratesElig) == 1)
+            if (length(vecRecov) > 0) {
+                idsRecov <- idsElig[vecRecov]
+                nRecov <- sum(group[idsRecov] == 1)
+                nRecovG2 <- sum(group[idsRecov] == 2)
+                status[idsRecov] <- recovState
+            }
+        } else {
+            vecTimeSinceExp <- at - dat$attr$expTime[idsElig]
+            vecTimeSinceExp[is.na(vecTimeSinceExp)] <- 0
+            gammaRatesElig <- pweibull(vecTimeSinceExp, rec.dist.shape, scale=rec.dist.scale) 
+            nRecov <- round(sum(gammaRatesElig[gElig == 1], na.rm=TRUE))
+            if (nRecov > 0) {
+                idsRecov <- ssample(idsElig[gElig == 1], 
+                                    nRecov, prob = gammaRatesElig[gElig == 1])
+                status[idsRecov] <- recovState
+                # debug
+                if (FALSE & at <= 30) {
+                    print(paste("at:", at))
+                    print("idsElig:")
+                    print(idsElig[gElig == 1])
+                    print("vecTimeSinceExp:")
+                    print(vecTimeSinceExp[gElig == 1])
+                    print("gammaRatesElig:")
+                    print(gammaRatesElig)
+                    print(paste("nRecov:",nRecov))
+                    print(paste("sum of elig rates:", round(sum(gammaRatesElig[gElig == 1]))))
+                    print(paste("sum(gElig == 1):", sum(gElig == 1)))
+                    print("ids recovered:")
+                    print(idsRecov)
+                    print("probs of ids to be progressed:")
+                    print(gammaRatesElig[which(idsElig %in% idsRecov)]) 
+                    print("days since exposed of ids to be Recovered:")
+                    print(vecTimeSinceExp[which(idsElig %in% idsRecov)]) 
+                    print("------")
+                }  
+                
+            }
+            if (groups == 2) {
+                nRecovG2 <- round(sum(gammaRatesElig[gElig == 2], na.rm=TRUE))
+                if (nRecovG2 > 0) {
+                    idsRecovG2 <- ssample(idsElig[gElig == 2], 
+                                          nRecovG2, prob = gammaRatesElig[gElig == 2])
+                    status[idsRecovG2] <- recovState
+                    idsRecov <- c(idsRecov, idsRecovG2)
+                }
+            }
+        }
+    }
+    dat$attr$status <- status
+    dat$attr$recovTime[idsRecov] <- at
+    
+    fatEnable <- TRUE
+    if (fatEnable & type %in% c("SEIQHRF")) {  
+        # ----- case fatality ------- 
+        fat.rand <- dat$control$fat.rand
+        fat.rate.base <- dat$param$fat.rate.base
+        fat.rate.base.g2 <- dat$param$fat.rate.base.g2
+        fat.rate.base.g2 <- ifelse(is.null(fat.rate.base.g2), 
+                                   0, fat.rate.base.g2)
+        fat.rate.overcap <- dat$param$fat.rate.overcap
+        fat.rate.overcap.g2 <- dat$param$fat.rate.overcap.g2
+        fat.rate.overcap.g2 <- ifelse(is.null(fat.rate.overcap.g2), 
+                                      0, fat.rate.overcap.g2)
+        hosp.cap <- dat$param$hosp.cap
+        fat.tcoeff <- dat$param$fat.tcoeff
+        
+        nFat <- nFatG2 <- 0
+        idsElig <- which(active == 1 & status =="h")
+        nElig <- length(idsElig)
+        
+        if (nElig > 0) {
+            gElig <- group[idsElig]
+            timeInHospElig <- at - dat$attr$hospTime[idsElig]
+            rates <- c(fat.rate.base, fat.rate.base.g2)
+            h.num.yesterday <- 0
+            if (!is.null(dat$epi$h.num[at - 1])) {
+                h.num.yesterday <- dat$epi$h.num[at - 1]
+                if (h.num.yesterday > hosp.cap) {
+                    blended.rate <- ((hosp.cap * fat.rate.base) + 
+                                         ((h.num.yesterday - hosp.cap) * fat.rate.overcap)) / 
+                        h.num.yesterday
+                    blended.rate.g2 <- ((hosp.cap * fat.rate.base.g2) + 
+                                            ((h.num.yesterday - hosp.cap) * fat.rate.overcap.g2)) / 
+                        h.num.yesterday
+                    rates <- c(blended.rate, blended.rate.g2)
+                }  
+            } 
+            ratesElig <- rates[gElig]
+            ratesElig <- ratesElig + timeInHospElig*fat.tcoeff*ratesElig
+            
+            if (fat.rand == TRUE) {
+                vecFat <- which(rbinom(nElig, 1, ratesElig) == 1)
+                if (length(vecFat) > 0) {
+                    idsFat <- idsElig[vecFat]
+                    nFat <- sum(group[idsFat] == 1)
+                    nFatG2 <- sum(group[idsFat] == 2)
+                    status[idsFat] <- fatState
+                    dat$attr$fatTime[idsFat] <- at
+                }
+            } else {
+                nFat <- min(round(sum(ratesElig[gElig == 1])), sum(gElig == 1))
+                idsFat <- ssample(idsElig[gElig == 1], nFat)
+                status[idsFat] <- fatState
+                dat$attr$fatTime[idsFat] <- at
+                if (groups == 2) {
+                    nFatG2 <- min(round(sum(ratesElig[gElig == 2])), sum(gElig == 2))
+                    idsFatG2 <- ssample(idsElig[gElig == 2], nFatG2)
+                    status[idsFatG2] <- fatState
+                    dat$attr$fatTime[idsFatG2] <- at
+                }
+            }
+        }
+        dat$attr$status <- status
+    }
+    
+    # Output ------------------------------------------------------------------
+    outName_a <- ifelse(type %in% c("SIR", "SEIR"), "ir.flow", "is.flow")
+    outName_a[2] <- paste0(outName_a, ".g2")
+    if (type %in% c("SEIR", "SEIQHR", "SEIQHRF")) {
+        outName_b <- "ei.flow"
+        outName_b[2] <- paste0(outName_b, ".g2")
+    }
+    if (type %in% c("SEIQHR", "SEIQHRF")) {
+        outName_c <- "iq.flow"
+        outName_c[2] <- paste0(outName_c, ".g2")
+        outName_d <- "iq2h.flow"
+        outName_d[2] <- paste0(outName_d, ".g2")
+    }
+    if (type %in% c("SEIQHRF")) {
+        outName_e <- "hf.flow"
+        outName_e[2] <- paste0(outName_e, ".g2")
+    }
+    ## Summary statistics
+    if (at == 2) {
+        dat$epi[[outName_a[1]]] <- c(0, nRecov)
+        if (type %in% c("SEIR", "SEIQHR")) {
+            dat$epi[[outName_b[1]]] <- c(0, nProg) 
+        }
+        if (type %in% c("SEIQHR", "SEIQHRF")) {
+            dat$epi[[outName_c[1]]] <- c(0, nQuar) 
+            dat$epi[[outName_d[1]]] <- c(0, nHosp) 
+        }
+        if (fatEnable & type %in% c("SEIQHRF")) {
+            dat$epi[[outName_e[1]]] <- c(0, nFat) 
+        }
+    } else {
+        dat$epi[[outName_a[1]]][at] <- nRecov
+        if (type %in% c("SEIR", "SEIQHR")) {
+            dat$epi[[outName_b[1]]][at] <- nProg 
+        }
+        if (type %in% c("SEIQHR", "SEIQHRF")) {
+            dat$epi[[outName_c[1]]][at] <- nQuar 
+            dat$epi[[outName_d[1]]][at] <- nHosp 
+        }
+        if (fatEnable & type %in% c("SEIQHRF")) {
+            dat$epi[[outName_e[1]]][at] <- nFat 
+        }
+    }
+    if (groups == 2) {
+        if (at == 2) {
+            dat$epi[[outName_a[2]]] <- c(0, nRecovG2)
+            if (type %in% c("SEIR", "SEIQHR", "SEIQHRF")) {
+                dat$epi[[outName_b[2]]] <- c(0, nProgG2) 
+            }
+            if (type %in% c("SEIQHR", "SEIQHRF")) {
+                dat$epi[[outName_c[2]]] <- c(0, nQuarG2) 
+                dat$epi[[outName_d[2]]] <- c(0, nHospG2) 
+            }
+            if (type %in% c("SEIQHRF")) {
+                dat$epi[[outName_e[2]]] <- c(0, nFatG2) 
+            }
+        } else {
+            dat$epi[[outName_a[2]]][at] <- nRecovG2
+            if (type %in% c("SEIR", "SEIQHR", "SEIQHRF")) {
+                dat$epi[[outName_b[2]]][at] <- nProgG2 
+            }
+            if (type %in% c("SEIQHR", "SEIQHRF")) {
+                dat$epi[[outName_c[2]]][at] <- nQuarG2 
+                dat$epi[[outName_d[2]]][at] <- nHospG2 
+            }
+            if (type %in% c("SEIQHRF")) {
+                dat$epi[[outName_e[2]]][at] <- nFatG2 
+            }
+        }
+    }
+    
+    return(dat)
+}
diff --git a/R/mod_status.R b/R/recovery.FUN.R
similarity index 69%
rename from R/mod_status.R
rename to R/recovery.FUN.R
index c8e15f2..60f9200 100644
--- a/R/mod_status.R
+++ b/R/recovery.FUN.R
@@ -3,199 +3,147 @@
 #' Function to get progress of icms
 #'
 #' @param dat Object containing all data
-#' @param at ?
+#' @param at time point
+#' @param seed random seed for checking consistency
 #'
 #' @return progress
 #' @importFrom stats pweibull
+#' @importFrom stats rbinom
+#' @importFrom EpiModel ssample
 #' @importFrom stats rgeom
 #' @importFrom stats sd
 #' @export
-recovery.FUN <- function(dat, at) {
-
-  #print(at)
-  #print(dat$control$type)
-  #print("-------")
-
+recovery.FUN <- function(dat, at, seed = NULL) {
+  
+  if(!is.null(seed)) set.seed(seed)
   # Conditions --------------------------------------------------------------
   if (!(dat$control$type %in% c("SIR", "SIS", "SEIR", "SEIQHR", "SEIQHRF"))) {
     return(dat)
   }
-
+  
+  
+  # internal function --------------------------------------------------------------
+  update_status <- function(rate, rand, active, status, label, state, at, prog, expTime = NULL, prog.dist.scale = NULL, prog.dist.shape = NULL){
+    
+    smp_sz <- 0
+    at_idx <- NULL
+    
+    idsElig <- which(active == 1 & status == label)
+    nElig <- length(idsElig)
+    
+    if (nElig > 0) {
+      gElig <- rep(1, nElig)
+      
+      if (rand) {
+        vecProg <- which(stats::rbinom(nElig, 1, rate) == 1)
+        if (length(vecProg) > 0) {
+          idsProg <- idsElig[vecProg]
+          smp_sz <- length(idsProg)
+          status[idsProg] <<- state
+          at_idx <- idsProg
+        }
+      }else{
+        do_sample <- TRUE
+        if(!prog){
+          smp_sz <- min(round(sum(rates[gElig == 1])), sum(gElig == 1))
+          smp_prob <- NULL
+        }else{
+          vecTimeSinceExp <- at - expTime[idsElig]
+          gammaRatesElig <- stats::pweibull(vecTimeSinceExp, prog.dist.shape, scale=prog.dist.scale)
+          smp_sz <- round(sum(gammaRatesElig[gElig == 1], na.rm=TRUE)) 
+          smp_prob <- gammaRatesElig[gElig == 1]
+          if(smp_sz <= 0) do_sample <- FALSE
+        }
+        
+        if(do_sample){
+          ids2bProg <- EpiModel::ssample(idsElig[gElig == 1], smp_sz, prob = smp_prob)
+          status[ids2bProg] <<- state
+          at_idx <- ids2bProg
+        }
+        
+        
+      }
+      
+      dat$attr$status <<- status
+    }
+    
+    list(smp_sz, at_idx)
+  }
+  
+  
   # Variables ---------------------------------------------------------------
   active <- dat$attr$active
   status <- dat$attr$status
-
+  
   groups <- dat$param$groups
   group <- dat$attr$group
-
+  
   type <- dat$control$type
   recovState <- ifelse(type %in% c("SIR", "SEIR", "SEIQHR", "SEIQHRF"), "r", "s")
-  progState <- "i"
-  quarState <- "q"
   hospState <- "h"
   fatState <- "f"
-
+  
   # --- progress from exposed to infectious ----
-  prog.rand <- dat$control$prog.rand
-  prog.rate <- dat$param$prog.rate
-  prog.rate.g2 <- dat$param$prog.rate.g2
-  prog.dist.scale <- dat$param$prog.dist.scale
-  prog.dist.shape <- dat$param$prog.dist.shape
-  prog.dist.scale.g2 <- dat$param$prog.dist.scale.g2
-  prog.dist.shape.g2 <- dat$param$prog.dist.shape.g2
-
-  nProg <- nProgG2 <- 0
-  idsElig <- which(active == 1 & status == "e")
-  nElig <- length(idsElig)
-
-  if (nElig > 0) {
-
-    gElig <- group[idsElig]
-    rates <- c(prog.rate, prog.rate.g2)
-    ratesElig <- rates[gElig]
-
-    if (prog.rand == TRUE) {
-      vecProg <- which(rbinom(nElig, 1, ratesElig) == 1)
-      if (length(vecProg) > 0) {
-        idsProg <- idsElig[vecProg]
-        nProg <- sum(group[idsProg] == 1)
-        nProgG2 <- sum(group[idsProg] == 2)
-        status[idsProg] <- progState
-        dat$attr$infTime[idsProg] <- at
-      }
-    } else {
-      vecTimeSinceExp <- at - dat$attr$expTime[idsElig]
-      gammaRatesElig <- pweibull(vecTimeSinceExp, prog.dist.shape, scale=prog.dist.scale)
-      nProg <- round(sum(gammaRatesElig[gElig == 1], na.rm=TRUE))
-      if (nProg > 0) {
-        ids2bProg <- ssample(idsElig[gElig == 1],
-                      nProg, prob = gammaRatesElig[gElig == 1])
-        status[ids2bProg] <- progState
-        dat$attr$infTime[ids2bProg] <- at
-        # debug
-        if (FALSE & at <= 30) {
-          print(paste("at:", at))
-          print("idsElig:")
-          print(idsElig[gElig == 1])
-          print("vecTimeSinceExp:")
-          print(vecTimeSinceExp[gElig == 1])
-          print("gammaRatesElig:")
-          print(gammaRatesElig)
-          print(paste("nProg:",nProg))
-          print(paste("sum of elig rates:", round(sum(gammaRatesElig[gElig == 1]))))
-          print(paste("sum(gElig == 1):", sum(gElig == 1)))
-          print("ids progressed:")
-          print(ids2bProg)
-          print("probs of ids to be progressed:")
-          print(gammaRatesElig[which(idsElig %in% ids2bProg)])
-          print("days since exposed of ids to be progressed:")
-          print(vecTimeSinceExp[which(idsElig %in% ids2bProg)])
-          print("------")
-        }
-      }
-      if (groups == 2) {
-        nProgG2 <- round(sum(gammaRatesElig[gElig == 2], na.rm=TRUE))
-        if (nProgG2 > 0) {
-          ids2bProgG2 <- ssample(idsElig[gElig == 2],
-                        nProgG2, prob = gammaRatesElig[gElig == 2])
-          status[ids2bProgG2] <- progState
-          dat$attr$infTime[ids2bProgG2] <- at
-        }
-      }
-    }
-  }
-  dat$attr$status <- status
-
+  res <- update_status(rate = dat$param$prog.rate,
+                       rand = dat$control$prog.rand, 
+                       active, status, 
+                       label = "e", 
+                       state =  "i", 
+                       at, prog = TRUE, 
+                       expTime = dat$attr$expTime, 
+                       prog.dist.scale = dat$param$prog.dist.scale,
+                       prog.dist.shape = dat$param$prog.dist.shape)
+  
+  nProg <- res[[1]]
+  if(!is.null(res[[2]])) dat$attr$infTime[res[[2]]] <- at
+  
+  
   if (type %in% c("SEIQHR", "SEIQHRF")) {
-    # ----- quarantine -------
-    quar.rand <- dat$control$quar.rand
+    
+    # ----- quarantine ------- 
     quar.rate <- dat$param$quar.rate
-    quar.rate.g2 <- dat$param$quar.rate.g2
-
-    nQuar <- nQuarG2 <- 0
-    idsElig <- which(active == 1 & status == "i")
-    nElig <- length(idsElig)
-
-    if (nElig > 0) {
-
-      gElig <- group[idsElig]
-      rates <- c(quar.rate, quar.rate.g2)
-
-      if (length(quar.rate) > 1) {
-          qrate <- quar.rate[at]
-      } else {
-          qrate <- quar.rate
-      }
-      if (length(quar.rate.g2) > 1) {
-          qrate.g2 <- quar.rate.g2[at]
-      } else {
-          qrate.g2 <- quar.rate.g2
-      }
-      rates <- c(qrate, qrate.g2)
-      ratesElig <- rates[gElig]
-      if (quar.rand == TRUE) {
-        vecQuar <- which(rbinom(nElig, 1, ratesElig) == 1)
-        if (length(vecQuar) > 0) {
-          idsQuar <- idsElig[vecQuar]
-          nQuar <- sum(group[idsQuar] == 1)
-          nQuarG2 <- sum(group[idsQuar] == 2)
-          status[idsQuar] <- quarState
-          dat$attr$quarTime[idsQuar] <- at
-        }
-      } else {
-        nQuar <- min(round(sum(ratesElig[gElig == 1])), sum(gElig == 1))
-        idsQuar <- ssample(idsElig[gElig == 1], nQuar)
-        status[idsQuar] <- quarState
-        dat$attr$quarTime[idsQuar] <- at
-        if (groups == 2) {
-          nQuarG2 <- min(round(sum(ratesElig[gElig == 2])), sum(gElig == 2))
-          idsQuarG2 <- ssample(idsElig[gElig == 2], nQuarG2)
-          status[idsQuarG2] <- quarState
-          dat$attr$quarTime[idsQuarG2] <- at
-        }
-      }
-    }
-    dat$attr$status <- status
-
+    rate <- ifelse(length(quar.rate) > 1, quar.rate[at], quar.rate)
+    
+    res <- update_status(rate,
+                         rand = dat$control$quar.rand,
+                         active = dat$attr$active, 
+                         status = dat$attr$status,
+                         label = "i", 
+                         state = "q", at, prog = FALSE)
+    nQuar <- res[[1]]
+    if(!is.null(res[[2]])) dat$attr$quarTime[res[[2]]] <- at
+    
+    
     # ----- need to be hospitalised -------
     hosp.rand <- dat$control$hosp.rand
     hosp.rate <- dat$param$hosp.rate
-    hosp.rate.g2 <- dat$param$hosp.rate.g2
-
+    
     nHosp <- nHospG2 <- 0
     idsElig <- which(active == 1 & (status == "i" | status == "q"))
     nElig <- length(idsElig)
     idsHosp <- numeric(0)
-
+    
     if (nElig > 0) {
-
+      
       gElig <- group[idsElig]
-      rates <- c(hosp.rate, hosp.rate.g2)
+      rates <- hosp.rate
       ratesElig <- rates[gElig]
-
-      if (hosp.rand == TRUE) {
-        vecHosp <- which(rbinom(nElig, 1, ratesElig) == 1)
+      
+      if (hosp.rand) {
+        vecHosp <- which(stats::rbinom(nElig, 1, ratesElig) == 1)
         if (length(vecHosp) > 0) {
           idsHosp <- idsElig[vecHosp]
           nHosp <- sum(group[idsHosp] == 1)
-          nHospG2 <- sum(group[idsHosp] == 2)
           status[idsHosp] <- hospState
         }
       } else {
         nHosp <- min(round(sum(ratesElig[gElig == 1])), sum(gElig == 1))
         idsHosp <- ssample(idsElig[gElig == 1], nHosp)
         status[idsHosp] <- hospState
-        if (groups == 2) {
-          nHospG2 <- min(round(sum(ratesElig[gElig == 2])), sum(gElig == 2))
-          idsHospG2 <- ssample(idsElig[gElig == 2], nHospG2)
-          status[idsHospG2] <- hospState
-          idsHosp <- c(idsHosp, idsHospG2)
-        }
       }
     }
     dat$attr$status <- status
     dat$attr$hospTime[idsHosp] <- at
-
     # ----- discharge from need to be hospitalised -------
     disch.rand <- dat$control$disch.rand
     disch.rate <- dat$param$disch.rate
diff --git a/man/recovery.FUN.Rd b/man/recovery.FUN.Rd
index d4fc81b..0f81dc3 100644
--- a/man/recovery.FUN.Rd
+++ b/man/recovery.FUN.Rd
@@ -4,12 +4,14 @@
 \alias{recovery.FUN}
 \title{Progress icm}
 \usage{
-recovery.FUN(dat, at)
+recovery.FUN(dat, at, seed = NULL)
 }
 \arguments{
 \item{dat}{Object containing all data}
 
-\item{at}{?}
+\item{at}{time point}
+
+\item{seed}{random seed for checking consistency}
 }
 \value{
 progress
diff --git a/tests/testthat/test-initialize.R b/tests/testthat/test-initialize.R
index 7775cad..6710c67 100644
--- a/tests/testthat/test-initialize.R
+++ b/tests/testthat/test-initialize.R
@@ -16,6 +16,6 @@ test_that("Identical output as Churches' original function: initialize.FUN", {
         i <- i + 1
         rm(.Random.seed)
     }
- 
+    
     expect_equal(sum(comp), No_seeds)
-})
\ No newline at end of file
+})
diff --git a/tests/testthat/test-recovery.R b/tests/testthat/test-recovery.R
new file mode 100644
index 0000000..8b63194
--- /dev/null
+++ b/tests/testthat/test-recovery.R
@@ -0,0 +1,20 @@
+test_that("Identical output as Churches' original function: recovery.FUN", {
+    
+    at <- 2
+    dat <- do.call(initialize.FUN, set_param())
+    dat <- do.call(infection.FUN, list(dat, at))   
+    
+    No_seeds <- 10
+    seed_list <- sample(1:1000, No_seeds)
+    comp <- rep(NA, No_seeds)
+    i <- 1
+    for(seed in seed_list){
+        dat1 <- do.call(recovery.FUN, list(dat, at, seed))
+        dat2 <- do.call(progress.seiqhrf.icm, list(dat, at, seed))
+        comp[i] <- identical(dat1, dat2)
+        i <- i + 1
+        rm(.Random.seed)
+    }
+    
+    expect_equal(sum(comp), No_seeds)
+})
\ No newline at end of file
-- 
GitLab