Skip to content
Snippets Groups Projects
FIN_plot.R 15.7 KiB
Newer Older
#' Plot simulation result 
#' Function to extract timings, prevalence etc. from the simulation 
#' @param x An seiqhrf object returned from function \code{\link{seiqhrf}}.
#' @param method If "times", plot Duration frequency distributions.
#'               If "weekly_local", plot local weekly estimates from simulation.
#'               If NULL, plot sirplus plots.
#' @param return_df In effect only when method == "weekly", if TRUE returns 
#'        also the dataframe used for plotting as well as the ggplot object.
#' @param comp_remove Compartments to remove. Suggest c(s.num, r.num)
#' @param time_lim Number of steps (days) to plot. 
#' @param ci y/n to include 95% confidence intervals in sirplus plot.
#' @param sep_compartments y/n use faceting to show each compartment in a 
#'        separate plot, only works if plotting a single simulation.
#' @param trans Y-axis transformation (e.g. log2, log10). Default = none. 
#' @param known Dataframe with known compartment numbers to plot alongside
#'        projections
#' @param start_date Date for day 0. Default: ymd("2020-03-21"),
#' @param x_axis Title for x-axis. Default: 'Days since beginning of epidemic'
#' @param plot_title Title for whole plot. Default: 'SEIQHRF plot'
#' @param market.share between 0 and 1, percentage of local hospital beds in 
#'        the simulated unit (e.g. state)
#' @param icu_percent between 0 and 1, percentage of patients that should go to 
#'        ICU among the ones that need hospitalization
#' @param total_population True population size, needed only if simulation size 
#'        is smaller than the true population size due to computational cost 
#'        etc.
#' @param ... Additional parameters
#' 
#' @return ggplot2 object
#' 
#' @importFrom tidyr pivot_longer
#' @importFrom dplyr mutate
#' @importFrom dplyr "%>%"
#' @importFrom dplyr bind_rows
#' @importFrom dplyr select
#' @importFrom dplyr filter
#' @import ggplot2
#' @export
plot.seiqhrf <- function(x, 
                         method = NULL, 
                         comp_remove = "none",
                         time_lim = 90,
                         ci = 'y',
                         sep_compartments = 'n',
                         trans = 'na',
                         known = NULL,
                         start_date = ymd("2020-03-21"),
                         x_axis = 'Days since beginning of epidemic',
                         plot_title = 'SEIQHRF', 
                         return_df = TRUE, 
                         market.share = .04,
                         icu_percent = .1, 
                         total_population = NULL, ...) {
pqiao29's avatar
pqiao29 committed
        plot_sirplus(x, comp_remove = comp_remove,
                     time_lim = time_lim,
                     ci = ci,
                     sep_compartments = sep_compartments,
                     trans = trans,
                     known = known,
                     start_date = start_date,
                     x_axis = x_axis,
                     plot_title = plot_title)
        
    }else if(method == "times"){
            if(!inherits(x, "seiqhrf")) stop("If method == times, x needs to be an seiqhrf object")
            plot_times(x) 
            
    }else if(method == "weekly_local"){
        
            ret <- get_weekly_local(x, market.share = market.share,
                                    icu_percent = icu_percent, 
                                    start_date = start_date,
                                    time_limit = time_lim,
                                    total_population = total_population)
             if(return_df){
                 return(ret) 
             }else{
                 return(ret$plot)
             }
#' Wrapper for primary sirplus plotting function
#' Flexible function to generate sirplus plots (i.e. compartment counts over 
#' time). This function allows for plotting multiple experiments, viewing the 
#' plots of different scales (e.g. log2), plotting compartments separately,
#' adding 95% CIs, and plotting known data along side the simulations.  
#' @param x An seiqhrf object returned from function \code{\link{seiqhrf}}.
#' @param method If "times", plot Duration frequency distributions.
#'               If NULL, plot sirplus plots.
#' @param comp_remove Compartments to remove. Suggest c(s.num, r.num)
#' @param time_lim Number of steps (days) to plot. 
#' @param ci y/n to include 95% confidence intervals in sirplus plot.
#' @param sep_compartments y/n use faceting to show each compartment in a 
#'        separate plot, only works if plotting a single simulation.
#' @param trans Y-axis transformation (e.g. log2, log10). Default = none. 
#' @param known Dataframe with known compartment numbers to plot alongside
#'        projections
#' @param start_date Date for day 0. Default: ymd("2020-03-21"),
#' @param x_axis Title for x-axis. Default: 'Days since beginning of epidemic'
#' @param plot_title Title for whole plot. Default: 'SEIQHRF plot'
#' @param ... Additional parameters
Christina Azodi's avatar
Christina Azodi committed
#' 
#' @return ggplot2 object
Christina Azodi's avatar
Christina Azodi committed
#' 
#' @importFrom tidyr pivot_longer
Davis McCarthy's avatar
Davis McCarthy committed
#' @importFrom dplyr mutate
#' @importFrom dplyr "%>%"
#' @importFrom dplyr bind_rows
#' @importFrom dplyr select
#' @importFrom dplyr filter
Davis McCarthy's avatar
Davis McCarthy committed
#' @export
pqiao29's avatar
pqiao29 committed
plot_sirplus <- function(x,comp_remove = comp_remove,
                         time_lim = time_lim,
                         ci = ci,
                         sep_compartments = sep_compartments,
                         trans = trans,
                         known = known,
                         start_date = start_date,
                         x_axis = x_axis,
                         plot_title = plot_title, ...){
    
    # Convert from seiqhrf object to dataframe
    plot_df <- format_sims(x, time_lim = time_lim, start_date = start_date)
    reo_exp <- function(x) {factor(x, levels = unique(plot_df$experiment))}
    
    # Get Confidence Intervals
    if(ci =='y'){
        plot_df <- get_ci(x, plot_df)
    }
    
    # Add known compartment counts
    if(is.data.frame(known)){
        plot_df <- add_known(plot_df, known = known, start_date = start_date)
    }
    
    # Define compartment names and colours
    comps <- c("s.num", "e.num", "i.num", "q.num", "h.num", "r.num", "f.num")
    compcols <- c(s.num = "#4477AA", e.num = "#66CCEE", i.num = "#CCBB44", 
                  q.num = "#AA3377", h.num = "#EE6677", r.num = "#228833", 
                  f.num = "#BBBBBB")
    complabels <- c(s.num = "S: Susceptible", e.num = "E: Asymptomatic", 
                    i.num = "I: Infected", q.num = "Q: Self-isolated", 
                    h.num = "H: Hospitalized", r.num = "R: Recovered",
                    f.num = "F: Case Fatalities")
    
    # Filter compartments
    comp_plot <- setdiff(comps, comp_remove)

    plot_df <- plot_df %>% filter(compartment %in% c(comp_plot)) 

    # Plot with options
    p <- ggplot(plot_df, aes(x = Date, y = count, colour = compartment, linetype = sim)) + 
        geom_line(size = 1.2, alpha = 0.8) + 
        scale_x_date(date_breaks = "1 week", date_labels = "%m-%d") + 
        scale_colour_manual(values = compcols, labels = complabels) + 
        labs(title = plot_title, x = x_axis, y = "Prevalence") +
        theme_bw() + theme(axis.text.x = element_text(angle = 90))
    
    if(length(unique(plot_df$experiment)) > 1){
        p <- p + facet_grid(reo_exp(experiment) ~ ., scale = 'free')
    }
    
    if(sep_compartments == 'y'){
        p <- p + facet_grid(compartment ~ ., scales = 'free')
    }
    
    if(trans != "na"){
        p <- p + scale_y_continuous(trans = trans) 
    }
    if(ci == 'y'){
        p <- p + geom_ribbon(aes(ymin=qntCI.1, ymax=qntCI.2, x=Date, 
                                 fill = compartment, colour = NULL),
                             alpha = 0.4) +
            scale_color_manual(values = compcols, labels = complabels) +
            scale_fill_manual(values = compcols, guide = FALSE)
    }
    
    p
#' Plot compartment duration distributions
#' Function to plot Duration frequency distributions. If multiple simulations 
#' were performed (nsim >1), durations from sims are appended to each other.
#'
#' @param sim An seiqhrf object returned from function \code{\link{seiqhrf}}.
#' 
#' @return ggplot2 object
#' 
#' @import ggplot2
#' @importFrom tidyr pivot_longer
#' @importFrom dplyr mutate
#' @importFrom dplyr "%>%"
#' @importFrom dplyr bind_rows
#' @importFrom dplyr select
#' @importFrom dplyr filter
#' 
#' @export
plot_times <- function(sim) {
    
    for (s in 1:sim$control$nsims) {
        if (s == 1) {
            times <- sim$times[[paste("sim", s, sep = "")]]
            times <- times %>% mutate(s = s)
        } else {
            times <- times %>% bind_rows(sim$times[[paste("sim", s, sep = "")]] 
                                         %>% mutate(s = s))
        }
    }
    
    times <- times %>% mutate(infTime = ifelse(infTime < 0, -5, infTime), 
                              expTime = ifelse(expTime < 0, -5, expTime)) %>% 
        mutate(incubation_period = infTime - expTime, 
               illness_duration = recovTime - expTime, 
               illness_duration_hosp = dischTime - expTime, 
               hosp_los = dischTime - hospTime, 
               quarantine_delay = quarTime - infTime,
               survival_time = fatTime - infTime) %>% 
        select(s, incubation_period, quarantine_delay, illness_duration, 
               illness_duration_hosp, hosp_los, survival_time) %>% 
        pivot_longer(-s, names_to = "period_type", values_to = "duration") %>% 
        mutate(period_type = factor(period_type, 
                                    levels = c("incubation_period", 
                                               "quarantine_delay", 
                                               "illness_duration", 
                                               "illness_duration_hosp", 
                                               "hosp_los", 
                                               "survival_time"), 
                                    labels = c("Incubation\nperiod", 
                                               "Delay entering\nisolation", 
                                               "Illness\nduration",
                                               "Illness\nduration (hosp)", 
                                               "Hospital stay\nduration", 
                                               "Survival time\nfor fatalities"),
                                    ordered = TRUE))
    times %>% filter(duration <= 30) %>% ggplot(aes(x = duration)) + 
        geom_bar() + facet_grid(period_type ~ ., scales = "free_y") + 
        labs(title = "Compartment Duration Distributions")
#' Format seiqhrf objects into dataframe for ggplot
#'
#' @param x An seiqhrf object returned from function \code{\link{seiqhrf}}.
#' @param time_lim Number of steps (days) to plot. 
#' @param start_date Date for day 0. Default: ymd("2020-03-21"),
#' 
#' @return dataframe
#' 
#' @export
format_sims <- function(x, time_lim = time_lim, start_date = start_date){
    # Merge models to plot together
    if(class(x) == "seiqhrf"){
        sim_id <- "seiqhrf model" 
        plot_df <- as.data.frame(x, out = "mean")
        plot_df <- plot_df %>% mutate(experiment = sim_id)
    }else{
        
        if(is.null(sim_id)) stop("Please assign a name to each element in sims")
        
        plot_df <- as.data.frame(x[[1]], out = "mean")
        plot_df <- plot_df %>% mutate(experiment = sim_id[1])
        if(length(sim_id) > 1){
            for (i in (2:length(sim_id))) {
                tmp_df <- as.data.frame(x[[i]], out = "mean")
                tmp_df <- tmp_df %>% mutate(experiment = sim_id[i])
                plot_df <- plot_df %>% bind_rows(plot_df, tmp_df)
            }
    plot_df <- plot_df %>% filter(time <= time_lim) %>% 
Davis McCarthy's avatar
Davis McCarthy committed
        pivot_longer(-c(time, experiment),
                     names_to = "compartment", 
                     values_to = "count") %>%
        mutate(sim = 'sim',
               Date = start_date + time)

    return(plot_df)
}


#' Get 95% confidence intervals
#'
#' @param x An seiqhrf object returned from function \code{\link{seiqhrf}}.
#' @param known Dataframe with known compartment numbers to plot alongside
#'        projections
#' 
#' @return dataframe with CIs and sd added
#' @importFrom tidyr separate
#' 
#' @export
#' 
get_ci <- function(x, plot_df){
    # Get sim variance metrics for single seiqhrf object
    if(class(x) == "seiqhrf"){
        
        ci_info <- as.data.frame.list(summary.seiqhrf(x))
        print(head(ci_info))
        ci_info <- ci_info %>% mutate(time = as.numeric(row.names(ci_info))) %>%
            pivot_longer(cols = -time, names_to = 'compartment',
                         values_to = 'mean') %>%
            tidyr::separate(compartment, into = c('compartment', 'metric'), sep='num.') %>%
            mutate(compartment = paste0(compartment, 'num'),
                   experiment = 'seiqhrf model') %>%
            pivot_wider(names_from = metric, values_from = mean)
        ci_info <- as.data.frame.list(summary.seiqhrf(x[[1]]))
        ci_info <- ci_info %>% mutate(time =  as.numeric(row.names(ci_info))) %>%
            pivot_longer(cols = -time, names_to = 'compartment',
                         values_to = 'mean') %>%
            tidyr::separate(compartment, into = c('compartment', 'metric'), sep='num.') %>%
            mutate(compartment = paste0(compartment, 'num'),
                   experiment = sim_id[1]) %>%
            pivot_wider(names_from = metric, values_from = mean) 
        if(length(sim_id) > 1){
            for (i in (2:length(sim_id))) {
                
                ci_tmp <- as.data.frame.list(summary.seiqhrf(x[[i]]))
                ci_tmp <- ci_tmp %>% mutate(time = as.numeric(row.names(ci_tmp))) %>%
                    pivot_longer(cols = -time, names_to = 'compartment',
                                 values_to = 'mean') %>%
                    tidyr::separate(compartment, 
                                    into = c('compartment', 'metric'), 
                                    sep='num.') %>%
                    mutate(compartment = paste0(compartment, 'num'),
                           experiment = sim_id[i]) %>%
                    pivot_wider(names_from = metric, values_from = mean) 
                ci_info <- ci_info %>% bind_rows(ci_info, ci_tmp)
            }
        }
    ci_info <- ci_info %>% mutate(sim = 'sim') %>%
        mutate(sim = 'sim')
    ci_info[is.na(ci_info)] <- 0
        
    plot_df <- plot_df %>% full_join(ci_info, by = c('time', 'compartment', 'experiment', 'sim'))
    
    return(plot_df)
}

#' Add known counts to sims dataframe for ggplot
#'
#' @param x An seiqhrf object returned from function \code{\link{seiqhrf}}.
#' @param known Dataframe with known compartment numbers to plot alongside
#'        projections
#' 
#' @return dataframe with known data added
#' 
#' @export
#' 
add_known <- function(plot_df, known = known, start_date = start_date){
    # Add Date to known data
    missing_cols <- setdiff(names(plot_df), names(known))
    if("Date" %in% missing_cols){
        known$Date = start_date + known$time
    
    known <- known %>% pivot_longer(cols = -c(time, Date), 
                                    names_to = 'compartment',
                                    values_to = 'count')
    exps <- unique(plot_df$experiment)
    for(i in exps){
        known <- known %>% mutate(experiment = i, sim = 'known')
        missing_cols <- setdiff(names(plot_df), names(known))
        if(length(missing_cols) > 0){
            add <- rep(0, length(missing_cols))
            names(add) <- missing_cols
            known <- known %>% mutate(!!! add)
        }
        plot_df <- rbind(plot_df, known)
    } 
    
    return(plot_df)
}