Skip to content

Commit

Permalink
further logic
Browse files Browse the repository at this point in the history
  • Loading branch information
advieser committed Nov 25, 2024
1 parent 33dbac7 commit f25bcc7
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions R/PipeOpDecode.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
#'
#' @description
#' Description
#'
#' @section Construction:
#' ```
Expand Down Expand Up @@ -50,7 +51,7 @@ PipeOpDecode = R6Class("PipeOpDecode",
group_pattern = p_uty(custom_check = check_string, tags = c("train", "predict"))
)
ps$values = list(treatment_encoding = FALSE, group_pattern = "^([^.]*)\\.")
super$initialize(id, param_set = ps, param_vals = param_vals, packages = "stats", tags = "encode", feature_types = c("factor", "ordered"))
super$initialize(id, param_set = ps, param_vals = param_vals, tags = "encode", feature_types = c("integer", "numeric"))
}
),
private = list(
Expand All @@ -62,42 +63,62 @@ PipeOpDecode = R6Class("PipeOpDecode",
# If pattern == "", all columns are collapsed into one column
if (pv$group_pattern == "") {
return(list(colmaps = list(result = set_names(cols, cols))))
# should make sure that result is available (or name we chose instead)
}

# Extract group names
# Drop columns that do contain group_pattern
# What about cols starting with .? -> probably let user exclude this by changing group_pattern are using affect_columns
cols = cols[grepl(pv$group_pattern, cols)]

# Extract factor names
matches = regmatches(cols, regexec(pv$group_pattern, cols))
grps = unlist(map(matches, function(x) if (length(x)) x[[2]] else ""))
fcts = vapply(matches, function(x) x[[2]], character(1))
# Extract level names
lvls = set_names(gsub(pv$group_pattern, "", cols), cols)

# Drop entries for which no match to group_pattern was found
keep = fcts != ""
fcts = fcts[keep]
lvls = lvls[keep]

# add "" = "ref" if pv$treatment_encoding == TRUE
# test that split is consistent for this use case
list(colmaps = split(lvls, fcts))
},
s = list(colmaps = split(lvls, fcts))

if (pv$treatment_encoding) {
# Set default name for reference level
ref_name = "ref"
counter = 1
while (ref_name %in% cols) {
ref_name = paste0("ref.", counter)
counter = counter + 1
}
# Append ref_name with empty name to all list entries
for (i in seq_along(s[["colmaps"]])) {
s[["colmaps"]][[i]][[length(s[["colmaps"]][[i]]) + 1]] = ref_name
}
}

# take maximum value, bc could be scaled
# treatment dass alles 0 ist, hard coden, referenzname als reference nennen (und ref.1 falls es die spalte schon gibt)
s
},

# decide when to assign "ref" (e.g. no unique maximum)
.transform_dt = function(dt, levels) {
colmaps = self$state$colmaps

for (fct in names(colmaps)) {
old_cols = names(colmaps[[fct]])
lvls = unname(colmaps[[fct]])
lvls = colmaps[[fct]]

# Do we check that which.max is unique?
# Generally what checks should we perform? e.g. that group_pattern contains a capturing group

# Find the column with the maximal value for each row
dt[, (fct) := old_cols[apply(.SD, 1, which.max)], .SDcols = old_cols]
# Assign the corresponding value from the named vector to the new column
dt[, (fct) := lvls[get(fct)]]
# dt[, (fct) := lvls[get(fct)]]
dt[, (fct) := lvls[dt[[fct]]]]

# type conversion to factor?

# Remove the old columns (can move this to outside the loop)
dt[, (old_cols) := NULL]
}

dt
}
)
)
Expand Down

0 comments on commit f25bcc7

Please sign in to comment.