Skip to content

Commit

Permalink
initial work on lprior tags
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Jan 14, 2025
1 parent 9d1acf2 commit df007fb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
23 changes: 12 additions & 11 deletions R/priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@
#' @export
set_prior <- function(prior, class = "b", coef = "", group = "",
resp = "", dpar = "", nlpar = "",
lb = NA, ub = NA, check = TRUE) {
input <- nlist(prior, class, coef, group, resp, dpar, nlpar, lb, ub, check)
lb = NA, ub = NA, lprior = "", check = TRUE) {
input <- nlist(prior, class, coef, group, resp, dpar, nlpar, lb, ub, lprior, check)
input <- try(as.data.frame(input), silent = TRUE)
if (is_try_error(input)) {
stop2("Processing arguments of 'set_prior' has failed:\n", input)
Expand All @@ -375,7 +375,7 @@ set_prior <- function(prior, class = "b", coef = "", group = "",

# validate arguments passed to 'set_prior'
.set_prior <- function(prior, class, coef, group, resp,
dpar, nlpar, lb, ub, check) {
dpar, nlpar, lb, ub, lprior, check) {
prior <- as_one_character(prior)
class <- as_one_character(class)
group <- as_one_character(group)
Expand All @@ -386,16 +386,17 @@ set_prior <- function(prior, class = "b", coef = "", group = "",
check <- as_one_logical(check)
lb <- as_one_character(lb, allow_na = TRUE)
ub <- as_one_character(ub, allow_na = TRUE)
lprior <- as_one_character(lprior)
if (dpar == "mu") {
# distributional parameter 'mu' is currently implicit #1368
dpar <- ""
}
if (!check) {
# prior will be added to the log-posterior as is
class <- coef <- group <- resp <- dpar <- nlpar <- lb <- ub <- ""
class <- coef <- group <- resp <- dpar <- nlpar <- lb <- ub <- lprior <- ""
}
source <- "user"
out <- nlist(prior, source, class, coef, group, resp, dpar, nlpar, lb, ub)
out <- nlist(prior, source, class, coef, group, resp, dpar, nlpar, lb, ub, lprior)
do_call(brmsprior, out)
}

Expand Down Expand Up @@ -558,7 +559,7 @@ default_prior.default <- function(object, data, family = gaussian(), autocor = N
# explicitly label default priors as such
prior$source <- "default"
# apply 'unique' as the same prior may have been included multiple times
to_order <- with(prior, order(resp, dpar, nlpar, class, group, coef))
to_order <- with(prior, order(resp, dpar, nlpar, class, group, coef, lprior))
prior <- unique(prior[to_order, , drop = FALSE])
rownames(prior) <- NULL
class(prior) <- c("brmsprior", "data.frame")
Expand Down Expand Up @@ -1565,7 +1566,7 @@ get_sample_prior <- function(prior) {
# create data.frames containing prior information
brmsprior <- function(prior = "", class = "", coef = "", group = "",
resp = "", dpar = "", nlpar = "", lb = "", ub = "",
source = "", ls = list()) {
lprior = "", source = "", ls = list()) {
if (length(ls)) {
if (is.null(names(ls))) {
stop("Argument 'ls' must be named.")
Expand All @@ -1580,7 +1581,7 @@ brmsprior <- function(prior = "", class = "", coef = "", group = "",
}
out <- data.frame(
prior, class, coef, group,
resp, dpar, nlpar, lb, ub, source,
resp, dpar, nlpar, lb, ub, lprior, source,
stringsAsFactors = FALSE
)
class(out) <- c("brmsprior", "data.frame")
Expand All @@ -1594,7 +1595,7 @@ empty_prior <- function() {
brmsprior(
prior = char0, source = char0, class = char0,
coef = char0, group = char0, resp = char0,
dpar = char0, nlpar = char0, lb = char0, ub = char0
dpar = char0, nlpar = char0, lb = char0, ub = char0, lprior = char0
)
}

Expand Down Expand Up @@ -1623,7 +1624,7 @@ prior_bounds <- function(prior) {
# all columns of brmsprior objects
all_cols_prior <- function() {
c("prior", "class", "coef", "group", "resp",
"dpar", "nlpar", "lb", "ub", "source")
"dpar", "nlpar", "lb", "ub", "lprior", "source")
}

# relevant columns for duplication checks in brmsprior objects
Expand Down Expand Up @@ -1915,7 +1916,7 @@ as.brmsprior <- function(x) {

defaults <- c(
class = "b", coef = "", group = "", resp = "",
dpar = "", nlpar = "", lb = NA, ub = NA
dpar = "", nlpar = "", lb = NA, ub = NA, lprior = ""
)
for (v in names(defaults)) {
if (!v %in% names(x)) {
Expand Down
17 changes: 14 additions & 3 deletions R/stan-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
c(index) <- j
}
prior_ij <- subset2(prior, coef = coef[i, j])
lprior_tag <- prior_ij$lprior
if (NROW(px) > 1L) {
# disambiguate priors of coefficients with the same name
# coming from different model components
Expand Down Expand Up @@ -131,7 +132,13 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
coef_prior, par_ij, broadcast = broadcast,
bound = bound, resp = px$resp[1], normalize = normalize
)
# add to the lprior
str_add(out$tpar_prior) <- paste0(lpp(), coef_prior, ";\n")
# add to the lprior of the tag if specified
if (!is.null(lprior_tag) && lprior_tag != "") {
str_add(out$tpar_prior) <- paste0(lpp(tag = lprior_tag), coef_prior, ";\n")
}

}
}
}
Expand Down Expand Up @@ -241,7 +248,7 @@ stan_base_prior <- function(prior, col = "prior", sel_prior = NULL, ...) {
return(brmsprior()[, col])
}
}
vars <- c("group", "nlpar", "dpar", "resp", "class")
vars <- c("group", "nlpar", "dpar", "resp", "class", "lprior")
for (v in vars) {
take <- nzchar(prior[[v]])
if (any(take)) {
Expand Down Expand Up @@ -698,7 +705,11 @@ stopif_prior_bound <- function(prior, class, ...) {
}

# lprior plus equal
lpp <- function(wsp = 2) {
lpp <- function(wsp = 2, tag = NULL) {
wsp <- collapse(rep(" ", wsp))
paste0(wsp, "lprior += ")
if (is.null(tag)) {
paste0(wsp, "lprior", " += ")
} else {
paste0(wsp, "lprior_", tag, " += ")
}
}
5 changes: 5 additions & 0 deletions R/stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ stancode.default <- function(object, data, family = gaussian(),
backend = getOption("brms.backend", "rstan"),
silent = TRUE, save_model = NULL, ...) {

lprior_tags <- prior$lprior[prior$lprior != ""]

normalize <- as_one_logical(normalize)
parse <- as_one_logical(parse)
backend <- match.arg(backend, backend_choices())
Expand Down Expand Up @@ -278,12 +280,15 @@ stancode.default <- function(object, data, family = gaussian(),

# generate transformed parameters block
scode_lprior_def <- " real lprior = 0; // prior contributions to the log posterior\n"
scode_lprior_tags_def <- paste0(
" real lprior_", unique(lprior_tags), " = 0;\n", collapse = "")
scode_transformed_parameters <- paste0(
"transformed parameters {\n",
scode_predictor[["tpar_def"]],
scode_re[["tpar_def"]],
scode_Xme[["tpar_def"]],
str_if(normalize, scode_lprior_def),
str_if(normalize, scode_lprior_tags_def),
collapse_stanvars(stanvars, "tparameters", "start"),
scode_predictor[["tpar_prior_const"]],
scode_re[["tpar_prior_const"]],
Expand Down

0 comments on commit df007fb

Please sign in to comment.