From b7fdca20a5c712734c132aecd359268b9f4903bd Mon Sep 17 00:00:00 2001 From: Aki Vehtari Date: Mon, 18 Mar 2024 19:01:15 +0200 Subject: [PATCH] update doc and loo recommendations --- R/loo.R | 34 +++++++++++++++++++++++++++------- R/loo_moment_match.R | 4 ++-- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/R/loo.R b/R/loo.R index 2565bee56..2df2032ad 100644 --- a/R/loo.R +++ b/R/loo.R @@ -27,9 +27,10 @@ #' details. #' @param reloo Logical; Indicate whether \code{\link{reloo}} #' should be applied on problematic observations. Defaults to \code{FALSE}. -#' @param k_threshold The threshold at which pareto \eqn{k} -#' estimates are treated as problematic. Defaults to \code{0.7}. -#' Only used if argument \code{reloo} is \code{TRUE}. +#' @param k_threshold The Pareto \eqn{k} threshold for which observations +#' \code{\link{loo_moment_match}} or \code{\link{reloo}} is applied if +#' argument \code{moment_match} or \code{reloo} is \code{TRUE}. +#' Defaults to \code{0.7}. #' See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} for more details. #' @param save_psis Should the \code{"psis"} object created internally be saved #' in the returned object? For more details see \code{\link[loo:loo]{loo}}. @@ -677,13 +678,19 @@ recommend_loo_options <- function(loo, k_threshold = 0.7, moment_match = FALSE, } else { model_name <- "" } - n <- length(loo::pareto_k_ids(loo, threshold = k_threshold)) ndraws <- dim(loo)[1] %||% Inf - if (n > 0 && ndraws < 2200) { + n <- length(loo::pareto_k_ids(loo, threshold = k_threshold)) + k_threshold2 <- ps_khat_threshold(ndraws) + if (k_threshold2 < k_threshold) { + n2 <- length(loo::pareto_k_ids(loo, threshold = k_threshold2)) + } else { + n2 <- n + } + if (n2 > n && k_threshold2<=0.7) { warning2( - "Found ", n, " observations with a pareto_k > ", k_threshold, + "Found ", n2, " observations with a pareto_k > ", round(k_threshold2,2), model_name, ". We recommend to run more iterations to get at least ", - "about 2200 posterior draws for more reliable pareteo_k estimation." + "about 2200 posterior draws to improve the LOO accuracy." ) out <- "loo_more_draws" } else if (n > 0 && !moment_match) { @@ -991,3 +998,16 @@ print.iclist <- function(x, digits = 2, ...) { print(round(mat, digits = digits), na.print = "") invisible(x) } + +#' Pareto-smoothing k-hat threshold +#' +#' Given sample size S computes khat threshold for reliable Pareto +#' smoothed estimate (to have small probability of large error). See +#' section 3.2.4, equation (13). +#' @param S sample size +#' @param ... unused +#' @return threshold +#' @noRd +ps_khat_threshold <- function(S, ...) { + 1 - 1 / log10(S) +} diff --git a/R/loo_moment_match.R b/R/loo_moment_match.R index fcf252384..bb998d210 100644 --- a/R/loo_moment_match.R +++ b/R/loo_moment_match.R @@ -9,8 +9,8 @@ #' @inheritParams predict.brmsfit #' @param x An object of class \code{brmsfit}. #' @param loo An object of class \code{loo} originally created from \code{x}. -#' @param k_threshold The threshold at which Pareto \eqn{k} -#' estimates are treated as problematic. Defaults to \code{0.7}. +#' @param k_threshold The Pareto \eqn{k} threshold for which observations +#' moment matching is applied. Defaults to \code{0.7}. #' See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} #' for more details. #' @param check Logical; If \code{TRUE} (the default), some checks