Skip to content
Snippets Groups Projects
Commit 464852e3 authored by Jeffrey Pullin's avatar Jeffrey Pullin
Browse files

Add SVM classifier

parent 854d1781
Branches
Tags
No related merge requests found
......@@ -30,6 +30,11 @@ plot_classifier_metric <- function(data, metric) {
plot_metric <- metric_lookup[metric]
plot_data <- dataset_lookup[data$data_id[[1]]]
plot_classifier <- if (data$classifier[[1]] == "knn") {
"KNN"
} else {
"SVM"
}
# https://github.com/tidyverse/ggplot2/issues/2799
cf <- coord_flip(ylim = c(0.5, 1))
......@@ -50,7 +55,8 @@ plot_classifier_metric <- function(data, metric) {
x = "Method",
colour = "Package",
y = plot_metric,
title = paste0("Multiclass prediction ", plot_data, ", ", plot_metric)
title = paste0("Multiclass prediction ", plot_data, ", ", plot_metric,
", ", plot_classifier)
) +
theme_bw()
}
......@@ -89,9 +95,12 @@ plot_confusion_matrix <- function(data, pars) {
```{r load-data}
pred_perf_data <- retrieve_real_data_parameters() %>%
select(-c(fit_method, covariate, rankby, lambda, test_use, rankby_abs, func,
test.type, pval.type, metric, test.use)) %>%
expand_grid(classifier = c("svm", "knn")) %>%
rowwise() %>%
mutate(
pred_perf_filename = paste0("pred_perf-", data_id, "-", method_name, ".rds"),
pred_perf_filename = paste0("pred_perf-", data_id, "-", method_name, "-", classifier, ".rds"),
pred_perf_path = here::here("results", "pred_perf", pred_perf_filename),
pred_perf = list(readRDS(pred_perf_path))
) %>%
......@@ -105,7 +114,11 @@ pred_perf_data <- retrieve_real_data_parameters() %>%
```{r pbmc3k-predicitive-performance}
pred_perf_data %>%
filter(data_id == "pbmc3k") %>%
filter(data_id == "pbmc3k", classifier == "knn") %>%
plot_classifier_metric("median_f1_score")
pred_perf_data %>%
filter(data_id == "pbmc3k", classifier == "svm") %>%
plot_classifier_metric("median_f1_score")
pbmc3k_seurat_wilcox_confmat <- pred_perf_data %>%
......@@ -124,7 +137,11 @@ saveRDS(
```{r endothelial-multicalss-prediction}
pred_perf_data %>%
filter(data_id == "endothelial") %>%
filter(data_id == "endothelial", classifier == "knn") %>%
plot_classifier_metric("mean_f1_score")
pred_perf_data %>%
filter(data_id == "endothelial", classifier == "svm") %>%
plot_classifier_metric("mean_f1_score")
```
......@@ -132,7 +149,11 @@ pred_perf_data %>%
```{r zeisel-multicalss-prediction}
pred_perf_data %>%
filter(data_id == "zeisel") %>%
filter(data_id == "zeisel", classifier == "knn") %>%
plot_classifier_metric("mean_f1_score")
pred_perf_data %>%
filter(data_id == "zeisel", classifier == "svm") %>%
plot_classifier_metric("mean_f1_score")
```
......@@ -140,7 +161,11 @@ pred_perf_data %>%
```{r paul-multiclass-prediction}
pred_perf_data %>%
filter(data_id == "paul") %>%
filter(data_id == "paul", classifier == "knn") %>%
plot_classifier_metric("mean_f1_score")
pred_perf_data %>%
filter(data_id == "paul", classifier == "svm") %>%
plot_classifier_metric("mean_f1_score")
```
......@@ -148,7 +173,11 @@ pred_perf_data %>%
```{r ss3-pbmc-multiclass-prediction}
pred_perf_data %>%
filter(data_id == "ss3_pbmc") %>%
filter(data_id == "ss3_pbmc", classifier == "knn") %>%
plot_classifier_metric("mean_f1_score")
pred_perf_data %>%
filter(data_id == "ss3_pbmc", classifier == "svm") %>%
plot_classifier_metric("mean_f1_score")
```
......@@ -156,7 +185,11 @@ pred_perf_data %>%
```{r mesenchymal-multiclass-prediction}
pred_perf_data %>%
filter(data_id == "mesenchymal") %>%
filter(data_id == "mesenchymal", classifier == "knn") %>%
plot_classifier_metric("mean_f1_score")
pred_perf_data %>%
filter(data_id == "mesenchymal", classifier == "svm") %>%
plot_classifier_metric("mean_f1_score")
```
......@@ -164,15 +197,24 @@ pred_perf_data %>%
```{r lawlor-multiclass-prediction}
pred_perf_data %>%
filter(data_id == "lawlor") %>%
filter(data_id == "lawlor", classifier == "knn") %>%
plot_classifier_metric("mean_f1_score")
pred_perf_data %>%
filter(data_id == "lawlor", classifier == "svm") %>%
plot_classifier_metric("mean_f1_score") +
coord_flip(ylim = c(0.0, 0.6))
```
## Astrocyte
```{r astrocyte-multiclass-prediction}
pred_perf_data %>%
filter(data_id == "astrocyte") %>%
filter(data_id == "astrocyte", classifier == "knn") %>%
plot_classifier_metric("mean_f1_score")
pred_perf_data %>%
filter(data_id == "astrocyte", classifier == "svm") %>%
plot_classifier_metric("mean_f1_score")
```
......@@ -180,7 +222,7 @@ pred_perf_data %>%
```{r zhao-multiclass-prediction}
zhao_pred_perf <- pred_perf_data %>%
filter(data_id == "zhao") %>%
filter(data_id == "zhao", classifier == "knn") %>%
plot_classifier_metric("median_f1_score") +
coord_flip(ylim = c(0.5, 0.85))
......@@ -190,12 +232,18 @@ saveRDS(
zhao_pred_perf,
here::here("figures", "raw", "zhao-pred-perf.rds")
)
pred_perf_data %>%
filter(data_id == "zhao", classifier == "svm") %>%
plot_classifier_metric("median_f1_score") +
coord_flip(ylim = c(0.5, 0.85))
```
## Overall
```{r overall-mutliclass-prediction}
```{r overall-mutliclass-prediction-knn}
overall_multiclass_pred_plot <- pred_perf_data %>%
filter(classifier == "knn") %>%
group_by(data_id, pars, method) %>%
summarise(median_f1_score = median(median_f1_score), .groups = "drop") %>%
group_by(data_id) %>%
......@@ -211,11 +259,11 @@ overall_multiclass_pred_plot <- pred_perf_data %>%
ggplot(aes(x = plot_data_id, y = plot_pars)) +
geom_tile(aes(fill = rank), colour = "black") +
scale_fill_distiller(palette = "RdYlBu",
breaks = seq(1, 41, by = 5),
labels = seq(41, 1, by = -5)) +
breaks = seq(1, 56, by = 5),
labels = seq(56, 1, by = -5)) +
theme_bw() +
labs(
title = "Median F1-score rank across datasets",
title = "Median F1-score rank across datasets, KNN",
x = "Dataset",
y = "Method",
fill = "Rank",
......@@ -229,8 +277,107 @@ overall_multiclass_pred_plot <- pred_perf_data %>%
)
overall_multiclass_pred_plot
pred_perf_data %>%
filter(classifier == "knn") %>%
group_by(data_id, pars, method) %>%
summarise(median_f1_score = median(median_f1_score), .groups = "drop") %>%
group_by(data_id) %>%
mutate(score = scale(median_f1_score)[, 1]) %>%
ungroup() %>%
mutate(
plot_method = method_lookup[method],
plot_pars = pars_lookup[pars],
plot_data_id = dataset_lookup[data_id]
) %>%
mutate(plot_pars = fct_reorder(factor(plot_pars), score, .fun = mean)) %>%
ggplot(aes(x = plot_data_id, y = plot_pars)) +
geom_tile(aes(fill = score), colour = "black") +
scale_fill_distiller(palette = "RdYlBu") +
theme_bw() +
labs(
title = "Mean z-score F1-score across datasets, KNN",
x = "Dataset",
y = "Method",
fill = "Z-score",
) +
theme(
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
axis.ticks.y = element_blank(),
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
)
pred_perf_data %>%
filter(data_id != "lawlor") %>%
filter(classifier == "knn") %>%
group_by(data_id, pars, method) %>%
summarise(median_f1_score = median(median_f1_score), .groups = "drop") %>%
group_by(data_id) %>%
mutate(score = scale(median_f1_score)[, 1]) %>%
ungroup() %>%
mutate(
plot_method = method_lookup[method],
plot_pars = pars_lookup[pars],
plot_data_id = dataset_lookup[data_id]
) %>%
mutate(plot_pars = fct_reorder(factor(plot_pars), score, .fun = mean)) %>%
ggplot(aes(x = plot_data_id, y = plot_pars)) +
geom_tile(aes(fill = score), colour = "black") +
scale_fill_distiller(palette = "RdYlBu") +
theme_bw() +
labs(
title = "Mean z-score F1-score across datasets, SVM",
x = "Dataset",
y = "Method",
fill = "Z-score",
) +
theme(
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
axis.ticks.y = element_blank(),
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
)
saveRDS(
overall_multiclass_pred_plot,
here::here("figures", "raw", "overall-mc-pred-plot.rds")
)
```
```{r overall-mutliclass-prediction-svm}
pred_perf_data %>%
filter(classifier == "svm") %>%
group_by(data_id, pars, method) %>%
summarise(median_f1_score = median(median_f1_score), .groups = "drop") %>%
group_by(data_id) %>%
mutate(rank = rank(median_f1_score)) %>%
ungroup() %>%
mutate(
plot_method = method_lookup[method],
plot_pars = pars_lookup[pars],
plot_data_id = dataset_lookup[data_id]
) %>%
# This step encodes the ranking by rank.
mutate(plot_pars = fct_reorder(factor(plot_pars), rank, .fun = median)) %>%
ggplot(aes(x = plot_data_id, y = plot_pars)) +
geom_tile(aes(fill = rank), colour = "black") +
scale_fill_distiller(palette = "RdYlBu",
breaks = seq(1, 56, by = 5),
labels = seq(56, 1, by = -5)) +
theme_bw() +
labs(
title = "Median F1-score rank across datasets, SVM",
x = "Dataset",
y = "Method",
fill = "Rank",
) +
theme(
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
axis.ticks.y = element_blank(),
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
)
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment