diff --git a/R/PipeOpDecode.R b/R/PipeOpDecode.R index 489e5e3eb..edf6f07c6 100644 --- a/R/PipeOpDecode.R +++ b/R/PipeOpDecode.R @@ -5,6 +5,7 @@ #' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. #' #' @description +#' Description #' #' @section Construction: #' ``` @@ -50,7 +51,7 @@ PipeOpDecode = R6Class("PipeOpDecode", group_pattern = p_uty(custom_check = check_string, tags = c("train", "predict")) ) ps$values = list(treatment_encoding = FALSE, group_pattern = "^([^.]*)\\.") - super$initialize(id, param_set = ps, param_vals = param_vals, packages = "stats", tags = "encode", feature_types = c("factor", "ordered")) + super$initialize(id, param_set = ps, param_vals = param_vals, tags = "encode", feature_types = c("integer", "numeric")) } ), private = list( @@ -62,42 +63,62 @@ PipeOpDecode = R6Class("PipeOpDecode", # If pattern == "", all columns are collapsed into one column if (pv$group_pattern == "") { return(list(colmaps = list(result = set_names(cols, cols)))) + # should make sure that result is available (or name we chose instead) } - # Extract group names + # Drop columns that do contain group_pattern + # What about cols starting with .? -> probably let user exclude this by changing group_pattern are using affect_columns + cols = cols[grepl(pv$group_pattern, cols)] + + # Extract factor names matches = regmatches(cols, regexec(pv$group_pattern, cols)) - grps = unlist(map(matches, function(x) if (length(x)) x[[2]] else "")) + fcts = vapply(matches, function(x) x[[2]], character(1)) # Extract level names lvls = set_names(gsub(pv$group_pattern, "", cols), cols) - # Drop entries for which no match to group_pattern was found - keep = fcts != "" - fcts = fcts[keep] - lvls = lvls[keep] - - # add "" = "ref" if pv$treatment_encoding == TRUE # test that split is consistent for this use case - list(colmaps = split(lvls, fcts)) - }, + s = list(colmaps = split(lvls, fcts)) + + if (pv$treatment_encoding) { + # Set default name for reference level + ref_name = "ref" + counter = 1 + while (ref_name %in% cols) { + ref_name = paste0("ref.", counter) + counter = counter + 1 + } + # Append ref_name with empty name to all list entries + for (i in seq_along(s[["colmaps"]])) { + s[["colmaps"]][[i]][[length(s[["colmaps"]][[i]]) + 1]] = ref_name + } + } - # take maximum value, bc could be scaled - # treatment dass alles 0 ist, hard coden, referenzname als reference nennen (und ref.1 falls es die spalte schon gibt) + s + }, - # decide when to assign "ref" (e.g. no unique maximum) .transform_dt = function(dt, levels) { colmaps = self$state$colmaps for (fct in names(colmaps)) { old_cols = names(colmaps[[fct]]) - lvls = unname(colmaps[[fct]]) + lvls = colmaps[[fct]] + + # Do we check that which.max is unique? + # Generally what checks should we perform? e.g. that group_pattern contains a capturing group # Find the column with the maximal value for each row dt[, (fct) := old_cols[apply(.SD, 1, which.max)], .SDcols = old_cols] # Assign the corresponding value from the named vector to the new column - dt[, (fct) := lvls[get(fct)]] + # dt[, (fct) := lvls[get(fct)]] + dt[, (fct) := lvls[dt[[fct]]]] + + # type conversion to factor? + # Remove the old columns (can move this to outside the loop) dt[, (old_cols) := NULL] } + + dt } ) )