From 701a6993f8a5e8b65799ee3b333e05d31a10219f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Tue, 12 Mar 2024 11:47:24 +0100 Subject: [PATCH] use categorical_logit_glm primitive --- DESCRIPTION | 4 +-- NEWS.md | 2 ++ R/stan-likelihood.R | 52 ++++++++++++++++----------- R/stan-predictor.R | 62 ++++++++++++++++++++++++--------- R/stan-response.R | 1 - tests/testthat/tests.stancode.R | 10 ++++++ 6 files changed, 91 insertions(+), 40 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1a276484d..2dbf1358d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,8 +2,8 @@ Package: brms Encoding: UTF-8 Type: Package Title: Bayesian Regression Models using 'Stan' -Version: 2.20.16 -Date: 2024-03-08 +Version: 2.20.17 +Date: 2024-03-12 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", role = c("aut", "cre")), diff --git a/NEWS.md b/NEWS.md index eb800cc6e..8ff19ca94 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,8 @@ if potentially results-changing arguments are provided to the criterion method. * Allow to turn off automatic broadcasting of `constant` priors. * Allow for joint likelihood evaluation in `kfold` via argument `joint`. +* Use several Stan built-in functions implemented since version 2.26 +to improve the efficiency of multiple model classes. (#1077) ### Other Changes diff --git a/R/stan-likelihood.R b/R/stan-likelihood.R index f5a145f91..ab2d3f3e2 100644 --- a/R/stan-likelihood.R +++ b/R/stan-likelihood.R @@ -745,7 +745,7 @@ stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL, stan_log_lik_cumulative <- function(bterms, resp = "", mix = "", threads = NULL, ...) { - if (use_glm_primitive(bterms, allow_special_terms = FALSE)) { + if (use_glm_primitive(bterms)) { p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) out <- sdist("ordered_logistic_glm", p$x, p$beta, p$alpha) } else { @@ -773,11 +773,16 @@ stan_log_lik_categorical <- function(bterms, resp = "", mix = "", threads = NULL, ...) { stopifnot(bterms$family$link == "logit") stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed - # if (use_glm_primitive_categorical(bterms)) { - # # TODO: support categorical_logit_glm - # } - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi") - sdist("categorical_logit", p$mu) + if (use_glm_primitive_categorical(bterms)) { + bterms1 <- bterms$dpars[[1]] + bterms1$family <- bterms$family + p <- args_glm_primitive(bterms1, resp = resp, threads = threads) + out <- sdist("categorical_logit_glm", p$x, p$alpha, p$beta) + } else { + p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi") + out <- sdist("categorical_logit", p$mu) + } + out } stan_log_lik_multinomial <- function(bterms, resp = "", mix = "", ...) { @@ -1000,10 +1005,8 @@ stan_log_lik_custom <- function(bterms, resp = "", mix = "", threads = NULL, ... # use Stan GLM primitive functions? # @param bterms a brmsterms object -# @param allow_special_terms still use glm primitives if -# random effects, splines, etc. are present? # @return TRUE or FALSE -use_glm_primitive <- function(bterms, allow_special_terms = TRUE) { +use_glm_primitive <- function(bterms) { stopifnot(is.brmsterms(bterms)) # the model can only have a single predicted parameter # and no additional residual or autocorrelation structure @@ -1016,6 +1019,7 @@ use_glm_primitive <- function(bterms, allow_special_terms = TRUE) { } # some primitives do not support special terms in the way # required by brms' Stan code generation + allow_special_terms <- !mu$family$family %in% c("cumulative", "categorical") if (!allow_special_terms && has_special_terms(mu)) { return(FALSE) } @@ -1033,22 +1037,22 @@ use_glm_primitive <- function(bterms, allow_special_terms = TRUE) { # use Stan categorical GLM primitive function? # @param bterms a brmsterms object -# @param ... passed to use_glm_primitive # @return TRUE or FALSE -use_glm_primitive_categorical <- function(bterms, ...) { - # NOTE: this function is not yet in use; see stan_log_lik_categorical +use_glm_primitive_categorical <- function(bterms) { stopifnot(is.brmsterms(bterms)) - stopifnot(is_categorical(bterms)) - bterms_tmp <- bterms - bterms_tmp$dpars <- list() + if (!is_categorical(bterms)) { + return(FALSE) + } + tmp <- bterms + tmp$dpars <- list() # we know that all dpars in categorical models are mu parameters out <- rep(FALSE, length(bterms$dpars)) for (i in seq_along(bterms$dpars)) { - bterms_tmp$dpars$mu <- bterms$dpars[[i]] - bterms_tmp$dpars$mu$family <- bterms$family - out[i] <- use_glm_primitive(bterms_tmp, ...) && + tmp$dpars$mu <- bterms$dpars[[i]] + tmp$dpars$mu$family <- bterms$family + out[i] <- use_glm_primitive(tmp) && # the design matrix of all mu parameters must match - all.equal(bterms_tmp$dpars$mu$fe, bterms$dpars[[1]]$fe) + all.equal(tmp$dpars$mu$fe, bterms$dpars[[1]]$fe) } all(out) } @@ -1068,6 +1072,10 @@ args_glm_primitive <- function(bterms, resp = "", threads = NULL) { } else if (center_X) { sfx_X <- "c" } + is_categorical <- is_categorical(bterms) + if (is_categorical) { + sfx_X <- glue("{sfx_X}_{bterms$dpar}") + } x <- glue("X{sfx_X}{resp}{slice}") beta <- glue("b{sfx_b}{resp}") if (has_special_terms(bterms)) { @@ -1077,7 +1085,11 @@ args_glm_primitive <- function(bterms, resp = "", threads = NULL) { if (center_X) { alpha <- glue("Intercept{resp}") } else { - alpha <- "0" + if (is_categorical) { + alpha <- glue("rep_vector(0, ncat{resp})") + } else { + alpha <- "0" + } } } nlist(x, alpha, beta) diff --git a/R/stan-predictor.R b/R/stan-predictor.R index a80ab49b3..45a40f1da 100644 --- a/R/stan-predictor.R +++ b/R/stan-predictor.R @@ -46,7 +46,7 @@ stan_predictor.brmsterms <- function(x, data, prior, normalize, ...) { str_add_list(out) <- stan_response(x, data = data, normalize = normalize) valid_dpars <- valid_dpars(x) args <- nlist(data, prior, normalize, nlpars = names(x$nlpars), ...) - args$primitive <- use_glm_primitive(x) + args$primitive <- use_glm_primitive(x) || use_glm_primitive_categorical(x) for (nlp in names(x$nlpars)) { nlp_args <- list(x$nlpars[[nlp]]) str_add_list(out) <- do_call(stan_predictor, c(nlp_args, args)) @@ -2095,25 +2095,53 @@ stan_dpar_transform <- function(bterms, prior, threads, normalize, ...) { resp <- usc(bterms$resp) if (any(conv_cats_dpars(families))) { stopifnot(length(families) == 1L) - is_logistic_normal <- any(is_logistic_normal(families)) - len_mu <- glue("ncat{p}{str_if(is_logistic_normal, '-1')}") - str_add(out$model_def) <- glue( - " // linear predictor matrix\n", - " array[N{resp}] vector[{len_mu}] mu{p};\n" - ) - mu_dpars <- make_stan_names(glue("mu{bterms$family$cats}")) - mu_dpars <- glue("{mu_dpars}{p}[n]") iref <- get_refcat(bterms$family, int = TRUE) - if (is_logistic_normal) { - mu_dpars <- mu_dpars[-iref] + mus <- make_stan_names(glue("mu{bterms$family$cats}")) + mus <- glue("{mus}{p}") + if (use_glm_primitive_categorical(bterms)) { + bterms1 <- bterms$dpars[[1]] + center_X <- stan_center_X(bterms1) + ct <- str_if(center_X, "c") + K <- glue("K{ct}_{bterms1$dpar}{p}") + str_add(out$model_def) <- glue( + " // joint regression coefficients over categories\n", + " matrix[{K}, ncat{p}] b{p};\n" + ) + bnames <- glue("b_{mus}") + bnames[iref] <- glue("rep_vector(0, {K})") + str_add(out$model_comp_catjoin) <- cglue( + " b{p}[, {seq_along(bnames)}] = {bnames};\n" + ) + if (center_X) { + Inames <- glue("Intercept_{mus}") + Inames[iref] <- "0" + str_add(out$model_def) <- glue( + " // joint intercepts over categories\n", + " vector[ncat{p}] Intercept{p};\n" + ) + str_add(out$model_comp_catjoin) <- glue( + " Intercept{p} = {stan_vector(Inames)};\n" + ) + } } else { - mu_dpars[iref] <- "0" + is_logistic_normal <- any(is_logistic_normal(families)) + len_mu <- glue("ncat{p}{str_if(is_logistic_normal, '-1')}") + str_add(out$model_def) <- glue( + " // linear predictor matrix\n", + " array[N{resp}] vector[{len_mu}] mu{p};\n" + ) + mus <- glue("{mus}[n]") + if (is_logistic_normal) { + mus <- mus[-iref] + } else { + mus[iref] <- "0" + } + str_add(out$model_comp_catjoin) <- glue( + " for (n in 1:N{resp}) {{\n", + " mu{p}[n] = {stan_vector(mus)};\n", + " }}\n" + ) } - str_add(out$model_comp_catjoin) <- glue( - " for (n in 1:N{resp}) {{\n", - " mu{p}[n] = {stan_vector(mu_dpars)};\n", - " }}\n" - ) } if (any(families %in% "skew_normal")) { # as suggested by Stephen Martin use sigma and mu of CP diff --git a/R/stan-response.R b/R/stan-response.R index 2da027189..6446cfbca 100644 --- a/R/stan-response.R +++ b/R/stan-response.R @@ -774,4 +774,3 @@ stan_hurdle_ordinal_lpmf <- function(family, link) { } out } - diff --git a/tests/testthat/tests.stancode.R b/tests/testthat/tests.stancode.R index 6437b6042..d6c831808 100644 --- a/tests/testthat/tests.stancode.R +++ b/tests/testthat/tests.stancode.R @@ -646,6 +646,16 @@ test_that("Stan code for categorical models is correct", { scode <- stancode(y ~ x + (1 |ID| .g), data = dat, family = categorical(refcat = NA)) expect_match2(scode, "mu[n] = transpose([mu1[n], mu2[n], mu3[n], muab[n]]);") + + # test use of glm primitive + scode <- stancode(y ~ x, data = dat, family = categorical()) + expect_match2(scode, "b[, 1] = rep_vector(0, Kc_mu2);") + expect_match2(scode, "b[, 3] = b_mu3;") + expect_match2(scode, "Intercept = transpose([0, Intercept_mu2, Intercept_mu3, Intercept_muab]);") + expect_match2(scode, "target += categorical_logit_glm_lpmf(Y | Xc_mu2, Intercept, b);") + + scode <- stancode(bf(y ~ x, center = FALSE), data = dat, family = categorical()) + expect_match2(scode, "target += categorical_logit_glm_lpmf(Y | X_mu2, rep_vector(0, ncat), b);") }) test_that("Stan code for multinomial models is correct", {