diff --git a/R/priors.R b/R/priors.R index d3342aa6..82fcf22e 100644 --- a/R/priors.R +++ b/R/priors.R @@ -360,8 +360,8 @@ #' @export set_prior <- function(prior, class = "b", coef = "", group = "", resp = "", dpar = "", nlpar = "", - lb = NA, ub = NA, check = TRUE) { - input <- nlist(prior, class, coef, group, resp, dpar, nlpar, lb, ub, check) + lb = NA, ub = NA, lprior = "", check = TRUE) { + input <- nlist(prior, class, coef, group, resp, dpar, nlpar, lb, ub, lprior, check) input <- try(as.data.frame(input), silent = TRUE) if (is_try_error(input)) { stop2("Processing arguments of 'set_prior' has failed:\n", input) @@ -375,7 +375,7 @@ set_prior <- function(prior, class = "b", coef = "", group = "", # validate arguments passed to 'set_prior' .set_prior <- function(prior, class, coef, group, resp, - dpar, nlpar, lb, ub, check) { + dpar, nlpar, lb, ub, lprior, check) { prior <- as_one_character(prior) class <- as_one_character(class) group <- as_one_character(group) @@ -386,16 +386,17 @@ set_prior <- function(prior, class = "b", coef = "", group = "", check <- as_one_logical(check) lb <- as_one_character(lb, allow_na = TRUE) ub <- as_one_character(ub, allow_na = TRUE) + lprior <- as_one_character(lprior) if (dpar == "mu") { # distributional parameter 'mu' is currently implicit #1368 dpar <- "" } if (!check) { # prior will be added to the log-posterior as is - class <- coef <- group <- resp <- dpar <- nlpar <- lb <- ub <- "" + class <- coef <- group <- resp <- dpar <- nlpar <- lb <- ub <- lprior <- "" } source <- "user" - out <- nlist(prior, source, class, coef, group, resp, dpar, nlpar, lb, ub) + out <- nlist(prior, source, class, coef, group, resp, dpar, nlpar, lb, ub, lprior) do_call(brmsprior, out) } @@ -558,7 +559,7 @@ default_prior.default <- function(object, data, family = gaussian(), autocor = N # explicitly label default priors as such prior$source <- "default" # apply 'unique' as the same prior may have been included multiple times - to_order <- with(prior, order(resp, dpar, nlpar, class, group, coef)) + to_order <- with(prior, order(resp, dpar, nlpar, class, group, coef, lprior)) prior <- unique(prior[to_order, , drop = FALSE]) rownames(prior) <- NULL class(prior) <- c("brmsprior", "data.frame") @@ -1565,7 +1566,7 @@ get_sample_prior <- function(prior) { # create data.frames containing prior information brmsprior <- function(prior = "", class = "", coef = "", group = "", resp = "", dpar = "", nlpar = "", lb = "", ub = "", - source = "", ls = list()) { + lprior = "", source = "", ls = list()) { if (length(ls)) { if (is.null(names(ls))) { stop("Argument 'ls' must be named.") @@ -1580,7 +1581,7 @@ brmsprior <- function(prior = "", class = "", coef = "", group = "", } out <- data.frame( prior, class, coef, group, - resp, dpar, nlpar, lb, ub, source, + resp, dpar, nlpar, lb, ub, lprior, source, stringsAsFactors = FALSE ) class(out) <- c("brmsprior", "data.frame") @@ -1594,7 +1595,7 @@ empty_prior <- function() { brmsprior( prior = char0, source = char0, class = char0, coef = char0, group = char0, resp = char0, - dpar = char0, nlpar = char0, lb = char0, ub = char0 + dpar = char0, nlpar = char0, lb = char0, ub = char0, lprior = char0 ) } @@ -1623,7 +1624,7 @@ prior_bounds <- function(prior) { # all columns of brmsprior objects all_cols_prior <- function() { c("prior", "class", "coef", "group", "resp", - "dpar", "nlpar", "lb", "ub", "source") + "dpar", "nlpar", "lb", "ub", "lprior", "source") } # relevant columns for duplication checks in brmsprior objects @@ -1915,7 +1916,7 @@ as.brmsprior <- function(x) { defaults <- c( class = "b", coef = "", group = "", resp = "", - dpar = "", nlpar = "", lb = NA, ub = NA + dpar = "", nlpar = "", lb = NA, ub = NA, lprior = "" ) for (v in names(defaults)) { if (!v %in% names(x)) { diff --git a/R/stan-prior.R b/R/stan-prior.R index 4307d48e..92fd1585 100644 --- a/R/stan-prior.R +++ b/R/stan-prior.R @@ -95,6 +95,7 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL, c(index) <- j } prior_ij <- subset2(prior, coef = coef[i, j]) + lprior_tag <- prior_ij$lprior if (NROW(px) > 1L) { # disambiguate priors of coefficients with the same name # coming from different model components @@ -131,7 +132,13 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL, coef_prior, par_ij, broadcast = broadcast, bound = bound, resp = px$resp[1], normalize = normalize ) + # add to the lprior str_add(out$tpar_prior) <- paste0(lpp(), coef_prior, ";\n") + # add to the lprior of the tag if specified + if (!is.null(lprior_tag) && lprior_tag != "") { + str_add(out$tpar_prior) <- paste0(lpp(tag = lprior_tag), coef_prior, ";\n") + } + } } } @@ -241,7 +248,7 @@ stan_base_prior <- function(prior, col = "prior", sel_prior = NULL, ...) { return(brmsprior()[, col]) } } - vars <- c("group", "nlpar", "dpar", "resp", "class") + vars <- c("group", "nlpar", "dpar", "resp", "class", "lprior") for (v in vars) { take <- nzchar(prior[[v]]) if (any(take)) { @@ -698,7 +705,11 @@ stopif_prior_bound <- function(prior, class, ...) { } # lprior plus equal -lpp <- function(wsp = 2) { +lpp <- function(wsp = 2, tag = NULL) { wsp <- collapse(rep(" ", wsp)) - paste0(wsp, "lprior += ") + if (is.null(tag)) { + paste0(wsp, "lprior", " += ") + } else { + paste0(wsp, "lprior_", tag, " += ") + } } diff --git a/R/stancode.R b/R/stancode.R index ece2f74f..4b0ae17b 100644 --- a/R/stancode.R +++ b/R/stancode.R @@ -114,6 +114,8 @@ stancode.default <- function(object, data, family = gaussian(), backend = getOption("brms.backend", "rstan"), silent = TRUE, save_model = NULL, ...) { + lprior_tags <- prior$lprior[prior$lprior != ""] + normalize <- as_one_logical(normalize) parse <- as_one_logical(parse) backend <- match.arg(backend, backend_choices()) @@ -278,12 +280,15 @@ stancode.default <- function(object, data, family = gaussian(), # generate transformed parameters block scode_lprior_def <- " real lprior = 0; // prior contributions to the log posterior\n" + scode_lprior_tags_def <- paste0( + " real lprior_", unique(lprior_tags), " = 0;\n", collapse = "") scode_transformed_parameters <- paste0( "transformed parameters {\n", scode_predictor[["tpar_def"]], scode_re[["tpar_def"]], scode_Xme[["tpar_def"]], str_if(normalize, scode_lprior_def), + str_if(normalize, scode_lprior_tags_def), collapse_stanvars(stanvars, "tparameters", "start"), scode_predictor[["tpar_prior_const"]], scode_re[["tpar_prior_const"]],