#' Get times
#'
#' Function to extract timings and assemble results in a dataframe
#'
#' @param simulate_results results from `simulte()` function.
#'
#' @importFrom tidyr pivot_longer
#' 
get_times <- function(simulate_results) {
    
    sim <- simulate_results$sim
    
    for (s in 1:sim$control$nsims) {
        if (s == 1) {
            times <- sim$times[[paste("sim", s, sep = "")]]
            times <- times %>% mutate(s = s)
        } else {
            times <- times %>% bind_rows(sim$times[[paste("sim", s, sep = "")]] 
                                         %>% mutate(s = s))
        }
    }
    
    times <- times %>% mutate(infTime = ifelse(infTime < 0, -5, infTime), 
                              expTime = ifelse(expTime < 0, -5, expTime)) %>% 
        mutate(incubation_period = infTime - expTime, 
               illness_duration = recovTime - expTime, 
               illness_duration_hosp = dischTime - expTime, 
               hosp_los = dischTime - hospTime, 
               quarantine_delay = quarTime - infTime,
               survival_time = fatTime - infTime) %>% 
        select(s, incubation_period, quarantine_delay, illness_duration, 
               illness_duration_hosp, hosp_los, survival_time) %>% 
        pivot_longer(-s, names_to = "period_type", values_to = "duration") %>% 
        mutate(period_type = factor(period_type, 
                                    levels = c("incubation_period", 
                                               "quarantine_delay", 
                                               "illness_duration", 
                                               "illness_duration_hosp", 
                                               "hosp_los", 
                                               "survival_time"), 
                                    labels = c("Incubation period", 
                                               "Delay entering isolation", 
                                               "Illness duration",
                                               "Illness duration (hosp)", 
                                               "Hosp care required duration", 
                                               "Survival time case fatalities"),
                                    ordered = TRUE))

    return(times)
}


#' Plot times
#'
#' Function to plot Duration frequency distributions.
#'
#' @param times results from `get_times()` function.
#'
#' @return ggplot object
#' 
plot_times <- function(times){
    times %>% filter(duration <= 30) %>% ggplot(aes(x = duration)) + 
        geom_bar() + facet_grid(period_type ~ ., scales = "free_y") + 
        labs(title = "Duration frequency distributions",
             subtitle = "Baseline simulation")
}

#' Plot models
#'
#' Function to plot individuals models or mutliple models for comparison.
#'
#' @param simulate_results results from `simulte()` function.
#'
#' @return ggplot
#' 
#' @import dplyr
#' 
plot_models <- function(sims = baseline_sim,
                        sim_id = 'baseline',
                        comp_remove = 'none',
                        time_lim = 100,
                        h.beds = 40,
                        plot_title = 'ICM plot'){
    
    # Define a standard set of colours to represent compartments
    comps <- c("s.num", "e.num", "i.num", "q.num", "h.num", "r.num", "f.num")
    compcols <- c(s.num = "#4477AA", e.num = "#66CCEE", i.num = "#CCBB44", 
                  q.num = "#AA3377", h.num = "#EE6677", r.num = "#228833", 
                  f.num = "#BBBBBB")
    complabels <- c(s.num = "Susceptible", e.num = "Infected/asymptomatic", 
                    i.num = "Infected/infectious", q.num = "Self-isolated", 
                    h.num = "Requires hospitalisation", r.num = "Recovered",
                    f.num = "Case fatality")

    # Merge models to plot together
    for(i in (1: length(sim_id))){
        if(i == 1){
            plot_df <- sims[i*2]$df
            plot_df <- plot_df %>% mutate(experiment = sim_id[i])
        }else{
            tmp_df <- sims[i*2]$df
            tmp_df <- tmp_df %>% mutate(experiment = sim_id[i])
            plot_df <- plot_df %>% bind_rows(plot_df, tmp_df)
        }
    }
    
    plot_df <- plot_df %>% filter(time <= time_lim) %>% 
        pivot_longer(-c(time, experiment), 
                     names_to = "compartment", 
                     values_to = "count")
    
    # Filter compartments
    comp_plot <- setdiff(comps, comp_remove)
    plot_df <- plot_df %>% filter(compartment %in% c(comp_plot))

    # Plot single model
    if(length(sim_id) == 1){
        plot_df %>% ggplot(aes(x = time, y = count, colour = compartment)) + 
            #scale_y_continuous(trans='log2') +
            geom_line(size = 1.5, alpha = 0.8) + 
            geom_hline(aes(yintercept = h.beds, linetype = "Hospital beds"), 
                       size = 1, color = "#EE6677", alpha = 0.5) +
            scale_colour_manual(values = compcols, labels = complabels) + 
            labs(title = plot_title, 
                 x = "Days since beginning of epidemic", 
                 y = "Prevalence (persons)") +
            theme_bw() 
    }else(
        plot_df %>% ggplot(aes(x = time, y = count, colour = compartment)) + 
            #scale_y_continuous(trans='log2') +
            facet_grid(experiment ~ .) + 
            geom_line(size = 1.5, alpha = 0.8) + 
            geom_hline(aes(yintercept = h.beds, linetype = "Hospital beds"), 
                       size = 1, color = "#EE6677", alpha = 0.5) +
            scale_colour_manual(values = compcols, labels = complabels) + 
            labs(title = plot_title, 
                 x = "Days since beginning of epidemic", 
                 y = "Prevalence (persons)") +
            theme_bw() 

    )

}