From c65408dacb0342fb4dae3d091c3546abb0ab73ba Mon Sep 17 00:00:00 2001
From: pqiao29 <pqiao@student.unimelb.edu.au>
Date: Mon, 30 Mar 2020 15:54:37 +1100
Subject: [PATCH] clean up initialize.FUN

---
 R/Chu_init.R                     | 170 +++++++++++++++++++++++++++++++
 R/init.R                         | 150 ++++++++-------------------
 man/initialize.FUN.Rd            |   6 +-
 tests/testthat/test-initialize.R |  21 ++++
 4 files changed, 236 insertions(+), 111 deletions(-)
 create mode 100644 R/Chu_init.R
 create mode 100644 tests/testthat/test-initialize.R

diff --git a/R/Chu_init.R b/R/Chu_init.R
new file mode 100644
index 0000000..69caf0b
--- /dev/null
+++ b/R/Chu_init.R
@@ -0,0 +1,170 @@
+## Churches' original function
+## internal in sirplus
+
+initialize.icm <- function(param, init, control, seed = NULL) {
+    
+    if(!is.null(seed)) set.seed(seed)
+    
+    ## Master List for Data ##
+    dat <- list()
+    dat$param <- param
+    dat$init <- init
+    dat$control <- control
+    
+    
+    # Set attributes
+    dat$attr <- list()
+    numeric.init <- init[which(sapply(init, class) == "numeric")]
+    n <- do.call("sum", numeric.init)
+    dat$attr$active <- rep(1, n)
+    if (dat$param$groups == 1) {
+        dat$attr$group <- rep(1, n)
+    } else {
+        g2inits <- grep(".g2", names(numeric.init))
+        g1inits <- setdiff(1:length(numeric.init), g2inits)
+        nG1 <- sum(sapply(g1inits, function(x) init[[x]]))
+        nG2 <- sum(sapply(g2inits, function(x) init[[x]]))
+        dat$attr$group <- c(rep(1, nG1), rep(2, max(0, nG2)))
+    }
+    
+    # Initialize status and infection time
+    dat <- init_status.icm(dat)
+    
+    
+    # Summary out list
+    dat <- get_prev.icm(dat, at = 1)
+    
+    return(dat)
+}
+
+
+init_status.icm <- function(dat) {
+    
+    # Variables ---------------------------------------------------------------
+    type <- dat$control$type
+    group <- dat$attr$group
+    nGroups <- dat$param$groups
+    
+    nG1 <- sum(group == 1)
+    nG2 <- sum(group == 2)
+    
+    e.num <- dat$init$e.num
+    i.num <- dat$init$i.num
+    q.num <- dat$init$q.num
+    h.num <- dat$init$h.num
+    r.num <- dat$init$r.num
+    f.num <- dat$init$f.num
+    e.num.g2 <- dat$init$e.num.g2
+    i.num.g2 <- dat$init$i.num.g2
+    q.num.g2 <- dat$init$q.num.g2
+    h.num.g2 <- dat$init$h.num.g2
+    r.num.g2 <- dat$init$r.num.g2
+    f.num.g2 <- dat$init$f.num.g2
+    
+    # Status ------------------------------------------------------------------
+    status <- rep("s", nG1 + nG2)
+    status[sample(which(group == 1), size = i.num)] <- "i"
+    if (nGroups == 2) {
+        status[sample(which(group == 2), size = i.num.g2)] <- "i"
+    }
+    if (type %in% c("SIR", "SEIR", "SEIQHR", "SEIQHRF")) {
+        status[sample(which(group == 1 & status == "s"), size = r.num)] <- "r"
+        if (nGroups == 2) {
+            status[sample(which(group == 2 & status == "s"), size = r.num.g2)] <- "r"
+        }
+    }
+    if (type %in% c("SEIR", "SEIQHR", "SEIQHRF")) {
+        status[sample(which(group == 1 & status == "s"), size = e.num)] <- "e"
+        if (nGroups == 2) {
+            status[sample(which(group == 2 & status == "s"), size = e.num.g2)] <- "e"
+        }
+    }
+    if (type %in% c("SEIQHR", "SEIQHRF")) {
+        status[sample(which(group == 1 & status == "s"), size = q.num)] <- "q"
+        if (nGroups == 2) {
+            status[sample(which(group == 2 & status == "s"), size = q.num.g2)] <- "q"
+        }
+        status[sample(which(group == 1 & status == "s"), size = h.num)] <- "h"
+        if (nGroups == 2) {
+            status[sample(which(group == 2 & status == "s"), size = h.num.g2)] <- "h"
+        }
+    }
+    if (type %in% c("SEIQHRF")) {
+        status[sample(which(group == 1 & status == "s"), size = f.num)] <- "f"
+        if (nGroups == 2) {
+            status[sample(which(group == 2 & status == "s"), size = f.num.g2)] <- "f"
+        }
+    }
+    
+    dat$attr$status <- status
+    
+    
+    # Exposure Time ----------------------------------------------------------
+    idsExp <- which(status == "e")
+    expTime <- rep(NA, length(status))
+    # leave exposure time uninitialised for now, and 
+    # just set to NA at start.
+    dat$attr$expTime <- expTime
+    
+    # Infection Time ----------------------------------------------------------
+    idsInf <- which(status == "i")
+    infTime <- rep(NA, length(status))
+    dat$attr$infTime <- infTime # overwritten below
+    
+    # Recovery Time ----------------------------------------------------------
+    idsRecov <- which(status == "r")
+    recovTime <- rep(NA, length(status))
+    dat$attr$recovTime <- recovTime
+    
+    # Need for Hospitalisation Time ----------------------------------------------------------
+    idsHosp <- which(status == "h")
+    hospTime <- rep(NA, length(status))
+    dat$attr$hospTime <- hospTime
+    
+    # Quarantine Time ----------------------------------------------------------
+    idsQuar <- which(status == "q")
+    quarTime <- rep(NA, length(status))
+    dat$attr$quarTime <- quarTime
+    
+    # Hospital-need cessation  Time ----------------------------------------------------------
+    dischTime <- rep(NA, length(status))
+    dat$attr$dischTime <- dischTime
+    
+    # Case-fatality  Time ----------------------------------------------------------
+    fatTime <- rep(NA, length(status))
+    dat$attr$fatTime <- fatTime
+    
+    # If vital=TRUE, infTime is a uniform draw over the duration of infection
+    # note the initial infections may have negative infTime!
+    if (FALSE) {
+        # not sure what the following section is trying to do, but it
+        # mucks up the gamma-distributed incumabtion periods, so set 
+        # infTime for initial infected people to t=1 instead
+        if (dat$param$vital == TRUE && dat$param$di.rate > 0) {
+            infTime[idsInf] <- -rgeom(n = length(idsInf), prob = dat$param$di.rate) + 2
+        } else {
+            if (dat$control$type == "SI" || dat$param$rec.rate == 0) {
+                # infTime a uniform draw over the number of sim time steps
+                infTime[idsInf] <- ssample(1:(-dat$control$nsteps + 2),
+                                           length(idsInf), replace = TRUE)
+            } else {
+                if (nGroups == 1) {
+                    infTime[idsInf] <- ssample(1:(-round(1 / dat$param$rec.rate) + 2),
+                                               length(idsInf), replace = TRUE)
+                }
+                if (nGroups == 2) {
+                    infG1 <- which(status == "i" & group == 1)
+                    infTime[infG1] <- ssample(1:(-round(1 / dat$param$rec.rate) + 2),
+                                              length(infG1), replace = TRUE)
+                    infG2 <- which(status == "i" & group == 2)
+                    infTime[infG2] <- ssample(1:(-round(1 / dat$param$rec.rate.g2) + 2),
+                                              length(infG2), replace = TRUE)
+                }
+            }
+        }
+    }
+    infTime[idsInf] <- 1
+    dat$attr$infTime <- infTime
+    
+    return(dat)
+}
\ No newline at end of file
diff --git a/R/init.R b/R/init.R
index b8b3352..1eb38a5 100644
--- a/R/init.R
+++ b/R/init.R
@@ -4,41 +4,33 @@
 #'
 #' @param param ICM parameters.
 #' @param init Initial value parameters.
