Skip to content

Commit

Permalink
use categorical_logit_glm primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Mar 12, 2024
1 parent 15c0754 commit 701a699
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 40 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.20.16
Date: 2024-03-08
Version: 2.20.17
Date: 2024-03-12
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
role = c("aut", "cre")),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
if potentially results-changing arguments are provided to the criterion method.
* Allow to turn off automatic broadcasting of `constant` priors.
* Allow for joint likelihood evaluation in `kfold` via argument `joint`.
* Use several Stan built-in functions implemented since version 2.26
to improve the efficiency of multiple model classes. (#1077)

### Other Changes

Expand Down
52 changes: 32 additions & 20 deletions R/stan-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL,

stan_log_lik_cumulative <- function(bterms, resp = "", mix = "",
threads = NULL, ...) {
if (use_glm_primitive(bterms, allow_special_terms = FALSE)) {
if (use_glm_primitive(bterms)) {
p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads)
out <- sdist("ordered_logistic_glm", p$x, p$beta, p$alpha)
} else {
Expand Down Expand Up @@ -773,11 +773,16 @@ stan_log_lik_categorical <- function(bterms, resp = "", mix = "",
threads = NULL, ...) {
stopifnot(bterms$family$link == "logit")
stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed
# if (use_glm_primitive_categorical(bterms)) {
# # TODO: support categorical_logit_glm
# }
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi")
sdist("categorical_logit", p$mu)
if (use_glm_primitive_categorical(bterms)) {
bterms1 <- bterms$dpars[[1]]
bterms1$family <- bterms$family
p <- args_glm_primitive(bterms1, resp = resp, threads = threads)
out <- sdist("categorical_logit_glm", p$x, p$alpha, p$beta)
} else {
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi")
out <- sdist("categorical_logit", p$mu)
}
out
}

stan_log_lik_multinomial <- function(bterms, resp = "", mix = "", ...) {
Expand Down Expand Up @@ -1000,10 +1005,8 @@ stan_log_lik_custom <- function(bterms, resp = "", mix = "", threads = NULL, ...

# use Stan GLM primitive functions?
# @param bterms a brmsterms object
# @param allow_special_terms still use glm primitives if
# random effects, splines, etc. are present?
# @return TRUE or FALSE
use_glm_primitive <- function(bterms, allow_special_terms = TRUE) {
use_glm_primitive <- function(bterms) {
stopifnot(is.brmsterms(bterms))
# the model can only have a single predicted parameter
# and no additional residual or autocorrelation structure
Expand All @@ -1016,6 +1019,7 @@ use_glm_primitive <- function(bterms, allow_special_terms = TRUE) {
}
# some primitives do not support special terms in the way
# required by brms' Stan code generation
allow_special_terms <- !mu$family$family %in% c("cumulative", "categorical")
if (!allow_special_terms && has_special_terms(mu)) {
return(FALSE)
}
Expand All @@ -1033,22 +1037,22 @@ use_glm_primitive <- function(bterms, allow_special_terms = TRUE) {

# use Stan categorical GLM primitive function?
# @param bterms a brmsterms object
# @param ... passed to use_glm_primitive
# @return TRUE or FALSE
use_glm_primitive_categorical <- function(bterms, ...) {
# NOTE: this function is not yet in use; see stan_log_lik_categorical
use_glm_primitive_categorical <- function(bterms) {
stopifnot(is.brmsterms(bterms))
stopifnot(is_categorical(bterms))
bterms_tmp <- bterms
bterms_tmp$dpars <- list()
if (!is_categorical(bterms)) {
return(FALSE)
}
tmp <- bterms
tmp$dpars <- list()
# we know that all dpars in categorical models are mu parameters
out <- rep(FALSE, length(bterms$dpars))
for (i in seq_along(bterms$dpars)) {
bterms_tmp$dpars$mu <- bterms$dpars[[i]]
bterms_tmp$dpars$mu$family <- bterms$family
out[i] <- use_glm_primitive(bterms_tmp, ...) &&
tmp$dpars$mu <- bterms$dpars[[i]]
tmp$dpars$mu$family <- bterms$family
out[i] <- use_glm_primitive(tmp) &&
# the design matrix of all mu parameters must match
all.equal(bterms_tmp$dpars$mu$fe, bterms$dpars[[1]]$fe)
all.equal(tmp$dpars$mu$fe, bterms$dpars[[1]]$fe)
}
all(out)
}
Expand All @@ -1068,6 +1072,10 @@ args_glm_primitive <- function(bterms, resp = "", threads = NULL) {
} else if (center_X) {
sfx_X <- "c"
}
is_categorical <- is_categorical(bterms)
if (is_categorical) {
sfx_X <- glue("{sfx_X}_{bterms$dpar}")
}
x <- glue("X{sfx_X}{resp}{slice}")
beta <- glue("b{sfx_b}{resp}")
if (has_special_terms(bterms)) {
Expand All @@ -1077,7 +1085,11 @@ args_glm_primitive <- function(bterms, resp = "", threads = NULL) {
if (center_X) {
alpha <- glue("Intercept{resp}")
} else {
alpha <- "0"
if (is_categorical) {
alpha <- glue("rep_vector(0, ncat{resp})")
} else {
alpha <- "0"
}
}
}
nlist(x, alpha, beta)
Expand Down
62 changes: 45 additions & 17 deletions R/stan-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ stan_predictor.brmsterms <- function(x, data, prior, normalize, ...) {
str_add_list(out) <- stan_response(x, data = data, normalize = normalize)
valid_dpars <- valid_dpars(x)
args <- nlist(data, prior, normalize, nlpars = names(x$nlpars), ...)
args$primitive <- use_glm_primitive(x)
args$primitive <- use_glm_primitive(x) || use_glm_primitive_categorical(x)
for (nlp in names(x$nlpars)) {
nlp_args <- list(x$nlpars[[nlp]])
str_add_list(out) <- do_call(stan_predictor, c(nlp_args, args))
Expand Down Expand Up @@ -2095,25 +2095,53 @@ stan_dpar_transform <- function(bterms, prior, threads, normalize, ...) {
resp <- usc(bterms$resp)
if (any(conv_cats_dpars(families))) {
stopifnot(length(families) == 1L)
is_logistic_normal <- any(is_logistic_normal(families))
len_mu <- glue("ncat{p}{str_if(is_logistic_normal, '-1')}")
str_add(out$model_def) <- glue(
" // linear predictor matrix\n",
" array[N{resp}] vector[{len_mu}] mu{p};\n"
)
mu_dpars <- make_stan_names(glue("mu{bterms$family$cats}"))
mu_dpars <- glue("{mu_dpars}{p}[n]")
iref <- get_refcat(bterms$family, int = TRUE)
if (is_logistic_normal) {
mu_dpars <- mu_dpars[-iref]
mus <- make_stan_names(glue("mu{bterms$family$cats}"))
mus <- glue("{mus}{p}")
if (use_glm_primitive_categorical(bterms)) {
bterms1 <- bterms$dpars[[1]]
center_X <- stan_center_X(bterms1)
ct <- str_if(center_X, "c")
K <- glue("K{ct}_{bterms1$dpar}{p}")
str_add(out$model_def) <- glue(
" // joint regression coefficients over categories\n",
" matrix[{K}, ncat{p}] b{p};\n"
)
bnames <- glue("b_{mus}")
bnames[iref] <- glue("rep_vector(0, {K})")
str_add(out$model_comp_catjoin) <- cglue(
" b{p}[, {seq_along(bnames)}] = {bnames};\n"
)
if (center_X) {
Inames <- glue("Intercept_{mus}")
Inames[iref] <- "0"
str_add(out$model_def) <- glue(
" // joint intercepts over categories\n",
" vector[ncat{p}] Intercept{p};\n"
)
str_add(out$model_comp_catjoin) <- glue(
" Intercept{p} = {stan_vector(Inames)};\n"
)
}
} else {
mu_dpars[iref] <- "0"
is_logistic_normal <- any(is_logistic_normal(families))
len_mu <- glue("ncat{p}{str_if(is_logistic_normal, '-1')}")
str_add(out$model_def) <- glue(
" // linear predictor matrix\n",
" array[N{resp}] vector[{len_mu}] mu{p};\n"
)
mus <- glue("{mus}[n]")
if (is_logistic_normal) {
mus <- mus[-iref]
} else {
mus[iref] <- "0"
}
str_add(out$model_comp_catjoin) <- glue(
" for (n in 1:N{resp}) {{\n",
" mu{p}[n] = {stan_vector(mus)};\n",
" }}\n"
)
}
str_add(out$model_comp_catjoin) <- glue(
" for (n in 1:N{resp}) {{\n",
" mu{p}[n] = {stan_vector(mu_dpars)};\n",
" }}\n"
)
}
if (any(families %in% "skew_normal")) {
# as suggested by Stephen Martin use sigma and mu of CP
Expand Down
1 change: 0 additions & 1 deletion R/stan-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -774,4 +774,3 @@ stan_hurdle_ordinal_lpmf <- function(family, link) {
}
out
}

10 changes: 10 additions & 0 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,16 @@ test_that("Stan code for categorical models is correct", {
scode <- stancode(y ~ x + (1 |ID| .g), data = dat,
family = categorical(refcat = NA))
expect_match2(scode, "mu[n] = transpose([mu1[n], mu2[n], mu3[n], muab[n]]);")

# test use of glm primitive
scode <- stancode(y ~ x, data = dat, family = categorical())
expect_match2(scode, "b[, 1] = rep_vector(0, Kc_mu2);")
expect_match2(scode, "b[, 3] = b_mu3;")
expect_match2(scode, "Intercept = transpose([0, Intercept_mu2, Intercept_mu3, Intercept_muab]);")
expect_match2(scode, "target += categorical_logit_glm_lpmf(Y | Xc_mu2, Intercept, b);")

scode <- stancode(bf(y ~ x, center = FALSE), data = dat, family = categorical())
expect_match2(scode, "target += categorical_logit_glm_lpmf(Y | X_mu2, rep_vector(0, ncat), b);")
})

test_that("Stan code for multinomial models is correct", {
Expand Down

0 comments on commit 701a699

Please sign in to comment.