Skip to content

Commit

Permalink
Merge pull request #22 from GidonFrischkorn/4-implement-check_data
Browse files Browse the repository at this point in the history
Second version of the check_data function
  • Loading branch information
GidonFrischkorn authored Feb 28, 2024
2 parents 12d0a7b + 1045583 commit 8bbd558
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 23 deletions.
71 changes: 50 additions & 21 deletions R/helpers-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,42 +143,71 @@ check_var_setsize <- function(setsize, data) {

check_data.M3 <- function(model, data, formula) {

# Check if the response variables are all present in the data
resp_name <- model$resp_vars[[1]]
# Get the vector of the response variables
resp_name <- model$resp_vars$resp_cats
# Get the names for each columns
col_name <- colnames(data)
missing_list <- setdiff(resp_name, intersect(resp_name, col_name))
if (length(missing_list) > 0) {
stop(paste0(
"The response variable(s) '",
paste0(missing_list,collapse="', '"),
"' is not present in the data."))

# Check if the response variables are legal or not.
if (sum(grepl("[[:punct:]]|\\s", resp_name)) > 0) {
stop("Space and punctuation are not allowed in the response variable names.")
}

# Check if each response variable is provided the number of options
nOpt_data <- model$other_vars$num_options
option_name <- colnames(nOpt_data)
missing_list <- setdiff(resp_name, intersect(resp_name, option_name))
# Check if the response variables are all present in the data
missing_list <- setdiff(resp_name, intersect(resp_name, col_name))
if (length(missing_list) > 0) {
stop(paste0(
"The response variable(s) '",
paste0(missing_list,collapse="', '"),
"' is not provided the number of options"))
"' is not present in the data."))
}

# Transfer all of the response variables to a matrix and name it 'Y'
data$Y <- as.matrix(data[,resp_name])
colnames(data$Y) <- 1:length(resp_name)
data <- dplyr::select(data, -all_of(resp_name))

# Rename the number of options to 'nOpt'
nOpt_data <- dplyr::rename_at(nOpt_data,dplyr::vars(resp_name), function(x) paste0("nOpt",1:length(resp_name)))
# Get the vector of the options variables
nOpt_vect <- model$other_vars$num_options

# Add number of options to the data
if (ncol(nOpt_data) == length(resp_name)) {
data <- dplyr::cross_join(data, nOpt_data)
} else {
data <- dplyr::left_join(data, nOpt = nOpt_data)
# Check whether the option variables have the same length as the response variables.
if (length(nOpt_vect) != length(resp_name)) {
stop("The option variables should have the same length as the response variables.")
}

# If the number of options is a string, then it is the name of the column in the data
if (is.character(nOpt_vect)) {
option_name <- nOpt_vect

# Check if the name of the number of options is legal or not.
if (sum(grepl("[[:punct:]]|\\s", option_name)) > 0) {
stop("Space and punctuation are not allowed in the number of options variable name.")
}

# Check if the number of options is present in the data
missing_list <- setdiff(option_name, intersect(option_name, col_name))
if (length(missing_list) > 0) {
stop(paste0(
"The variable(s) '",
paste0(missing_list,collapse="', '"),
"' is not present in the data."))
}
# If the number of options is a numeric vector,
# then it represents the number of options for each response variable in all conditions.
} else if (is.numeric(nOpt_vect)) {

nOpt_name <- paste0("nOpt",resp_name)

nOpt_data <- data.frame(nOpt_name, nOpt_vect) %>%
tidyr::pivot_wider(names_from = nOpt_name, values_from = nOpt_vect)

# Add the number of options to the data
data <- dplyr::cross_join(data, nOpt_data)

} else {
stop("The number of options should be a string or a numeric vector.")
}

NextMethod("check_data")
}


Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-helpers-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ test_that("check_var_setsize rejects invalid input", {
test_that("check_data() returns a data.frame()", {
mls <- lapply(supported_models(print_call=FALSE), get_model)
for (ml in mls) {
expect_s3_class(check_data(ml(resp_err = "y",nt_features = 'x', setsize=2, nt_distances = 'z'),
data.frame(y = 1, x = 1, z = 2),
expect_s3_class(check_data(ml(resp_err = "y",nt_features = 'x', setsize=2, nt_distances = 'z', resp_cats = c("IIP","IOP","NPL"), num_options = c(1,2,3)),
data.frame(y = 1, x = 1, z = 2, IIP = 1, IOP = 1, NPL = 1),
bmmformula(kappa ~ 1)), "data.frame")
}
})
Expand Down

0 comments on commit 8bbd558

Please sign in to comment.