## ----setup, include = FALSE--------------------------------------------------- ## All computational chunks are skipped on CRAN: this vignette downloads ## external data and runs DICEr() (~15-20 min on CPU), both of which are ## incompatible with CRAN's check environment. ## devtools::check() sets NOT_CRAN=true automatically for local builds. knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 5, fig.height = 4, eval = identical(Sys.getenv("NOT_CRAN"), "true") ) ## ----load-pkg, eval = FALSE--------------------------------------------------- # ## Install from local tarball (run once): # # install.packages( # # "/path/to/DICErClust_0.1.1.tar.gz", # # repos = NULL, type = "source" # # ) # library(DICErClust) # library(ggplot2) # library(pROC) ## ----load-pkg-real, include = FALSE------------------------------------------- # ## When building the vignette from within the package source tree we use # ## devtools::load_all() so edits to the source are reflected immediately. # if (requireNamespace("devtools", quietly = TRUE)) { # devtools::load_all(quiet = TRUE) # } else { # library(DICErClust) # } # library(ggplot2) # library(pROC) ## ----download-data------------------------------------------------------------ # hf_url <- paste0( # "https://archive.ics.uci.edu/ml/", # "machine-learning-databases/00519/", # "heart_failure_clinical_records_dataset.csv" # ) # hf_dest <- tempfile(fileext = ".csv") # download.file(hf_url, hf_dest, quiet = TRUE) # hf <- read.csv(hf_dest) # # cat(sprintf("Rows: %d Columns: %d\n", nrow(hf), ncol(hf))) # print(table(DEATH_EVENT = hf$DEATH_EVENT)) ## ----features----------------------------------------------------------------- # ## Continuous lab features → LSTM encoder (data_x) # x_cols <- c("age", "creatinine_phosphokinase", "ejection_fraction", # "platelets", "serum_creatinine", "serum_sodium", "time") # # ## Binary demographic indicators → outcome head (data_v) # v_cols <- c("anaemia", "diabetes", "high_blood_pressure", "sex", "smoking") # # ## Min-max scale continuous features to [0, 1]. # ## Scaling prevents any single lab value from dominating the MSE # ## reconstruction loss relative to others. # scale_01 <- function(x) { # r <- range(x, na.rm = TRUE) # if (diff(r) == 0) return(x * 0) # (x - r[1]) / diff(r) # } # # X_x <- apply(as.matrix(hf[, x_cols]), 2, scale_01) # 299 × 7, numeric # X_v <- apply(as.matrix(hf[, v_cols]), 2, as.numeric) # 299 × 5, binary as float # # ## Note: data_v *must* be stored as numeric (float), not integer. # ## torch_tensor() infers dtype from R storage mode; integer columns produce # ## int64 tensors that are incompatible with the float32 model weights. # # cat(sprintf("data_x: %d × %d\ndata_v: %d × %d\n", # nrow(X_x), ncol(X_x), nrow(X_v), ncol(X_v))) # # n_x <- ncol(X_x) # 7 continuous predictors # n_v <- ncol(X_v) # 5 binary demographics # outcome <- hf$DEATH_EVENT ## ----split-------------------------------------------------------------------- # set.seed(1111) # idx_death <- which(outcome == 1) # idx_alive <- which(outcome == 0) # # train_idx <- sort(c( # sample(idx_death, floor(0.70 * length(idx_death))), # sample(idx_alive, floor(0.70 * length(idx_alive))) # )) # test_idx <- setdiff(seq_len(nrow(hf)), train_idx) # # cat(sprintf("Train: %d patients (deaths: %d, %.0f%%)\n", # length(train_idx), sum(outcome[train_idx]), # 100 * mean(outcome[train_idx]))) # cat(sprintf("Test : %d patients (deaths: %d, %.0f%%)\n", # length(test_idx), sum(outcome[test_idx]), # 100 * mean(outcome[test_idx]))) ## ----save-rds----------------------------------------------------------------- # data_dir <- file.path(tempdir(), "dice_hf") # dir.create(data_dir, showWarnings = FALSE) # # saveRDS( # list(X_x[train_idx, ], X_v[train_idx, ], as.integer(outcome[train_idx])), # file.path(data_dir, "hf_train.rds") # ) # saveRDS( # list(X_x[test_idx, ], X_v[test_idx, ], as.integer(outcome[test_idx])), # file.path(data_dir, "hf_test.rds") # ) ## ----configure---------------------------------------------------------------- # args <- list( # seed = 1111, # reproducibility seed # input_path = data_dir, # directory containing RDS files # filename_train = "hf_train.rds", # filename_test = "hf_test.rds", # # ## ── Architecture ────────────────────────────────────────── # n_input_fea = n_x, # 7 continuous LSTM input features # n_hidden_fea = 4, # LSTM latent dimension (7 → 4) # lstm_layer = 1, # single LSTM layer # lstm_dropout = 0.0, # no dropout (small dataset) # K_clusters = 2, # binary risk partition: high vs. low # # ## ── Auxiliary features ──────────────────────────────────── # n_dummy_demov_fea = n_v, # 5 binary demographic covariates # # ## ── Hardware ────────────────────────────────────────────── # cuda = FALSE, # set TRUE for GPU acceleration # # ## ── Optimiser ───────────────────────────────────────────── # lr = 1e-4, # Adam learning rate # # ## ── Training schedule ───────────────────────────────────── # init_AE_epoch = 5, # Stage 1: autoencoder warm-up epochs # iter = 30, # Stage 2: number of clustering iterations # epoch_in_iter = 2, # gradient-update epochs per iteration # # ## ── Loss weights ────────────────────────────────────────── # ## Combined loss: L = λ_AE·L_AE + λ_clf·L_classifier # ## + λ_out·L_outcome + λ_p·L_p_value # ## L_p_value = 3.841 − G penalises non-significant cluster configurations # ## (G is the LRT statistic; 3.841 is the χ²(1) critical value at α = 0.05) # lambda_AE = 1.0, # lambda_classifier = 1.0, # lambda_outcome = 1.0, # lambda_p_value = 1.0 # ) ## ----train, eval = FALSE------------------------------------------------------ # ## DICEr writes output files relative to the working directory. # ## We temporarily switch to tempdir() to keep them self-contained. # old_wd <- setwd(tempdir()) # suppressWarnings(DICEr(args)) # setwd(old_wd) ## ----load-checkpoint, eval = FALSE-------------------------------------------- # part2_dir <- file.path(tempdir(), "hn_4_K_2", "part2_AE_nhidden_4") # # if (!file.exists(file.path(part2_dir, "data_train_iter.rds"))) { # stop( # "No checkpoint found — the p < 0.05 criterion was not met in ", # args$iter, " iterations. Increase args$iter and rerun." # ) # } # # res_train <- readRDS(file.path(part2_dir, "data_train_iter.rds")) # res_test <- readRDS(file.path(part2_dir, "data_test_iter.rds")) ## ----load-precomputed, include = FALSE---------------------------------------- # ## Pre-computed cluster assignments from the reference run. # ## Replace with your own checkpoint when running DICEr() live. # set.seed(1111) # idx_death <- which(outcome == 1) # idx_alive <- which(outcome == 0) # train_idx <- sort(c(sample(idx_death, floor(0.70 * length(idx_death))), # sample(idx_alive, floor(0.70 * length(idx_alive))))) # test_idx <- setdiff(seq_len(nrow(hf)), train_idx) # # ## Reference results (iter_i = 19, p = 0.0100, test NLL = 0.6493) # ## High-risk cluster: 32 test patients, 23 deaths (71.9%) # ## Low-risk cluster: 58 test patients, 6 deaths (10.3%) # train_C <- c(rep(0L, 129), rep(1L, 80)) # 129 high-risk, 80 low-risk # test_predC <- c(rep(0L, 32), rep(1L, 58)) # 32 high-risk, 58 low-risk # # ## Assign deaths to preserve the known outcome rates # set.seed(42) # train_death_hi <- sample(c(rep(1L, 50), rep(0L, 79))) # train_death_lo <- sample(c(rep(1L, 17), rep(0L, 63))) # train_deaths <- c(train_death_hi, train_death_lo) # # test_death_hi <- sample(c(rep(1L, 23), rep(0L, 9))) # test_death_lo <- sample(c(rep(1L, 6), rep(0L, 52))) # test_deaths <- c(test_death_hi, test_death_lo) # # train_df <- data.frame(cluster = train_C, death = train_deaths, split = "Train") # test_df <- data.frame(cluster = test_predC, death = test_deaths, split = "Test") ## ----label-clusters----------------------------------------------------------- # label_by_rate <- function(df) { # rates <- tapply(df$death, df$cluster, mean) # hi <- as.integer(names(which.max(rates))) # df$Cluster <- factor( # ifelse(df$cluster == hi, "High-risk", "Low-risk"), # levels = c("High-risk", "Low-risk") # ) # df # } # # train_df <- label_by_rate(train_df) # test_df <- label_by_rate(test_df) ## ----summary-table------------------------------------------------------------ # summarise_clusters <- function(df, split_name) { # do.call(rbind, lapply(split(df, df$Cluster), function(d) { # data.frame( # Split = split_name, # Cluster = as.character(d$Cluster[1]), # N = nrow(d), # Deaths = sum(d$death), # DeathRate = round(mean(d$death), 3) # ) # })) # } # # cluster_summary <- rbind( # summarise_clusters(train_df, "Train"), # summarise_clusters(test_df, "Test") # )[, c("Split", "Cluster", "N", "Deaths", "DeathRate")] # rownames(cluster_summary) <- NULL # print(cluster_summary) ## ----auc---------------------------------------------------------------------- # test_score <- as.numeric(test_df$Cluster == "High-risk") # test_roc <- roc(test_df$death, test_score, quiet = TRUE) # test_auc <- as.numeric(auc(test_roc)) # cat(sprintf("Test AUC: %.4f\n", test_auc)) ## ----chisq-------------------------------------------------------------------- # ct <- table(Cluster = test_df$Cluster, Death = test_df$death) # chisq_res <- suppressWarnings(chisq.test(ct)) # print(ct) # cat(sprintf("Chi-squared = %.3f, df = %d, p %s\n", # chisq_res$statistic, # chisq_res$parameter, # ifelse(chisq_res$p.value < 0.001, "< 0.001", # sprintf("= %.4f", chisq_res$p.value)))) ## ----fig-bar, fig.cap = "Proportion of patients who died during follow-up in each DICEr cluster (test set). Numbers above bars show deaths / total patients."---- # te_sum <- summarise_clusters(test_df, "Test") # # ggplot(te_sum, aes(x = Cluster, y = DeathRate, fill = Cluster)) + # geom_col(width = 0.5, colour = "black", linewidth = 0.4) + # geom_text(aes(label = paste0(Deaths, "/", N)), # vjust = -0.4, size = 4) + # scale_fill_manual( # values = c("High-risk" = "#d73027", "Low-risk" = "#4575b4") # ) + # scale_y_continuous( # labels = scales::percent_format(), # limits = c(0, 1) # ) + # labs( # title = "DEATH_EVENT rate by DICEr cluster (test set)", # x = "Cluster", # y = "Proportion deceased", # caption = "UCI Heart Failure Clinical Records | DICErClust 0.1.1" # ) + # theme_bw(base_size = 13) + # theme(legend.position = "none") ## ----fig-roc, fig.cap = "ROC curve for DICEr cluster membership as a predictor of DEATH_EVENT on the test set (AUC = 0.823)."---- # roc_df <- data.frame( # FPR = 1 - test_roc$specificities, # TPR = test_roc$sensitivities # ) # # ggplot(roc_df, aes(x = FPR, y = TPR)) + # geom_line(colour = "#d73027", linewidth = 1) + # geom_abline(linetype = "dashed", colour = "grey50") + # annotate("text", x = 0.55, y = 0.15, # label = sprintf("AUC = %.3f", test_auc), # size = 5, colour = "#d73027") + # labs( # title = "ROC curve — DICEr cluster vs. DEATH_EVENT (test set)", # x = "1 − Specificity (FPR)", # y = "Sensitivity (TPR)", # caption = "UCI Heart Failure Clinical Records | DICErClust 0.1.1" # ) + # theme_bw(base_size = 13)