-#' @param control Control parameters.
+#' @param control Control parameters
+#' @param seed random seed for checking consistency with other versions.
 #'
 #' @return Updated dat
 #' @export
-initialize.FUN <- function(param, init, control) {
-
+initialize.FUN <- function(param, init, control, seed = NULL) {
+  if(!is.null(seed)) set.seed(seed)
+  
   ## Master List for Data ##
   dat <- list()
   dat$param <- param
   dat$init <- init
   dat$control <- control
-
-
+  
   # Set attributes
   dat$attr <- list()
   numeric.init <- init[which(sapply(init, class) == "numeric")]
   n <- do.call("sum", numeric.init)
   dat$attr$active <- rep(1, n)
-  if (dat$param$groups == 1) {
-    dat$attr$group <- rep(1, n)
-  } else {
-    g2inits <- grep(".g2", names(numeric.init))
-    g1inits <- setdiff(1:length(numeric.init), g2inits)
-    nG1 <- sum(sapply(g1inits, function(x) init[[x]]))
-    nG2 <- sum(sapply(g2inits, function(x) init[[x]]))
-    dat$attr$group <- c(rep(1, nG1), rep(2, max(0, nG2)))
-  }
-
+  dat$attr$group <- rep(1, n)
+  
   # Initialize status and infection time
   dat <- init_status.icm(dat)
-
-
+  
   # Summary out list
   dat <- get_prev.icm(dat, at = 1)
-
+  
   return(dat)
 }
 
@@ -53,15 +45,14 @@ initialize.FUN <- function(param, init, control) {
 #' @importFrom EpiModel ssample
 #' @export
 init_status.icm <- function(dat) {
-
+  
   # Variables ---------------------------------------------------------------
   type <- dat$control$type
   group <- dat$attr$group
   nGroups <- dat$param$groups
-
-  nG1 <- sum(group == 1)
-  nG2 <- sum(group == 2)
-
+  
+  nG <- sum(group == 1)
+  
   e.num <- dat$init$e.num
   i.num <- dat$init$i.num
   q.num <- dat$init$q.num
@@ -74,111 +65,52 @@ init_status.icm <- function(dat) {
   h.num.g2 <- dat$init$h.num.g2
   r.num.g2 <- dat$init$r.num.g2
   f.num.g2 <- dat$init$f.num.g2
-
+  
   # Status ------------------------------------------------------------------
-  status <- rep("s", nG1 + nG2)
+  status <- rep("s", nG)
   status[sample(which(group == 1), size = i.num)] <- "i"
-  if (nGroups == 2) {
-    status[sample(which(group == 2), size = i.num.g2)] <- "i"
-  }
   if (type %in% c("SIR", "SEIR", "SEIQHR", "SEIQHRF")) {
     status[sample(which(group == 1 & status == "s"), size = r.num)] <- "r"
-    if (nGroups == 2) {
-      status[sample(which(group == 2 & status == "s"), size = r.num.g2)] <- "r"
-    }
   }
   if (type %in% c("SEIR", "SEIQHR", "SEIQHRF")) {
     status[sample(which(group == 1 & status == "s"), size = e.num)] <- "e"
-    if (nGroups == 2) {
-      status[sample(which(group == 2 & status == "s"), size = e.num.g2)] <- "e"
-    }
   }
   if (type %in% c("SEIQHR", "SEIQHRF")) {
     status[sample(which(group == 1 & status == "s"), size = q.num)] <- "q"
-    if (nGroups == 2) {
-      status[sample(which(group == 2 & status == "s"), size = q.num.g2)] <- "q"
-    }
     status[sample(which(group == 1 & status == "s"), size = h.num)] <- "h"
-    if (nGroups == 2) {
-      status[sample(which(group == 2 & status == "s"), size = h.num.g2)] <- "h"
-    }
   }
   if (type %in% c("SEIQHRF")) {
     status[sample(which(group == 1 & status == "s"), size = f.num)] <- "f"
-    if (nGroups == 2) {
-      status[sample(which(group == 2 & status == "s"), size = f.num.g2)] <- "f"
-    }
   }
-
+  
   dat$attr$status <- status
-
-
-  # Exposure Time ----------------------------------------------------------
-  idsExp <- which(status == "e")
-  expTime <- rep(NA, length(status))
+  n <- length(status)
+  
   # leave exposure time uninitialised for now, and
   # just set to NA at start.
-  dat$attr$expTime <- expTime
-
+  
+  # Exposure Time ----------------------------------------------------------
+  dat$attr$expTime <- rep(NA, n)
+  
   # Infection Time ----------------------------------------------------------
-  idsInf <- which(status == "i")
-  infTime <- rep(NA, length(status))
-  dat$attr$infTime <- infTime # overwritten below
-
-# Recovery Time ----------------------------------------------------------
-  idsRecov <- which(status == "r")
-  recovTime <- rep(NA, length(status))
-  dat$attr$recovTime <- recovTime
-
- # Need for Hospitalisation Time ----------------------------------------------------------
-  idsHosp <- which(status == "h")
-  hospTime <- rep(NA, length(status))
-  dat$attr$hospTime <- hospTime
-
-   # Quarantine Time ----------------------------------------------------------
-  idsQuar <- which(status == "q")
-  quarTime <- rep(NA, length(status))
-  dat$attr$quarTime <- quarTime
-
- # Hospital-need cessation  Time ----------------------------------------------------------
-  dischTime <- rep(NA, length(status))
-  dat$attr$dischTime <- dischTime
-
- # Case-fatality  Time ----------------------------------------------------------
-  fatTime <- rep(NA, length(status))
-  dat$attr$fatTime <- fatTime
-
-  # If vital=TRUE, infTime is a uniform draw over the duration of infection
-  # note the initial infections may have negative infTime!
-  if (FALSE) {
-    # not sure what the following section is trying to do, but it
-    # mucks up the gamma-distributed incumabtion periods, so set
-    # infTime for initial infected people to t=1 instead
-    if (dat$param$vital == TRUE && dat$param$di.rate > 0) {
-      infTime[idsInf] <- -rgeom(n = length(idsInf), prob = dat$param$di.rate) + 2
-    } else {
-      if (dat$control$type == "SI" || dat$param$rec.rate == 0) {
-        # infTime a uniform draw over the number of sim time steps
-        infTime[idsInf] <- ssample(1:(-dat$control$nsteps + 2),
-                                   length(idsInf), replace = TRUE)
-      } else {
-        if (nGroups == 1) {
-          infTime[idsInf] <- ssample(1:(-round(1 / dat$param$rec.rate) + 2),
-                                     length(idsInf), replace = TRUE)
-        }
-        if (nGroups == 2) {
-          infG1 <- which(status == "i" & group == 1)
-          infTime[infG1] <- ssample(1:(-round(1 / dat$param$rec.rate) + 2),
-                                    length(infG1), replace = TRUE)
-          infG2 <- which(status == "i" & group == 2)
-          infTime[infG2] <- ssample(1:(-round(1 / dat$param$rec.rate.g2) + 2),
-                                    length(infG2), replace = TRUE)
-        }
-      }
-    }
-  }
-  infTime[idsInf] <- 1
+  infTime <- rep(NA, n)
+  infTime[status == "i"] <- 1
   dat$attr$infTime <- infTime
-
+  
+  # Recovery Time ----------------------------------------------------------
+  dat$attr$recovTime <- rep(NA, n)
+  
+  # Need for Hospitalisation Time ----------------------------------------------------------
+  dat$attr$hospTime <- rep(NA, n)
+  
+  # Quarantine Time ----------------------------------------------------------
+  dat$attr$quarTime <- rep(NA, n)
+  
+  # Hospital-need cessation  Time ----------------------------------------------------------
+  dat$attr$dischTime <- rep(NA, n)
+  
+  # Case-fatality  Time ----------------------------------------------------------
+  dat$attr$fatTime <- rep(NA, n)
+  
   return(dat)
 }
diff --git a/man/initialize.FUN.Rd b/man/initialize.FUN.Rd
index 8df60cc..3138a1b 100644
--- a/man/initialize.FUN.Rd
+++ b/man/initialize.FUN.Rd
@@ -4,14 +4,16 @@
 \alias{initialize.FUN}
 \title{Initialize ICM}
 \usage{
-initialize.FUN(param, init, control)
+initialize.FUN(param, init, control, seed = NULL)
 }
 \arguments{
 \item{param}{ICM parameters.}
 
 \item{init}{Initial value parameters.}
 
-\item{control}{Control parameters.}
+\item{control}{Control parameters}
+
+\item{seed}{random seed for checking consistency with other versions.}
 }
 \value{
 Updated dat
diff --git a/tests/testthat/test-initialize.R b/tests/testthat/test-initialize.R
new file mode 100644
index 0000000..7775cad
--- /dev/null
+++ b/tests/testthat/test-initialize.R
@@ -0,0 +1,21 @@
+test_that("Identical output as Churches' original function: initialize.FUN", {
+    
+    full_params <- set_param()
+    control <- full_params$control
+    param <- full_params$param
+    init <- full_params$init
+    
+    No_seeds <- 10
+    seed_list <- sample(1:1000, No_seeds)
+    comp <- rep(NA, No_seeds)
+    i <- 1
+    for(seed in seed_list){
+        dat1 <- do.call(initialize.icm, list(param, init, control, seed))
+        dat2 <- do.call(initialize.FUN, list(param, init, control, seed))
+        comp[i] <- identical(dat1, dat2)
+        i <- i + 1
+        rm(.Random.seed)
+    }
+ 
+    expect_equal(sum(comp), No_seeds)
+})
\ No newline at end of file
-- 
GitLab