From 4f1271bbe27f6ae7a0efa0e7d0286acf6d4a882c Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Fri, 2 Jul 2021 15:06:46 +0300 Subject: [PATCH 1/8] Use single taskdata object for all the examples --- src/abstractors/abstractors.jl | 1 - src/abstractors/sort_array.jl | 22 -- src/data_transformers/find_const.jl | 10 +- src/data_transformers/find_dependent_key.jl | 12 +- .../find_matching_obj_group.jl | 25 +-- .../find_neg_shift_by_key.jl | 7 +- .../find_proportionate_by_key.jl | 29 ++- .../find_proportionate_key.jl | 4 +- src/data_transformers/find_shifted_by_key.jl | 24 +-- src/data_transformers/find_shifted_key.jl | 4 +- src/data_transformers/match_transformers.jl | 72 +++---- src/find_solution.jl | 35 ++- src/operations/wrap_matcher.jl | 12 +- src/solution.jl | 204 +++++++++--------- src/taskdata/taskdata.jl | 8 +- test/test_ignore_background.jl | 72 ------- test/test_select_by_color.jl | 52 ----- test/utils.jl | 8 +- 18 files changed, 214 insertions(+), 387 deletions(-) delete mode 100644 src/abstractors/sort_array.jl delete mode 100644 test/test_ignore_background.jl delete mode 100644 test/test_select_by_color.jl diff --git a/src/abstractors/abstractors.jl b/src/abstractors/abstractors.jl index deb8285..d0c364d 100644 --- a/src/abstractors/abstractors.jl +++ b/src/abstractors/abstractors.jl @@ -356,7 +356,6 @@ include("background_color.jl") include("solid_objects.jl") include("group_obj_by_color.jl") include("compact_similar_objects.jl") -# include("sort_array.jl") include("transpose.jl") include("repeat_object_infinite.jl") include("unwrap_tuple.jl") diff --git a/src/abstractors/sort_array.jl b/src/abstractors/sort_array.jl deleted file mode 100644 index ece2b6f..0000000 --- a/src/abstractors/sort_array.jl +++ /dev/null @@ -1,22 +0,0 @@ - -struct SortArray <: AbstractorClass end - -SortArray(key, to_abs) = Abstractor(SortArray(), key, to_abs) -abs_keys(::SortArray) = ["sorted"] - -init_create_check_data(::SortArray, key, solution) = Dict("effective" => false) - -function check_task_value(::SortArray, value::AbstractVector{T}, data, aux_values) where {T} - if !hasmethod(isless, Tuple{T,T}) - return false - end - - if sort(value) != value - data["effective"] = true - end - return true -end - -to_abstract_value(p::Abstractor{SortArray}, source_value) = Dict(p.output_keys[1] => sort(source_value)) - -from_abstract_value(p::Abstractor{SortArray}, source_value) = Dict(p.output_keys[1] => source_value) diff --git a/src/data_transformers/find_const.jl b/src/data_transformers/find_const.jl index 559f202..201389c 100644 --- a/src/data_transformers/find_const.jl +++ b/src/data_transformers/find_const.jl @@ -1,19 +1,19 @@ using ..Operations: SetConst -function find_const(taskdata::Vector{TaskData}, _, _, key::String)::Vector{SetConst} +function find_const(taskdata::TaskData, _, _, key::String)::Vector{SetConst} result = nothing if !in(key, updated_keys(taskdata)) return [] end - for task_data in taskdata - if !haskey(task_data, key) + for value in taskdata[key] + if ismissing(value) continue end if isnothing(result) - result = task_data[key] + result = value end - possible_value = common_value(result, task_data[key]) + possible_value = common_value(result, value) if isnothing(possible_value) return [] end diff --git a/src/data_transformers/find_dependent_key.jl b/src/data_transformers/find_dependent_key.jl index 1f9b5cc..d846090 100644 --- a/src/data_transformers/find_dependent_key.jl +++ b/src/data_transformers/find_dependent_key.jl @@ -1,24 +1,22 @@ using ..Operations: CopyParam -function find_dependent_key(taskdata::Vector{TaskData}, field_info, invalid_sources::AbstractSet{String}, key::String) +function find_dependent_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) upd_keys = updated_keys(taskdata) skipmissing( - imap(keys(taskdata[1])) do input_key + imap(keys(taskdata)) do input_key if in(input_key, invalid_sources) || field_info[key].type != field_info[input_key].type || (!in(key, upd_keys) && !in(input_key, upd_keys)) return missing end - for task_data in taskdata - if !haskey(task_data, input_key) + for (input_value, out_value) in zip(taskdata[input_key], taskdata[key]) + if ismissing(input_value) return missing end - if !haskey(task_data, key) + if ismissing(out_value) continue end - input_value = task_data[input_key] - out_value = task_data[key] if !check_match(input_value, out_value) return missing end diff --git a/src/data_transformers/find_matching_obj_group.jl b/src/data_transformers/find_matching_obj_group.jl index ababb5f..4e699f0 100644 --- a/src/data_transformers/find_matching_obj_group.jl +++ b/src/data_transformers/find_matching_obj_group.jl @@ -28,35 +28,24 @@ using ..Abstractors: Abstractor, SelectGroup _check_group_type(::Type, ::Type) = false _check_group_type(::Type{Dict{K,V}}, expected::Type) where {K,V} = V == expected -function _get_matching_transformers( - taskdata::Vector{TaskData}, - field_info, - invalid_sources::AbstractSet{String}, - key::String, -) - if endswith(key, "|selected_group") || any(!haskey(task_data, key) for task_data in taskdata) +function _get_matching_transformers(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) + if endswith(key, "|selected_group") || any(ismissing(val) for val in taskdata[key]) return [] end upd_keys = updated_keys(taskdata) flatten( - imap(keys(taskdata[1])) do input_key + imap(keys(taskdata)) do input_key if in(input_key, invalid_sources) || (!in(key, upd_keys) && !in(input_key, upd_keys)) || !_check_group_type(field_info[input_key].type, field_info[key].type) || - any(!haskey(task, input_key) for task in taskdata) || - all(length(task[input_key]) <= 1 for task in taskdata) + any(ismissing(input_value) for input_value in taskdata[input_key]) || + all(length(input_value) <= 1 for input_value in taskdata[input_key]) return [] end matching_groups = [] - for task_data in taskdata - if !haskey(task_data, input_key) - return [] - end - input_value = task_data[input_key] - out_value = task_data[key] + for (input_value, out_value) in zip(taskdata[input_key], taskdata[key]) if !check_matching_group(input_value, out_value, matching_groups) return [] - end end imap(unroll_groups(matching_groups)) do group_keys @@ -64,7 +53,7 @@ function _get_matching_transformers( to_abs = MapValues( key_name, "output", - Dict(task_data["output"] => value for (task_data, value) in zip(taskdata, group_keys)), + Dict(output_value => value for (output_value, value) in zip(taskdata["output"], group_keys)), ) from_abs = Abstractor(SelectGroup(), true, [input_key, key_name], [key, key * "|rejected"], String[]) return (to_abstract = to_abs, from_abstract = from_abs) diff --git a/src/data_transformers/find_neg_shift_by_key.jl b/src/data_transformers/find_neg_shift_by_key.jl index 213a5dc..8036ee6 100644 --- a/src/data_transformers/find_neg_shift_by_key.jl +++ b/src/data_transformers/find_neg_shift_by_key.jl @@ -1,11 +1,10 @@ using ..Operations: DecByParam -_shifted_neg_key_filter(shift_key, input_value, output_value, task_data) = - haskey(task_data, shift_key) && - check_match(apply_func(input_value, (x, y) -> x .- y, task_data[shift_key]), output_value) +_shifted_neg_key_filter(shift_value, input_value, output_value) = + check_match(apply_func(input_value, (x, y) -> x .- y, shift_value), output_value) -find_neg_shifted_by_key(taskdata::Vector{TaskData}, field_info, invalid_sources::AbstractSet{String}, key::String) = +find_neg_shifted_by_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) = find_matching_for_key( taskdata, field_info, diff --git a/src/data_transformers/find_proportionate_by_key.jl b/src/data_transformers/find_proportionate_by_key.jl index f1e2641..7c90cd6 100644 --- a/src/data_transformers/find_proportionate_by_key.jl +++ b/src/data_transformers/find_proportionate_by_key.jl @@ -1,33 +1,28 @@ using ..Operations: MultByParam -_init_factor_keys(input_key, field_info, task_data, invalid_sources) = [ - key for (key, value) in task_data if !in(key, invalid_sources) && ( +_init_factor_keys(input_key, field_info, taskdata, invalid_sources) = [ + key for (key, values) in taskdata if !in(key, invalid_sources) && + (all(!ismissing(val) for val in values)) && + ( field_info[key].type == Int64 || field_info[key].type == Tuple{Int64,Int64} || ( field_info[key].type == field_info[input_key].type && - (isa(value, Dict) ? keys(value) == keys(task_data[input_key]) : true) + (isa(values[1], Dict) ? keys(values[1]) == keys(taskdata[input_key][1]) : true) ) ) ] -_factor_key_filter(shift_key, input_value, output_value, task_data) = - haskey(task_data, shift_key) && - check_match(apply_func(input_value, (x, y) -> x .* y, task_data[shift_key]), output_value) +_factor_key_filter(factor_value, input_value, output_value) = + check_match(apply_func(input_value, (x, y) -> x .* y, factor_value), output_value) -_check_effective_factor_key(shift_key, input_key, taskdata) = - all(haskey(task_data, shift_key) for task_data in taskdata) && any( - apply_func(task_data[input_key], (x, y) -> x .* y, task_data[shift_key]) != task_data[input_key] for - task_data in taskdata - ) - -function find_proportionate_by_key( - taskdata::Vector{TaskData}, - field_info, - invalid_sources::AbstractSet{String}, - key::String, +_check_effective_factor_key(factor_key, input_key, taskdata) = any( + apply_func(input_value, (x, y) -> x .* y, factor_value) != input_value for + (input_value, factor_value) in zip(taskdata[input_key], taskdata[factor_key]) ) + +function find_proportionate_by_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) find_matching_for_key( taskdata, field_info, diff --git a/src/data_transformers/find_proportionate_key.jl b/src/data_transformers/find_proportionate_key.jl index a7daa66..c37eca9 100644 --- a/src/data_transformers/find_proportionate_key.jl +++ b/src/data_transformers/find_proportionate_key.jl @@ -3,8 +3,8 @@ using ..Operations: MultParam _init_factors(::Any...) = [-9, -8, -7, -6, -5, -4, -3, -2, -1, 2, 3, 4, 5, 6, 7, 8, 9] -_factor_filter(factor, input_value, output_value, _) = +_factor_filter(factor, input_value, output_value) = check_match(apply_func(input_value, (x, y) -> x .* y, factor), output_value) -find_proportionate_key(taskdata::Vector{TaskData}, field_info, invalid_sources::AbstractSet{String}, key::String) = +find_proportionate_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) = find_matching_for_key(taskdata, field_info, invalid_sources, key, _init_factors, _factor_filter, MultParam) diff --git a/src/data_transformers/find_shifted_by_key.jl b/src/data_transformers/find_shifted_by_key.jl index 79ba0c4..e13940c 100644 --- a/src/data_transformers/find_shifted_by_key.jl +++ b/src/data_transformers/find_shifted_by_key.jl @@ -1,28 +1,28 @@ using ..Operations: IncByParam -_init_shift_keys(input_key, field_info, task_data, invalid_sources) = [ - key for (key, value) in task_data if !in(key, invalid_sources) && ( +_init_shift_keys(input_key, field_info, taskdata, invalid_sources) = [ + key for (key, values) in taskdata if !in(key, invalid_sources) && + (all(!ismissing(val) for val in values)) && + ( field_info[key].type == Int64 || field_info[key].type == Tuple{Int64,Int64} || ( field_info[key].type == field_info[input_key].type && - (isa(value, Dict) ? keys(value) == keys(task_data[input_key]) : true) + (isa(values[1], Dict) ? keys(values[1]) == keys(taskdata[input_key][1]) : true) ) ) ] -_shifted_key_filter(shift_key, input_value, output_value, task_data) = - haskey(task_data, shift_key) && - check_match(apply_func(input_value, (x, y) -> x .+ y, task_data[shift_key]), output_value) +_shifted_key_filter(shift_value, input_value, output_value) = + check_match(apply_func(input_value, (x, y) -> x .+ y, shift_value), output_value) -_check_effective_shift_key(shift_key, input_key, taskdata) = - all(haskey(task_data, shift_key) for task_data in taskdata) && any( - apply_func(task_data[input_key], (x, y) -> x .+ y, task_data[shift_key]) != task_data[input_key] for - task_data in taskdata - ) +_check_effective_shift_key(shift_key, input_key, taskdata) = any( + apply_func(input_value, (x, y) -> x .+ y, shift_value) != input_value for + (input_value, shift_value) in zip(taskdata[input_key], taskdata[shift_key]) +) -function find_shifted_by_key(taskdata::Vector{TaskData}, field_info, invalid_sources::AbstractSet{String}, key::String) +function find_shifted_by_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) find_matching_for_key( taskdata, field_info, diff --git a/src/data_transformers/find_shifted_key.jl b/src/data_transformers/find_shifted_key.jl index 8bf0876..07d17e2 100644 --- a/src/data_transformers/find_shifted_key.jl +++ b/src/data_transformers/find_shifted_key.jl @@ -29,9 +29,9 @@ function _init_shift(input_value, output_value) ), ) end -_shifted_filter(shift, input_value, output_value, _) = +_shifted_filter(shift, input_value, output_value) = check_match(apply_func(input_value, (x, y) -> x .+ y, shift), output_value) -find_shifted_key(taskdata::Vector{TaskData}, field_info, invalid_sources::AbstractSet{String}, key::String) = +find_shifted_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) = find_matching_for_key(taskdata, field_info, invalid_sources, key, _init_shift, _shifted_filter, IncParam) diff --git a/src/data_transformers/match_transformers.jl b/src/data_transformers/match_transformers.jl index acf39bb..8e5cfa5 100644 --- a/src/data_transformers/match_transformers.jl +++ b/src/data_transformers/match_transformers.jl @@ -2,7 +2,7 @@ using IterTools: imap using Base.Iterators: flatten -function get_match_transformers(taskdata::Array{TaskData}, field_info, invalid_sources, key) +function get_match_transformers(taskdata::TaskData, field_info, invalid_sources, key) find_matches_funcs = [ find_const, find_dependent_key, @@ -39,7 +39,7 @@ function match_fields(solution::Solution) try matched_results = Dict() valid_solutions_count = - prod(length(unpack_value(task[key])) for task in solution.taskdata if haskey(task, key)) + prod(length(unpack_value(value)) for value in solution.taskdata[key] if !ismissing(value)) iter = find_matched_fields(key, solution) state = () counter = 0 @@ -50,9 +50,8 @@ function match_fields(solution::Solution) end counter += 1 new_solution, state = next - key_result = [task[key] for task in new_solution.taskdata] - if !haskey(matched_results, key_result) - matched_results[key_result] = new_solution + if !haskey(matched_results, new_solution.taskdata[key]) + matched_results[new_solution.taskdata[key]] = new_solution end end if !isempty(matched_results) @@ -69,7 +68,7 @@ end function find_matching_for_key( - taskdata::Vector{TaskData}, + taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String, @@ -82,26 +81,24 @@ function find_matching_for_key( end upd_keys = updated_keys(taskdata) flatten( - imap(keys(taskdata[1])) do input_key + imap(keys(taskdata)) do input_key if in(input_key, invalid_sources) || field_info[key].type != field_info[input_key].type || (!in(key, upd_keys) && !in(input_key, upd_keys)) return [] end candidates = [] - for task_data in taskdata - if !haskey(task_data, input_key) + for (input_value, out_value) in zip(taskdata[input_key], taskdata[key]) + if ismissing(input_value) return [] end - if !haskey(task_data, key) + if ismissing(out_value) continue end - input_value = task_data[input_key] - out_value = task_data[key] if isempty(candidates) candidates = init_func(input_value, out_value) end - filter!(candidate -> filter_func(candidate, input_value, out_value, task_data), candidates) + filter!(candidate -> filter_func(candidate, input_value, out_value), candidates) if isempty(candidates) return [] end @@ -112,7 +109,7 @@ function find_matching_for_key( end function find_matching_for_key( - taskdata::Vector{TaskData}, + taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String, @@ -126,38 +123,41 @@ function find_matching_for_key( end upd_keys = updated_keys(taskdata) flatten( - imap(keys(taskdata[1])) do input_key + imap(keys(taskdata)) do input_key if in(input_key, invalid_sources) || field_info[key].type != field_info[input_key].type return [] end need_updated_candidates = !in(key, upd_keys) && !in(input_key, upd_keys) - candidates = [] - for task_data in taskdata - if !haskey(task_data, input_key) - return [] - end - if isempty(candidates) - candidates = init_func(input_key, field_info, task_data, invalid_sources) - if need_updated_candidates - candidates = filter(k -> in(k, upd_keys), candidates) + candidates = init_func(input_key, field_info, taskdata, invalid_sources) + if need_updated_candidates + candidates = filter(k -> in(k, upd_keys), candidates) + end + if isempty(candidates) + return [] + end + res = [] + for candidate in candidates + valid = true + for (input_value, out_value, cand_value) in + zip(taskdata[input_key], taskdata[key], taskdata[candidate]) + if ismissing(input_value) + return [] + end + if ismissing(out_value) + continue + end + if ismissing(cand_value) || !filter_func(cand_value, input_value, out_value) + valid = false + break end end - if isempty(candidates) - return [] - end - if !haskey(task_data, key) - continue - end - out_value = task_data[key] - input_value = task_data[input_key] - filter!(candidate -> filter_func(candidate, input_value, out_value, task_data), candidates) - if isempty(candidates) - return [] + if valid + push!(res, candidate) end end return [ transformer_class(key, input_key, candidate) for - candidate in candidates if candidate_checker(candidate, input_key, taskdata) + candidate in res if candidate_checker(candidate, input_key, taskdata) ] end, ) diff --git a/src/find_solution.jl b/src/find_solution.jl index 4aa5486..0ac8495 100644 --- a/src/find_solution.jl +++ b/src/find_solution.jl @@ -105,20 +105,14 @@ end function is_subsolution(old_sol::Solution, new_sol::Solution)::Bool equals = true - for (new_inp_vals, new_out_vals, old_inp_vals, old_out_vals, new_task_data, old_task_data) in zip( - new_sol.inp_val_hashes, - new_sol.out_val_hashes, - old_sol.inp_val_hashes, - old_sol.out_val_hashes, - new_sol.taskdata, - old_sol.taskdata, - ) + if !issetequal(keys(new_sol.taskdata), keys(old_sol.taskdata)) + equals = false + end + for (new_inp_vals, new_out_vals, old_inp_vals, old_out_vals) in + zip(new_sol.inp_val_hashes, new_sol.out_val_hashes, old_sol.inp_val_hashes, old_sol.out_val_hashes) if !issubset(new_out_vals, old_out_vals) || !issubset(new_inp_vals, old_inp_vals) return false end - if !issetequal(keys(new_task_data), keys(old_task_data)) - equals = false - end end if equals && old_sol != new_sol return false @@ -162,8 +156,8 @@ function pop_solution(queue, visited, border) end end -function generate_solutions(taskdata::Array, debug::Bool) - init_solution = Solution(taskdata) +function generate_solutions(task_info::Array, debug::Bool) + init_solution = Solution(task_info) queue = PriorityQueue() visited = Set() border = Set() @@ -206,8 +200,10 @@ end function solve_task(task_info::Dict, debug::Bool, early_stop = true::Bool) answers = [] + test_input = [task["input"] for task in task_info["test"]] + test_output = [task["output"] for task in task_info["test"]] for solution in generate_solutions(task_info["train"], debug) - answer = [solution(task["input"]) for task in task_info["test"]] + answer = solution(test_input) if !in(answer, answers) @info("found") @info(solution) @@ -216,20 +212,17 @@ function solve_task(task_info::Dict, debug::Bool, early_stop = true::Bool) if length(answers) >= 3 break end - if early_stop - if all( - compare_grids(target["output"], out_grid) == 0 for (out_grid, target) in zip(answer, task_info["test"]) - ) - break - end + if early_stop && compare_grids(test_output, answer) == 0 + break end end return answers end function validate_results(test_info::Vector, answers::Vector)::Bool + test_output = [task["output"] for task in test_info] for answer in answers - if all(compare_grids(target["output"], out_grid) == 0 for (out_grid, target) in zip(answer, test_info)) + if compare_grids(test_output, answer) == 0 return true end end diff --git a/src/operations/wrap_matcher.jl b/src/operations/wrap_matcher.jl index fb90386..d1cc38a 100644 --- a/src/operations/wrap_matcher.jl +++ b/src/operations/wrap_matcher.jl @@ -25,20 +25,20 @@ _check_matcher(value::AbstractDict) = any(_check_matcher(v) for v in values(valu _check_matcher(value::AbstractVector) = any(_check_matcher(v) for v in value) _filter_unmatched_keys(keys, taskdata) = - filter(key -> any(_check_matcher(task[key]) for task in taskdata if haskey(task, key)), keys) + filter(key -> any(_check_matcher(val) for val in taskdata[key] if !ismissing(val)), keys) function wrap_operation(taskdata, operation) unmatched_keys = _filter_unmatched_keys(operation.output_keys, taskdata) if isempty(unmatched_keys) return taskdata, operation end - for key in unmatched_keys, task in taskdata - if haskey(task, key) && !haskey(task, key * "|unfilled") - task[key*"|unfilled"] = task[key] + for key in unmatched_keys + if !haskey(taskdata, key * "|unfilled") + taskdata[key*"|unfilled"] = taskdata[key] end end - for key in operation.output_keys, task in taskdata - delete!(task, key) + for key in operation.output_keys + delete!(taskdata, key) end return taskdata, WrapMatcher( diff --git a/src/solution.jl b/src/solution.jl index 182802f..34febf5 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -1,8 +1,6 @@ export Solutions module Solutions -export validate_solution - using ..Operations: Operation, Project, get_sorting_keys struct Block @@ -88,7 +86,7 @@ _is_valid_value(val) = true _is_valid_value(val::Union{Array,Dict}) = !isempty(val) struct Solution - taskdata::Vector{TaskData} + taskdata::TaskData field_info::Dict{String,FieldInfo} blocks::Vector{Block} unfilled_fields::Set{String} @@ -113,21 +111,17 @@ struct Solution input_transformed_fields, complexity_score::Float64, ) - inp_val_hashes = Set{UInt64}[] - out_val_hashes = Set{UInt64}[] - for task_data in taskdata - inp_vals = Set{UInt64}() - out_vals = Set{UInt64}() - for (key, value) in task_data + inp_val_hashes = fill(Set{UInt64}(), length(taskdata["input"])) + out_val_hashes = fill(Set{UInt64}(), length(taskdata["input"])) + for (key, values) in taskdata + for (i, value) in enumerate(values) if in(key, transformed_fields) || in(key, filled_fields) || in(key, unfilled_fields) - push!(out_vals, hash(value)) + push!(out_val_hashes[i], hash(value)) end if in(key, unused_fields) || in(key, used_fields) || in(key, input_transformed_fields) - push!(inp_vals, hash(value)) + push!(inp_val_hashes[i], hash(value)) end end - push!(inp_val_hashes, inp_vals) - push!(out_val_hashes, out_vals) end new( taskdata, @@ -147,12 +141,16 @@ struct Solution end end -function Solution(taskdata) +function Solution(task_info) Solution( - [TaskData(task, Dict{String,Any}(), Set{String}()) for task in taskdata], + TaskData( + Dict("input" => [task["input"] for task in task_info], "output" => [task["output"] for task in task_info]), + Dict{String,Any}(), + Set{String}(), + ), Dict( - "input" => FieldInfo(taskdata[1]["input"], "input", [], [["input"]]), - "output" => FieldInfo(taskdata[1]["output"], "input", [], [Set()]), + "input" => FieldInfo(task_info[1]["input"], "input", [], [["input"]]), + "output" => FieldInfo(task_info[1]["output"], "input", [], [Set()]), ), [Block()], Set(["output"]), @@ -166,7 +164,7 @@ function Solution(taskdata) end persist_updates(solution::Solution) = Solution( - [persist_data(task) for task in solution.taskdata], + persist_data(solution.taskdata), solution.field_info, solution.blocks, solution.unfilled_fields, @@ -206,17 +204,14 @@ function move_to_next_block(solution::Solution)::Solution blocks[end] = Block(reverse(prev_block_ops)) - last_block_output = [ - blocks[end]( - filter( - keyval -> !in(keyval[1], solution.unfilled_fields) && !in(keyval[1], solution.transformed_fields), - task, - ), - ) for task in solution.taskdata - ] + last_block_output = blocks[end]( + filter( + keyval -> !in(keyval[1], solution.unfilled_fields) && !in(keyval[1], solution.transformed_fields), + solution.taskdata, + ), + ) - taskdata = - [merge(task_data, block_output) for (task_data, block_output) in zip(solution.taskdata, last_block_output)] + taskdata = merge(solution.taskdata, last_block_output) field_info = solution.field_info for op in blocks[end].operations @@ -225,10 +220,10 @@ function move_to_next_block(solution::Solution)::Solution input_field_info = [field_info[k] for k in op.input_keys if haskey(field_info, k)] i = argmin([length(info.derived_from) for info in input_field_info]) dependent_key = input_field_info[i].derived_from - for task in taskdata - if haskey(task, key) && _is_valid_value(task[key]) + for val in taskdata[key] + if !ismissing(val) && _is_valid_value(val) field_info[key] = FieldInfo( - task[key], + val, dependent_key, vcat([info.precursor_types for info in input_field_info]...), [(field_info[k].previous_fields for k in op.input_keys)..., [key]], @@ -268,24 +263,27 @@ function move_to_next_block(solution::Solution)::Solution i = argmin([length(info.derived_from) for info in input_field_info]) dependent_key = input_field_info[i].derived_from - projected_output = [project_op(block_output) for block_output in last_block_output] + projected_output = project_op(last_block_output) - taskdata = [filter(keyval -> !startswith(keyval[1], "projected|"), task_data) for task_data in taskdata] + taskdata = filter(keyval -> !startswith(keyval[1], "projected|"), taskdata) unused_fields = filter(key -> !startswith(key, "projected|"), solution.unused_fields) field_info = filter(keyval -> !startswith(keyval[1], "|projected"), field_info) for key in project_op.output_keys - for (observed_task, output) in zip(taskdata, projected_output) - if haskey(output, key) - observed_task[key] = output[key] - push!(unused_fields, key) - if !haskey(field_info, key) && _is_valid_value(output[key]) - field_info[key] = FieldInfo( - output[key], - dependent_key, - vcat([info.precursor_types for info in input_field_info]...), - [(field_info[k].previous_fields for k in project_op.input_keys)..., [key]], - ) + if haskey(projected_output, key) + taskdata[key] = projected_output[key] + push!(unused_fields, key) + if !haskey(field_info, key) + for val in projected_output[key] + if _is_valid_value(val) + field_info[key] = FieldInfo( + val, + dependent_key, + vcat([info.precursor_types for info in input_field_info]...), + [(field_info[k].previous_fields for k in project_op.input_keys)..., [key]], + ) + break + end end end end @@ -293,9 +291,7 @@ function move_to_next_block(solution::Solution)::Solution end if isempty(solution.unfilled_fields) - for (task, block_output) in zip(taskdata, last_block_output) - task["projected|output"] = block_output["output"] - end + taskdata["projected|output"] = last_block_output["output"] end if !isempty(new_block.operations) @@ -517,7 +513,7 @@ function insert_operation( try op = isnothing(reversed_op) ? operation : reversed_op - taskdata = [op(task) for task in solution.taskdata] + taskdata = op(solution.taskdata) if isnothing(reversed_op) && !no_wrap taskdata, operation = wrap_operation(taskdata, operation) @@ -558,10 +554,10 @@ function insert_operation( out_dependent_key = input_field_info[i].derived_from end for key in new_input_fields - for task in taskdata - if haskey(task, key) && _is_valid_value(task[key]) + for val in taskdata[key] + if !ismissing(val) && _is_valid_value(val) field_info[key] = FieldInfo( - task[key], + val, get_source_key(operation, out_dependent_key), vcat([info.precursor_types for info in output_field_info]...), [Set()], @@ -595,10 +591,10 @@ function insert_operation( end else push!(unused_fields, key) - for task in taskdata - if haskey(task, key) && _is_valid_value(task[key]) + for val in taskdata[key] + if !ismissing(val) && _is_valid_value(val) field_info[key] = FieldInfo( - task[key], + val, inp_dependent_key, vcat([info.precursor_types for info in input_field_info]...), [(field_info[k].previous_fields for k in operation.input_keys)..., [key]], @@ -678,7 +674,7 @@ Base.show(io::IO, s::Solution) = print( "Dict(\n", (vcat( ( - ["\t\t", keyval, ",\n"] for keyval in s.field_info if haskey(s.taskdata[1], keyval[1]) && ( + ["\t\t", keyval, ",\n"] for keyval in s.field_info if haskey(s.taskdata, keyval[1]) && ( in(keyval[1], s.unfilled_fields) || in(keyval[1], s.unused_fields) || in(keyval[1], s.input_transformed_fields) || @@ -691,33 +687,40 @@ Base.show(io::IO, s::Solution) = print( "\n)", ) -function (solution::Solution)(input_grid::Array{Int,2})::Array{Int,2} - observed_data = TaskData(Dict{String,Any}("input" => input_grid), Dict{String,Any}(), Set()) +function (solution::Solution)(input_grids::Vector{Array{Int,2}})::Array{Int,2} + observed_data = TaskData(Dict{String,Vector}("input" => input_grids), Dict{String,Any}(), Set()) for block in solution.blocks observed_data = block(observed_data) end - get(observed_data, "output", Array{Int}(undef, 0, 0)) + get(observed_data, "output", fill(Array{Int}(undef, 0, 0), length(input_grids))) end Base.:(==)(a::Solution, b::Solution)::Bool = a.blocks == b.blocks Base.hash(s::Solution, h::UInt64) = hash(s.blocks, h) -function check_task(solution::Solution, input_grid::Array{Int,2}, target::Array{Int,2}) - out = solution(input_grid) - compare_grids(target, out) +function check_task(solution::Solution, input_grids::Vector{Array{Int,2}}, targets::Vector{Array{Int,2}}) + out = solution(input_grids) + compare_grids(targets, out) end -function compare_grids(target::Array{Int,2}, output::Array{Int,2}) - if size(target) != size(output) - return reduce(*, size(target)) +function compare_grids(targets::Vector{Array{Int,2}}, outputs::Vector{Array{Int,2}}) + result = 0 + for (target, output) in zip(targets, outputs) + if size(target) != size(output) + result += reduce(*, size(target)) + else + result += sum(output .!= target) + end end - sum(output .!= target) + return result end -function get_score(taskdata, complexity_score)::Int - score = - sum(compare_grids(task["output"], get(task, "projected|output", Array{Int}(undef, 0, 0))) for task in taskdata) +function get_score(taskdata::TaskData, complexity_score)::Int + score = compare_grids( + taskdata["output"], + get(taskdata, "projected|output", fill(Array{Int}(undef, 0, 0), length(taskdata["output"]))), + ) # if complexity_score > 100 # score += floor(complexity_score) # end @@ -727,39 +730,40 @@ end using ..Complexity: get_complexity function get_unmatched_complexity_score(solution::Solution) - unmatched_data_score = [ - sum(Float64[get_complexity(value) for (key, value) in task_data if in(key, solution.unfilled_fields)]) for - task_data in solution.taskdata - ] - transformed_data_score = [ - sum(Float64[get_complexity(value) / 10 for (key, value) in task_data if in(key, solution.transformed_fields)]) for task_data in solution.taskdata - ] - unused_data_score = [ - sum( - Float64[ - startswith(key, "projected|") ? get_complexity(value) / 6 : get_complexity(value) for - (key, value) in task_data if in(key, solution.unused_fields) - ], - ) for task_data in solution.taskdata - ] - inp_transformed_data_score = [ - sum( - Float64[ - get_complexity(value) / 3 for (key, value) in task_data if in(key, solution.input_transformed_fields) - ], - ) for task_data in solution.taskdata - ] - return ( - sum(unmatched_data_score) + - # sum(transformed_data_score) + - sum(unused_data_score) + - sum(inp_transformed_data_score) + - solution.complexity_score - ) / length(solution.taskdata) -end + unmatched_data_score = sum( + Float64[ + sum(Float64[get_complexity(value) for value in values]) for + (key, values) in solution.taskdata if in(key, solution.unfilled_fields) + ], + ) + transformed_data_score = sum( + Float64[ + sum(Float64[get_complexity(value) / 10 for value in values]) for + (key, values) in solution.taskdata if in(key, solution.transformed_fields) + ], + ) + unused_data_score = sum( + Float64[ + sum( + Float64[ + startswith(key, "projected|") ? get_complexity(value) / 6 : get_complexity(value) for + value in values + ], + ) for (key, values) in solution.taskdata if in(key, solution.unused_fields) + ], + ) -function validate_solution(solution, taskdata) - sum(check_task(solution, task["input"], task["output"]) for task in taskdata) + inp_transformed_data_score = sum( + Float64[ + sum(Float64[get_complexity(value) / 3 for value in values]) for + (key, values) in solution.taskdata if in(key, solution.input_transformed_fields) + ], + ) + return (unmatched_data_score + + # transformed_data_score + + unused_data_score + + inp_transformed_data_score + + solution.complexity_score) / length(solution.taskdata["input"]) end end diff --git a/src/taskdata/taskdata.jl b/src/taskdata/taskdata.jl index bc3c35b..4f3d371 100644 --- a/src/taskdata/taskdata.jl +++ b/src/taskdata/taskdata.jl @@ -2,8 +2,8 @@ export Taskdata module Taskdata struct TaskData <: AbstractDict{String,Any} - persistent_data::Dict{String,Any} - updated_values::Dict{String,Any} + persistent_data::Dict{String,Vector} + updated_values::Dict{String,Vector} keys_to_delete::Set{String} end @@ -104,8 +104,4 @@ function updated_keys(t::TaskData) return filter(k -> !in(k, t.keys_to_delete), keys(t.updated_values)) end -function updated_keys(taskdata::Vector{TaskData}) - union((updated_keys(task) for task in taskdata)...) -end - end diff --git a/test/test_ignore_background.jl b/test/test_ignore_background.jl deleted file mode 100644 index 86c5bde..0000000 --- a/test/test_ignore_background.jl +++ /dev/null @@ -1,72 +0,0 @@ - -using .ObjectPrior: Object -# using .Abstractors:IgnoreBackground - -@testset "Ignore background" begin - return - @testset "get ignore background" begin - solution = make_dummy_solution( - [ - Dict( - "output|spatial_objects|selected_by|0|background" => [Object([-1 0 0; 0 0 0; 0 0 0], (1, 1))], - "output|spatial_objects|rejected_by|0|background" => [Object([1], (1, 1))], - "output|background" => 0, - ), - Dict( - "output|spatial_objects|selected_by|0|background" => [Object([-1 0 -1; 0 0 0; 0 0 0], (1, 1))], - "output|spatial_objects|rejected_by|0|background" => [Object([1], (1, 1)), Object([1], (1, 3))], - "output|background" => 0, - ), - Dict( - "output|spatial_objects|selected_by|0|background" => - [Object([0 -1 0; 0 0 0], (2, 1)), Object([0], (1, 2))], - "output|spatial_objects|rejected_by|0|background" => - [Object([1], (1, 1)), Object([1], (1, 3)), Object([1], (2, 2))], - "output|background" => 0, - ), - Dict( - "output|spatial_objects|selected_by|0|background" => [Object([-1 0 -1; 0 0 0; 0 0 0], (1, 1))], - "output|spatial_objects|rejected_by|0|background" => [Object([1], (1, 1)), Object([1], (1, 3))], - "output|background" => 0, - ), - Dict( - "output|spatial_objects|selected_by|0|background" => [Object([-1 0 0; 0 0 0; 0 0 0], (1, 1))], - "output|spatial_objects|rejected_by|0|background" => [Object([1], (1, 1))], - "output|background" => 0, - ), - Dict( - "output|spatial_objects|selected_by|0|background" => - [Object([0], (2, 1)), Object([-1 0; 0 0], (2, 2)), Object([0], (1, 2))], - "output|spatial_objects|rejected_by|0|background" => - [Object([1], (1, 1)), Object([1], (3, 1)), Object([1], (1, 3)), Object([1], (2, 2))], - "output|background" => 0, - ), - Dict( - "output|spatial_objects|selected_by|0|background" => - [Object([0], (2, 1)), Object([0], (2, 3)), Object([0], (1, 2)), Object([0], (3, 2))], - "output|spatial_objects|rejected_by|0|background" => [ - Object([1], (1, 1)), - Object([1], (3, 1)), - Object([1], (1, 3)), - Object([1], (2, 2)), - Object([1], (3, 3)), - ], - "output|background" => 0, - ), - Dict( - "output|spatial_objects|selected_by|0|background" => - [Object([0], (2, 1)), Object([-1 0; 0 0], (2, 2)), Object([0], (1, 2))], - "output|spatial_objects|rejected_by|0|background" => - [Object([1], (1, 1)), Object([1], (3, 1)), Object([1], (1, 3)), Object([1], (2, 2))], - "output|background" => 0, - ), - ], - ["output|spatial_objects|selected_by|0|background", "output|spatial_objects|rejected_by|0|background"], - ) - abstractors = create(IgnoreBackground(), solution, "output|spatial_objects|selected_by|0|background") - @test length(abstractors) == 1 - priority, abstractor = abstractors[1] - @test abstractor.to_abstract == IgnoreBackground("output|spatial_objects|selected_by|0|background", true) - @test abstractor.from_abstract == IgnoreBackground("output|spatial_objects|selected_by|0|background", false) - end -end diff --git a/test/test_select_by_color.jl b/test/test_select_by_color.jl deleted file mode 100644 index 29133a2..0000000 --- a/test/test_select_by_color.jl +++ /dev/null @@ -1,52 +0,0 @@ - -# using .Abstractors:SelectColor - - -@testset "Select Objects by color" begin - return - @testset "select objects" begin - selector = SelectColor("key", "selector_key", true) - input_data = Dict( - "key" => [Object([0 0], (1, 1)), Object([0; 0], (3, 2)), Object([1], (4, 1)), Object([2], (1, 5))], - "selector_key" => 0, - ) - output_data = selector([], [], input_data)[2] - @test output_data == Dict( - "key" => [Object([0 0], (1, 1)), Object([0; 0], (3, 2)), Object([1], (4, 1)), Object([2], (1, 5))], - "selector_key" => 0, - "key|selected_by_color|selector_key" => [Object([0 0], (1, 1)), Object([0; 0], (3, 2))], - "key|rejected_by_color|selector_key" => [Object([1], (4, 1)), Object([2], (1, 5))], - ) - delete!(output_data, "key") - selector = SelectColor("key", "selector_key", false) - reverted_data = selector([], [], output_data)[2] - @test reverted_data["key"] == input_data["key"] - end - - @testset "create selector" begin - solution = make_dummy_solution([ - Dict( - "input|key" => - [Object([0 0], (1, 1)), Object([0; 0], (3, 2)), Object([1], (4, 1)), Object([2], (1, 5))], - "input|selector_key" => 0, - ), - Dict( - "input|key" => - [Object([0 0], (1, 1)), Object([0; 0], (3, 2)), Object([1], (4, 1)), Object([2], (1, 5))], - "input|selector_key" => 0, - ), - ]) - abstractors = create(SelectColor(), solution, "input|key") - @test length(abstractors) == 1 - priority, abstractor = abstractors[1] - @test priority == 2.3 - @test abstractor.to_abstract == SelectColor("input|key", "input|selector_key", true) - @test abstractor.to_abstract.input_keys == ["input|key", "input|selector_key"] - @test abstractor.to_abstract.output_keys == - ["input|key|selected_by_color|input|selector_key", "input|key|rejected_by_color|input|selector_key"] - @test abstractor.from_abstract == SelectColor("input|key", "input|selector_key", false) - @test abstractor.from_abstract.input_keys == - ["input|key|selected_by_color|input|selector_key", "input|key|rejected_by_color|input|selector_key"] - @test abstractor.from_abstract.output_keys == ["input|key"] - end -end diff --git a/test/utils.jl b/test/utils.jl index 8b63a5e..142dff3 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -5,7 +5,6 @@ using .DataTransformers: match_fields using .Abstractors: create using .Taskdata: TaskData -make_sample_taskdata(len) = fill(Dict("input" => Array{Int}(undef, 0, 0), "output" => Array{Int}(undef, 0, 0)), len) struct FakeOperation <: Operation input_keys::Any @@ -15,11 +14,12 @@ end (op::FakeOperation)(task_data) = task_data -make_taskdata(tasks) = [make_taskdata(task) for task in tasks] +make_taskdata(tasks) = + TaskData(Dict{String,Vector}(), Dict(key => [task[key] for task in tasks] for key in keys(tasks[1])), Set()) -make_taskdata(task::Dict) = TaskData(Dict{String,Any}(), task, Set()) +make_taskdata(task::Dict) = TaskData(Dict{String,Vector}(), Dict(key => [val] for (key, val) in task), Set()) -make_field_info(taskdata) = Dict(key => FieldInfo(val, "input", [], [Set()]) for (key, val) in taskdata[1]) +make_field_info(taskdata) = Dict(key => FieldInfo(vals[1], "input", [], [Set()]) for (key, vals) in taskdata) function make_dummy_solution(data, unfilled = []) unused = Set(filter(k -> !in(k, unfilled) && k != "input" && k != "output", keys(data[1]))) From 3366a6b9a72539c2bb94b4fdd8c5ba82e85875fc Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Fri, 2 Jul 2021 16:59:02 +0300 Subject: [PATCH 2/8] Make data transformers work with singular taskdata --- src/operations/copy_param.jl | 5 +--- src/operations/dec_by_param.jl | 9 ++++--- src/operations/inc_by_param.jl | 9 ++++--- src/operations/inc_param.jl | 8 +++--- src/operations/map_values.jl | 14 +++++----- src/operations/mult_by_param.jl | 9 ++++--- src/operations/mult_param.jl | 8 +++--- src/operations/operation.jl | 1 + src/operations/set_const.jl | 6 ++--- src/pattern_matching/update_value.jl | 7 +++++ src/solution.jl | 22 ++++++++-------- test/test_pattern_matching.jl | 38 ++++++++++++++-------------- test/utils.jl | 10 +++++--- 13 files changed, 80 insertions(+), 66 deletions(-) diff --git a/src/operations/copy_param.jl b/src/operations/copy_param.jl index 0d382bc..4117e40 100644 --- a/src/operations/copy_param.jl +++ b/src/operations/copy_param.jl @@ -11,7 +11,4 @@ Base.show(io::IO, op::CopyParam) = print(io, "CopyParam(\"", op.output_keys[1], Base.:(==)(a::CopyParam, b::CopyParam) = a.output_keys == b.output_keys && a.input_keys == b.input_keys Base.hash(op::CopyParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_keys, h) -function (op::CopyParam)(task_data) - data = update_value(task_data, op.output_keys[1], task_data[op.input_keys[1]]) - data -end +(op::CopyParam)(taskdata::TaskData) = update_value(taskdata, op.output_keys[1], taskdata[op.input_keys[1]]) diff --git a/src/operations/dec_by_param.jl b/src/operations/dec_by_param.jl index 716a15b..62e5ba5 100644 --- a/src/operations/dec_by_param.jl +++ b/src/operations/dec_by_param.jl @@ -12,7 +12,10 @@ Base.show(io::IO, op::DecByParam) = Base.:(==)(a::DecByParam, b::DecByParam) = a.output_keys == b.output_keys && a.input_keys == b.input_keys Base.hash(op::DecByParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_keys, h) -function (op::DecByParam)(task_data) - output_value = apply_func(task_data[op.input_keys[1]], (a, b) -> a .- b, task_data[op.input_keys[2]]) - update_value(task_data, op.output_keys[1], output_value) +function (op::DecByParam)(taskdata::TaskData) + output_value = [ + apply_func(val1, (a, b) -> a .- b, val2) for + (val1, val2) in zip(taskdata[op.input_keys[1]], taskdata[op.input_keys[2]]) + ] + update_value(taskdata, op.output_keys[1], output_value) end diff --git a/src/operations/inc_by_param.jl b/src/operations/inc_by_param.jl index 46d3ccb..ccb5b9c 100644 --- a/src/operations/inc_by_param.jl +++ b/src/operations/inc_by_param.jl @@ -13,7 +13,10 @@ Base.show(io::IO, op::IncByParam) = Base.:(==)(a::IncByParam, b::IncByParam) = a.output_keys == b.output_keys && a.input_keys == b.input_keys Base.hash(op::IncByParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_keys, h) -function (op::IncByParam)(task_data) - output_value = apply_func(task_data[op.input_keys[1]], (a, b) -> a .+ b, task_data[op.input_keys[2]]) - update_value(task_data, op.output_keys[1], output_value) +function (op::IncByParam)(taskdata::TaskData) + output_value = [ + apply_func(val1, (a, b) -> a .+ b, val2) for + (val1, val2) in zip(taskdata[op.input_keys[1]], taskdata[op.input_keys[2]]) + ] + update_value(taskdata, op.output_keys[1], output_value) end diff --git a/src/operations/inc_param.jl b/src/operations/inc_param.jl index dac7964..a9dd790 100644 --- a/src/operations/inc_param.jl +++ b/src/operations/inc_param.jl @@ -14,8 +14,8 @@ Base.:(==)(a::IncParam, b::IncParam) = a.output_keys == b.output_keys && a.input_keys == b.input_keys && a.shift == b.shift Base.hash(op::IncParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_keys, h) + hash(op.shift, h) -function (op::IncParam)(task_data) - output_value = apply_func(task_data[op.input_keys[1]], (a, b) -> a .+ b, op.shift) - data = update_value(task_data, op.output_keys[1], output_value) - update_value(data, op.output_keys[2], op.shift) +function (op::IncParam)(taskdata::TaskData) + output_value = [apply_func(val, (a, b) -> a .+ b, op.shift) for val in taskdata[op.input_keys[1]]] + data = update_value(taskdata, op.output_keys[1], output_value) + update_value(data, op.output_keys[2], fill(op.shift, length(taskdata["input"]))) end diff --git a/src/operations/map_values.jl b/src/operations/map_values.jl index 611cbe1..9a7be6c 100644 --- a/src/operations/map_values.jl +++ b/src/operations/map_values.jl @@ -19,12 +19,10 @@ Base.:(==)(a::MapValues, b::MapValues) = a.output_keys == b.output_keys && a.input_keys == b.input_keys && a.match_pairs == b.match_pairs Base.hash(op::MapValues, h::UInt64) = hash(op.output_keys, h) + hash(op.input_keys, h) + hash(op.match_pairs, h) -function (op::MapValues)(task_data) - input_value = task_data[op.input_keys[1]] - if isa(input_value, Dict) - output_value = Dict(key => op.match_pairs[value] for (key, value) in input_value) - else - output_value = op.match_pairs[input_value] - end - update_value(task_data, op.output_keys[1], output_value) +function (op::MapValues)(taskdata::TaskData) + output_value = [ + isa(input_value, Dict) ? Dict(key => op.match_pairs[value] for (key, value) in input_value) : + op.match_pairs[input_value] for input_value in taskdata[op.input_keys[1]] + ] + update_value(taskdata, op.output_keys[1], output_value) end diff --git a/src/operations/mult_by_param.jl b/src/operations/mult_by_param.jl index c2d328d..53a0644 100644 --- a/src/operations/mult_by_param.jl +++ b/src/operations/mult_by_param.jl @@ -12,7 +12,10 @@ Base.show(io::IO, op::MultByParam) = Base.:(==)(a::MultByParam, b::MultByParam) = a.output_keys == b.output_keys && a.input_keys == b.input_keys Base.hash(op::MultByParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_keys, h) -function (op::MultByParam)(task_data) - output_value = apply_func(task_data[op.input_keys[1]], (a, b) -> a .* b, task_data[op.input_keys[2]]) - update_value(task_data, op.output_keys[1], output_value) +function (op::MultByParam)(taskdata::TaskData) + output_value = [ + apply_func(val1, (a, b) -> a .* b, val2) for + (val1, val2) in zip(taskdata[op.input_keys[1]], taskdata[op.input_keys[2]]) + ] + update_value(taskdata, op.output_keys[1], output_value) end diff --git a/src/operations/mult_param.jl b/src/operations/mult_param.jl index 6292cb3..c32aae8 100644 --- a/src/operations/mult_param.jl +++ b/src/operations/mult_param.jl @@ -14,8 +14,8 @@ Base.:(==)(a::MultParam, b::MultParam) = a.output_keys == b.output_keys && a.input_keys == b.input_keys && a.factor == b.factor Base.hash(op::MultParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_keys, h) + hash(op.factor, h) -function (op::MultParam)(task_data) - output_value = apply_func(task_data[op.input_keys[1]], (a, b) -> a .* b, op.factor) - data = update_value(task_data, op.output_keys[1], output_value) - update_value(data, op.output_keys[2], op.factor) +function (op::MultParam)(taskdata::TaskData) + output_value = [apply_func(val, (a, b) -> a .* b, op.factor) for val in taskdata[op.input_keys[1]]] + data = update_value(taskdata, op.output_keys[1], output_value) + update_value(data, op.output_keys[2], fill(op.factor, length(taskdata["input"]))) end diff --git a/src/operations/operation.jl b/src/operations/operation.jl index 309c1eb..08c62eb 100644 --- a/src/operations/operation.jl +++ b/src/operations/operation.jl @@ -11,6 +11,7 @@ get_sorting_keys(operation::Operation) = operation.output_keys needed_input_keys(operation::Operation) = operation.input_keys using ..PatternMatching: update_value, apply_func +using ..Taskdata: TaskData include("set_const.jl") include("copy_param.jl") diff --git a/src/operations/set_const.jl b/src/operations/set_const.jl index 4090606..5aa6072 100644 --- a/src/operations/set_const.jl +++ b/src/operations/set_const.jl @@ -13,7 +13,5 @@ Base.:(==)(a::SetConst, b::SetConst) = a.output_keys == b.output_keys && a.value Base.hash(op::SetConst, h::UInt64) = hash(op.output_keys, h) + hash(op.value, h) -function (op::SetConst)(task_data) - data = update_value(task_data, op.output_keys[1], op.value) - data -end +(op::SetConst)(taskdata::TaskData) = + update_value(taskdata, op.output_keys[1], fill(op.value, length(taskdata["input"]))) diff --git a/src/pattern_matching/update_value.jl b/src/pattern_matching/update_value.jl index a24aa85..9e4072b 100644 --- a/src/pattern_matching/update_value.jl +++ b/src/pattern_matching/update_value.jl @@ -29,6 +29,13 @@ function update_value(data::TaskData, path_keys::Array, value, ::Any)::TaskData data end +function update_value(data::TaskData, path_keys::Array, value, current_value::Vector)::TaskData + for (i, val) in enumerate(value) + data = update_value(data, vcat(path_keys, [i]), val, current_value[i]) + end + data +end + function update_value(data::TaskData, path_keys::Array, value, current_value::Dict)::TaskData for key in keys(current_value) data = update_value(data, vcat(path_keys, [key]), value[key], current_value[key]) diff --git a/src/solution.jl b/src/solution.jl index 34febf5..a6eedd4 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -591,15 +591,17 @@ function insert_operation( end else push!(unused_fields, key) - for val in taskdata[key] - if !ismissing(val) && _is_valid_value(val) - field_info[key] = FieldInfo( - val, - inp_dependent_key, - vcat([info.precursor_types for info in input_field_info]...), - [(field_info[k].previous_fields for k in operation.input_keys)..., [key]], - ) - break + if haskey(taskdata, key) + for val in taskdata[key] + if !ismissing(val) && _is_valid_value(val) + field_info[key] = FieldInfo( + val, + inp_dependent_key, + vcat([info.precursor_types for info in input_field_info]...), + [(field_info[k].previous_fields for k in operation.input_keys)..., [key]], + ) + break + end end end end @@ -704,7 +706,7 @@ function check_task(solution::Solution, input_grids::Vector{Array{Int,2}}, targe compare_grids(targets, out) end -function compare_grids(targets::Vector{Array{Int,2}}, outputs::Vector{Array{Int,2}}) +function compare_grids(targets::Vector, outputs::Vector) result = 0 for (target, output) in zip(targets, outputs) if size(target) != size(output) diff --git a/test/test_pattern_matching.jl b/test/test_pattern_matching.jl index a97c284..5ade320 100644 --- a/test/test_pattern_matching.jl +++ b/test/test_pattern_matching.jl @@ -54,7 +54,7 @@ using .Abstractors: iter_source_either_values end @testset "fix value" begin - keys = ["spatial_objects|grouped|0|first|splitted|first", 1] + keys = ["spatial_objects|grouped|0|first|splitted|first", 1, 1] value = Object([1], (0, 5)) task_data = make_taskdata( Dict( @@ -102,18 +102,18 @@ using .Abstractors: iter_source_either_values new_task_data = update_value(task_data, keys, value) @test Dict(new_task_data) == Dict( - "spatial_objects|grouped|0|step" => Dict(1 => (0, 6), 3 => (0, 6)), - "spatial_objects|grouped|0|first|splitted|first" => Dict( + "spatial_objects|grouped|0|step" => [Dict(1 => (0, 6), 3 => (0, 6))], + "spatial_objects|grouped|0|first|splitted|first" =>[Dict( 1 => Object([1], (0, 5)), 3 => Either([ Option(Object([3], (0, 8)), 8357411015276601514), Option(Object([3], (6, 8)), -6298199269447202670), ]), - ), - "spatial_objects|grouped|0|first|splitted|step" => Dict( + ),], + "spatial_objects|grouped|0|first|splitted|step" => [Dict( 1 => (1, 0), 3 => Either([Option((-1, 0), -6298199269447202670), Option((1, 0), 8357411015276601514)]), - ), + ),], ) task_data = make_taskdata( @@ -152,18 +152,18 @@ using .Abstractors: iter_source_either_values ) new_task_data = update_value(task_data, keys, value) @test Dict(new_task_data) == Dict( - "spatial_objects|grouped|0|step" => Dict(1 => (0, 6), 3 => (0, 6)), - "spatial_objects|grouped|0|first|splitted|first" => Dict( + "spatial_objects|grouped|0|step" => [Dict(1 => (0, 6), 3 => (0, 6))], + "spatial_objects|grouped|0|first|splitted|first" => [Dict( 1 => Object([1], (0, 5)), 3 => Either([ Option(Object([3], (0, 8)), 8357411015276601514), Option(Object([3], (6, 8)), -6298199269447202670), ]), - ), - "spatial_objects|grouped|0|first|splitted|step" => Dict( + ),], + "spatial_objects|grouped|0|first|splitted|step" => [Dict( 1 => Either([Option((-1, 0)), Option((1, 0))]), 3 => Either([Option((-1, 0), -6298199269447202670), Option((1, 0), 8357411015276601514)]), - ), + ),], ) end @@ -587,9 +587,9 @@ using .Abstractors: iter_source_either_values ]), ), ) - new_data = update_value(taskdata, "vert_kernel|horz_kernel|splitted|step", (-1, -1)) + new_data = update_value(taskdata, "vert_kernel|horz_kernel|splitted|step", [(-1, -1)]) @test new_data == Dict( - "vert_kernel|horz_kernel|splitted|first" => Either([ + "vert_kernel|horz_kernel|splitted|first" => Any[Either([ Option( Either([ Option(Object([0], (2, 2)), 9066318667632083547), @@ -606,17 +606,17 @@ using .Abstractors: iter_source_either_values 895491277412430600, ), Option(Object([0], (3, 2)), 1062607700805768252), - ]), - "vert_kernel|horz_kernel|splitted|step" => (-1, -1), - "vert_is_left" => AuxValue( + ])], + "vert_kernel|horz_kernel|splitted|step" => Any[(-1, -1)], + "vert_is_left" => Any[AuxValue( Either([ Option(true, 895491277412430600), Option(true, 14914076444286691944), Option(false, 1062607700805768252), Option(false, 15081192867680029596), ]), - ), - "vert_kernel|horz_is_top" => Either([ + )], + "vert_kernel|horz_is_top" => Any[Either([ Option( AuxValue(Either([Option(true, 9066318667632083547), Option(false, 2841229356805458503)])), 14914076444286691944, @@ -627,7 +627,7 @@ using .Abstractors: iter_source_either_values 895491277412430600, ), Option(AuxValue(true), 1062607700805768252), - ]), + ])], ) end end diff --git a/test/utils.jl b/test/utils.jl index 142dff3..fe8a3b7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -15,9 +15,9 @@ end (op::FakeOperation)(task_data) = task_data make_taskdata(tasks) = - TaskData(Dict{String,Vector}(), Dict(key => [task[key] for task in tasks] for key in keys(tasks[1])), Set()) + TaskData(Dict{String,Vector}(), Dict(key => Any[task[key] for task in tasks] for key in keys(tasks[1])), Set()) -make_taskdata(task::Dict) = TaskData(Dict{String,Vector}(), Dict(key => [val] for (key, val) in task), Set()) +make_taskdata(task::Dict) = TaskData(Dict{String,Vector}(), Dict(key => Any[val] for (key, val) in task), Set()) make_field_info(taskdata) = Dict(key => FieldInfo(vals[1], "input", [], [Set()]) for (key, vals) in taskdata) @@ -50,8 +50,10 @@ function _compare_operations(expected, solutions) end filtered_taskdata(solution) = [ - Dict(filter(keyval -> keyval[1] != "input" && keyval[1] != "output" && keyval[1] != "projected|output", task)) - for task in solution.taskdata + Dict( + key => values[i] for + (key, values) in solution.taskdata if key != "input" && key != "output" && key != "projected|output" + ) for (i, _) in enumerate(solution.taskdata["input"]) ] filtered_ops(solution) = From f804db854edc17b71a1e28522baa9c4ed92be35c Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Fri, 2 Jul 2021 18:40:12 +0300 Subject: [PATCH 3/8] Fix most abstractor tests --- src/abstractors/abstractors.jl | 35 +++++++---- src/abstractors/transpose.jl | 10 ++-- test/test_compact_similar_objects.jl | 6 +- test/test_group_objects.jl | 30 +++++----- test/test_repeat_obj_inf.jl | 24 ++++---- test/test_solid_objects.jl | 22 ++++--- test/test_split_object.jl | 86 +++++++++++++++------------- test/test_unite_touching.jl | 48 ++++++++++++---- 8 files changed, 157 insertions(+), 104 deletions(-) diff --git a/src/abstractors/abstractors.jl b/src/abstractors/abstractors.jl index d0c364d..38bd1e5 100644 --- a/src/abstractors/abstractors.jl +++ b/src/abstractors/abstractors.jl @@ -46,15 +46,16 @@ function Abstractor(cls::AbstractorClass, key::String, to_abs::Bool, found_aux_k end -function (p::Abstractor)(task_data) - out_data = copy(task_data) +function (p::Abstractor)(taskdata) + out_data = copy(taskdata) input_values = fetch_input_values(p, out_data) if p.to_abstract func = to_abstract_value else func = from_abstract_value end - merge!(out_data, wrap_func_call_value_root(p, func, input_values...)) + updated_values = [wrap_func_call_value_root(p, func, inputs...) for inputs in zip(input_values...)] + merge!(out_data, Dict(key => [values[key] for values in updated_values] for key in p.output_keys)) return out_data end @@ -296,24 +297,33 @@ function create( solution, key, )::Array{Tuple{Float64,NamedTuple{(:to_abstract, :from_abstract),Tuple{Abstractor,Abstractor}}},1} - if any(haskey(solution.taskdata[1], k) for k in abs_keys(cls, key)) + if any(haskey(solution.taskdata, k) for k in abs_keys(cls, key)) return [] end - found_aux_keys = [aux_keys(cls, key, task) for task in solution.taskdata] - if !all(all(length(keys) == length(aux_keys(cls))) && keys == found_aux_keys[1] for keys in found_aux_keys) + found_aux_keys = aux_keys(cls, key, solution.taskdata) + if length(found_aux_keys) != length(aux_keys(cls)) return [] end data = init_create_check_data(cls, key, solution) + if any(ismissing(value) for value in solution.taskdata[key]) + return [] + end + aux_values = get_aux_values_for_task(cls, solution.taskdata, key, solution) + if isempty(aux_values) + aux_values = fill([], length(solution.taskdata[key])) + else + aux_values = zip(aux_values...) + end + if !all( - haskey(task_data, key) && - wrap_check_task_value(cls, task_data[key], data, get_aux_values_for_task(cls, task_data, key, solution)) for - task_data in solution.taskdata + wrap_check_task_value(cls, value, data, aux_vals) for + (value, aux_vals) in zip(solution.taskdata[key], aux_values) ) return [] end output = [] - for (priority, abstractor) in create_abstractors(cls, data, key, found_aux_keys[1]) + for (priority, abstractor) in create_abstractors(cls, data, key, found_aux_keys) push!(output, (priority * (1.1^(length(split(key, '|')) - 1)), abstractor)) end output @@ -333,11 +343,12 @@ using ..PatternMatching: Matcher, unwrap_matcher wrap_check_task_value(cls::AbstractorClass, value::Matcher, data, aux_values) = all(wrap_check_task_value(cls, v, data, aux_values) for v in unwrap_matcher(value)) -get_aux_values_for_task(cls::AbstractorClass, task_data, key, solution) = - [task_data[k] for k in aux_keys(cls, key, task_data)] +get_aux_values_for_task(cls::AbstractorClass, taskdata, key, solution) = + [taskdata[k] for k in aux_keys(cls, key, taskdata)] function create_abstractors(cls::AbstractorClass, data, key, found_aux_keys) if haskey(data, "effective") && data["effective"] == false + @info 6 return [] end [( diff --git a/src/abstractors/transpose.jl b/src/abstractors/transpose.jl index 2867ea8..f22de58 100644 --- a/src/abstractors/transpose.jl +++ b/src/abstractors/transpose.jl @@ -13,25 +13,25 @@ function check_task_value(::Transpose, value::AbstractArray{Int,2}, data, aux_va end -get_aux_values_for_task(::Transpose, task_data, key, solution) = +get_aux_values_for_task(::Transpose, taskdata, key, solution) = in(key, solution.unfilled_fields) ? values( filter( kv -> - isa(kv[2], Array{Int,2}) && + isa(kv[2][1], Array{Int,2}) && kv[1] != key && (in(kv[1], solution.unfilled_fields) || in(kv[1], solution.transformed_fields)), - task_data, + taskdata, ), ) : values( filter( kv -> - isa(kv[2], Array{Int,2}) && + isa(kv[2][1], Array{Int,2}) && kv[1] != key && !in(kv[1], solution.unfilled_fields) && !in(kv[1], solution.transformed_fields), - task_data, + taskdata, ), ) diff --git a/test/test_compact_similar_objects.jl b/test/test_compact_similar_objects.jl index 4995eb3..7678bae 100644 --- a/test/test_compact_similar_objects.jl +++ b/test/test_compact_similar_objects.jl @@ -9,9 +9,9 @@ using .PatternMatching: ObjectShape, common_value reshaper = CompactSimilarObjects("key", true) out_data = reshaper(source_data) @test out_data == Dict( - "key" => Set([Object([1], (1, 1)), Object([1], (2, 3))]), - "key|common_shape" => ObjectShape(Object([1], (1, 1))), - "key|positions" => Set([(1, 1), (2, 3)]), + "key" => [Set([Object([1], (1, 1)), Object([1], (2, 3))])], + "key|common_shape" => [ObjectShape(Object([1], (1, 1)))], + "key|positions" => [Set([(1, 1), (2, 3)])], ) delete!(out_data, "key") reshaper = CompactSimilarObjects("key", false) diff --git a/test/test_group_objects.jl b/test/test_group_objects.jl index a552327..6c038af 100644 --- a/test/test_group_objects.jl +++ b/test/test_group_objects.jl @@ -17,19 +17,23 @@ import .ObjectPrior: Object, Color grouper = GroupObjectsByColor("key", true) out_data = grouper(source_data) @test out_data == Dict( - "key" => Set([ - Object([1], (1, 1)), - Object([1 1], (2, 4)), - Object([2], (2, 2)), - Object([3], (9, 1)), - Object([2], (1, 3)), - ]), - "key|grouped" => Dict( - Color(1) => Set([Object([1], (1, 1)), Object([1 1], (2, 4))]), - Color(2) => Set([Object([2], (2, 2)), Object([2], (1, 3))]), - Color(3) => Set([Object([3], (9, 1))]), - ), - "key|group_keys" => Set([Color(1), Color(2), Color(3)]), + "key" => [ + Set([ + Object([1], (1, 1)), + Object([1 1], (2, 4)), + Object([2], (2, 2)), + Object([3], (9, 1)), + Object([2], (1, 3)), + ]), + ], + "key|grouped" => [ + Dict( + Color(1) => Set([Object([1], (1, 1)), Object([1 1], (2, 4))]), + Color(2) => Set([Object([2], (2, 2)), Object([2], (1, 3))]), + Color(3) => Set([Object([3], (9, 1))]), + ), + ], + "key|group_keys" => [Set([Color(1), Color(2), Color(3)])], ) delete!(out_data, "key") ungrouper = GroupObjectsByColor("key", false) diff --git a/test/test_repeat_obj_inf.jl b/test/test_repeat_obj_inf.jl index ea8aaf9..b0bd0b3 100644 --- a/test/test_repeat_obj_inf.jl +++ b/test/test_repeat_obj_inf.jl @@ -42,16 +42,20 @@ using .PatternMatching: Either, Option repeater = RepeatObjectInfinite("input|key", true, source_data) out_data = repeater(source_data) @test out_data == Dict( - "input|key" => Set([Object([1], (1, 1)), Object([1], (2, 2)), Object([1], (3, 3))]), - "input|key|first" => Either([ - Option(Object([1], (1, 1)), hash((Object([1], (1, 1)), (1, 1)))), - Option(Object([1], (3, 3)), hash((Object([1], (3, 3)), (-1, -1)))), - ]), - "input|key|step" => Either([ - Option((1, 1), hash((Object([1], (1, 1)), (1, 1)))), - Option((-1, -1), hash((Object([1], (3, 3)), (-1, -1)))), - ]), - "input|grid_size" => (3, 3), + "input|key" => [Set([Object([1], (1, 1)), Object([1], (2, 2)), Object([1], (3, 3))])], + "input|key|first" => [ + Either([ + Option(Object([1], (1, 1)), hash((Object([1], (1, 1)), (1, 1)))), + Option(Object([1], (3, 3)), hash((Object([1], (3, 3)), (-1, -1)))), + ]), + ], + "input|key|step" => [ + Either([ + Option((1, 1), hash((Object([1], (1, 1)), (1, 1)))), + Option((-1, -1), hash((Object([1], (3, 3)), (-1, -1)))), + ]), + ], + "input|grid_size" => [(3, 3)], ) delete!(out_data, "input|key") abs_data = make_taskdata( diff --git a/test/test_solid_objects.jl b/test/test_solid_objects.jl index abf2eb8..becd242 100644 --- a/test/test_solid_objects.jl +++ b/test/test_solid_objects.jl @@ -101,15 +101,19 @@ using .ObjectPrior: Object ) out_data = abs(data) @test out_data == Dict( - "input|bgr_grid|spatial_objects" => Either([ - Option(Set{Object}([Object([2 2; 2 2], (2, 2))]), 6951943934144298334), - Option(Set{Object}([Object([0 0 0; 0 -1 -1; 0 -1 -1], (1, 1))]), 73827427852322294), - ]), - "input|bgr_grid|grid_size" => (3, 3), - "input|bgr_grid" => Either([ - Option([-1 -1 -1; -1 2 2; -1 2 2], 6951943934144298334), - Option([0 0 0; 0 -1 -1; 0 -1 -1], 73827427852322294), - ]), + "input|bgr_grid|spatial_objects" => [ + Either([ + Option(Set{Object}([Object([2 2; 2 2], (2, 2))]), 6951943934144298334), + Option(Set{Object}([Object([0 0 0; 0 -1 -1; 0 -1 -1], (1, 1))]), 73827427852322294), + ]), + ], + "input|bgr_grid|grid_size" => [(3, 3)], + "input|bgr_grid" => [ + Either([ + Option([-1 -1 -1; -1 2 2; -1 2 2], 6951943934144298334), + Option([0 0 0; 0 -1 -1; 0 -1 -1], 73827427852322294), + ]), + ], ) end end diff --git a/test/test_split_object.jl b/test/test_split_object.jl index 08c5228..fba4955 100644 --- a/test/test_split_object.jl +++ b/test/test_split_object.jl @@ -24,16 +24,18 @@ using .ObjectPrior: Object splitter = SplitObject("key", true) out_data = splitter(value) @test out_data == Dict( - "key" => Object([1; 1; 1; 1; 1; 1; 1], (1, 6)), - "key|splitted" => Set([ - Object([1], (1, 6)), - Object([1], (2, 6)), - Object([1], (3, 6)), - Object([1], (4, 6)), - Object([1], (5, 6)), - Object([1], (6, 6)), - Object([1], (7, 6)), - ]), + "key" => [Object([1; 1; 1; 1; 1; 1; 1], (1, 6))], + "key|splitted" => [ + Set([ + Object([1], (1, 6)), + Object([1], (2, 6)), + Object([1], (3, 6)), + Object([1], (4, 6)), + Object([1], (5, 6)), + Object([1], (6, 6)), + Object([1], (7, 6)), + ]), + ], ) delete!(out_data, "key") splitter = SplitObject("key", false) @@ -51,36 +53,40 @@ using .ObjectPrior: Object splitter = SplitObject("key", true) out_data = splitter(value) @test out_data == Dict( - "key" => Either([ - Option(Object([1; 1; 1; 1; 1; 1; 1], (1, 6)), 1519798033240906986), - Option(Object([1; 1; 1; 1; 1; 1; 1], (1, 18)), -8964597388769226366), - ]), - "key|splitted" => Either([ - Option( - Set([ - Object([1], (1, 6)), - Object([1], (2, 6)), - Object([1], (3, 6)), - Object([1], (4, 6)), - Object([1], (5, 6)), - Object([1], (6, 6)), - Object([1], (7, 6)), - ]), - 1519798033240906986, - ), - Option( - Set([ - Object([1], (1, 18)), - Object([1], (2, 18)), - Object([1], (3, 18)), - Object([1], (4, 18)), - Object([1], (5, 18)), - Object([1], (6, 18)), - Object([1], (7, 18)), - ]), - -8964597388769226366, - ), - ]), + "key" => [ + Either([ + Option(Object([1; 1; 1; 1; 1; 1; 1], (1, 6)), 1519798033240906986), + Option(Object([1; 1; 1; 1; 1; 1; 1], (1, 18)), -8964597388769226366), + ]), + ], + "key|splitted" => [ + Either([ + Option( + Set([ + Object([1], (1, 6)), + Object([1], (2, 6)), + Object([1], (3, 6)), + Object([1], (4, 6)), + Object([1], (5, 6)), + Object([1], (6, 6)), + Object([1], (7, 6)), + ]), + 1519798033240906986, + ), + Option( + Set([ + Object([1], (1, 18)), + Object([1], (2, 18)), + Object([1], (3, 18)), + Object([1], (4, 18)), + Object([1], (5, 18)), + Object([1], (6, 18)), + Object([1], (7, 18)), + ]), + -8964597388769226366, + ), + ]), + ], ) end end diff --git a/test/test_unite_touching.jl b/test/test_unite_touching.jl index d9563a7..8e9033e 100644 --- a/test/test_unite_touching.jl +++ b/test/test_unite_touching.jl @@ -82,7 +82,8 @@ using .Abstractors: UniteTouching ) abstractor = UniteTouching("input|spatial_objects|grouped", true) @test abstractor(task_data) == Dict( - "input|spatial_objects|grouped" => Dict{Any,Any}( + "input|spatial_objects|grouped" => [ + Dict{Any,Any}( 2 => Set([ Object([2], (1, 13)), Object( @@ -154,7 +155,9 @@ using .Abstractors: UniteTouching Object([8], (21, 14)), ]), ), - "input|spatial_objects|grouped|united_touch" => Dict( + ], + "input|spatial_objects|grouped|united_touch" => [ + Dict( 2 => Set([ Object( [ @@ -214,6 +217,7 @@ using .Abstractors: UniteTouching ), ]), ), + ], ) task_data = make_taskdata( Dict( @@ -287,7 +291,8 @@ using .Abstractors: UniteTouching ), ) @test abstractor(task_data) == Dict( - "input|spatial_objects|grouped" => Dict{Any,Any}( + "input|spatial_objects|grouped" => [ + Dict{Any,Any}( 2 => Set([ Object( [ @@ -347,7 +352,9 @@ using .Abstractors: UniteTouching Object([3], (18, 18)), ]), ), - "input|spatial_objects|grouped|united_touch" => Dict( + ], + "input|spatial_objects|grouped|united_touch" => [ + Dict( 2 => Set([ Object( [ @@ -407,10 +414,12 @@ using .Abstractors: UniteTouching ), ]), ), + ], ) next_abstrractor = UniteTouching("input|spatial_objects|grouped|united_touch", true) @test next_abstrractor(abstractor(task_data)) == Dict( - "input|spatial_objects|grouped" => Dict{Any,Any}( + "input|spatial_objects|grouped" => [ + Dict{Any,Any}( 2 => Set([ Object( [ @@ -470,7 +479,9 @@ using .Abstractors: UniteTouching Object([3], (18, 18)), ]), ), - "input|spatial_objects|grouped|united_touch" => Dict( + ], + "input|spatial_objects|grouped|united_touch" => [ + Dict( 2 => Set([ Object( [ @@ -530,7 +541,9 @@ using .Abstractors: UniteTouching ), ]), ), - "input|spatial_objects|grouped|united_touch|united_touch" => Dict( + ], + "input|spatial_objects|grouped|united_touch|united_touch" => [ + Dict( 2 => Set([ Object( [ @@ -590,6 +603,7 @@ using .Abstractors: UniteTouching ), ]), ), + ], ) task_data = make_taskdata( Dict( @@ -663,7 +677,8 @@ using .Abstractors: UniteTouching ), ) @test abstractor(task_data) == Dict( - "input|spatial_objects|grouped" => Dict{Any,Any}( + "input|spatial_objects|grouped" => [ + Dict{Any,Any}( 4 => Set([ Object( [ @@ -730,7 +745,9 @@ using .Abstractors: UniteTouching Object([1], (17, 18)), ]), ), - "input|spatial_objects|grouped|united_touch" => Dict( + ], + "input|spatial_objects|grouped|united_touch" => [ + Dict( 4 => Set([ Object( [ @@ -788,9 +805,11 @@ using .Abstractors: UniteTouching ), ]), ), + ], ) @test next_abstrractor(abstractor(task_data)) == Dict( - "input|spatial_objects|grouped" => Dict{Any,Any}( + "input|spatial_objects|grouped" => [ + Dict{Any,Any}( 4 => Set([ Object( [ @@ -857,7 +876,9 @@ using .Abstractors: UniteTouching Object([1], (17, 18)), ]), ), - "input|spatial_objects|grouped|united_touch" => Dict( + ], + "input|spatial_objects|grouped|united_touch" => [ + Dict( 4 => Set([ Object( [ @@ -915,7 +936,9 @@ using .Abstractors: UniteTouching ), ]), ), - "input|spatial_objects|grouped|united_touch|united_touch" => Dict( + ], + "input|spatial_objects|grouped|united_touch|united_touch" => [ + Dict( 4 => Set([ Object( [ @@ -973,6 +996,7 @@ using .Abstractors: UniteTouching ), ]), ), + ], ) end end From 61dff9d7df85c452d1d2b90d75501808b07f6406 Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Sat, 3 Jul 2021 14:13:53 +0300 Subject: [PATCH 4/8] Fix full task tests --- src/abstractors/abstractors.jl | 14 ++++++--- src/abstractors/select_group.jl | 53 ++++++++++++++------------------- src/solution.jl | 4 +-- test/utils.jl | 3 +- 4 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/abstractors/abstractors.jl b/src/abstractors/abstractors.jl index 38bd1e5..b82c0bc 100644 --- a/src/abstractors/abstractors.jl +++ b/src/abstractors/abstractors.jl @@ -55,7 +55,12 @@ function (p::Abstractor)(taskdata) func = from_abstract_value end updated_values = [wrap_func_call_value_root(p, func, inputs...) for inputs in zip(input_values...)] - merge!(out_data, Dict(key => [values[key] for values in updated_values] for key in p.output_keys)) + for key in p.output_keys + if all(!haskey(values, key) for values in updated_values) + continue + end + out_data[key] = Any[values[key] for values in updated_values] + end return out_data end @@ -75,8 +80,10 @@ Base.:(==)(a::Abstractor, b::Abstractor) = a.cls == b.cls && a.to_abstract == b.to_abstract && a.input_keys == b.input_keys && a.output_keys == b.output_keys -fetch_input_values(p::Abstractor, task_data) = - [in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, nothing) for k in p.input_keys] +fetch_input_values(p::Abstractor, task_data) = [ + in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, fill(nothing, length(task_data["input"]))) for + k in p.input_keys +] using DataStructures: DefaultDict @@ -348,7 +355,6 @@ get_aux_values_for_task(cls::AbstractorClass, taskdata, key, solution) = function create_abstractors(cls::AbstractorClass, data, key, found_aux_keys) if haskey(data, "effective") && data["effective"] == false - @info 6 return [] end [( diff --git a/src/abstractors/select_group.jl b/src/abstractors/select_group.jl index c2bb94a..929d3f3 100644 --- a/src/abstractors/select_group.jl +++ b/src/abstractors/select_group.jl @@ -33,10 +33,7 @@ function create( )::Array{Tuple{Float64,NamedTuple{(:to_abstract, :from_abstract),Tuple{Abstractor,Abstractor}}},1} data = init_create_check_data(cls, key, solution) - if !all( - haskey(task_data, key) && check_task_value(cls, task_data[key], data, task_data) for - task_data in solution.taskdata - ) + if !all(!ismissing(val) && check_task_value(cls, val, data, []) for val in solution.taskdata[key]) return [] end output = [] @@ -48,39 +45,32 @@ end function init_create_check_data(::SelectGroup, key, solution) - data = Dict( - "existing_choices" => Set{String}(), - "key" => key, - "effective" => false, - "field_info" => solution.field_info, - ) + data = Dict("effective" => false, "allowed_choices" => Set{String}()) + existing_choices = Set{String}() matcher = Regex("$(replace(key, '|' => "\\|"))\\|selected_by\\|(.*)") - sample_task = solution.taskdata[1] - for k in keys(sample_task) + for k in keys(solution.taskdata) m = match(matcher, k) - if !isnothing(m) && haskey(sample_task, m.captures[1]) - push!(data["existing_choices"], m.captures[1]) + if !isnothing(m) && haskey(solution.taskdata, m.captures[1]) + push!(existing_choices, m.captures[1]) + end + end + for k in keys(solution.taskdata) + if k != key && + in(key, solution.field_info[k].previous_fields) && + !in(k, existing_choices) && + all( + isa(dict_value, AbstractDict) && haskey(dict_value, candidate_key) for + (dict_value, candidate_key) in zip(solution.taskdata[key], solution.taskdata[k]) + ) + push!(data["allowed_choices"], k) end end data end -function check_task_value(::SelectGroup, value::AbstractDict, data, task_data) +function check_task_value(::SelectGroup, value::AbstractDict, data, aux_values) data["effective"] |= length(value) > 1 - if !haskey(data, "allowed_choices") - data["allowed_choices"] = Set{String}() - for (key, data_value) in task_data - if key != data["key"] && - in(data["key"], data["field_info"][key].previous_fields) && - !in(key, data["existing_choices"]) && - haskey(value, data_value) - push!(data["allowed_choices"], key) - end - end - else - filter!(key -> haskey(value, task_data[key]), data["allowed_choices"]) - end - return !isempty(data["allowed_choices"]) + return true end function create_abstractors(cls::SelectGroup, data, key) @@ -117,9 +107,10 @@ function to_abstract_value(p::Abstractor{SelectGroup}, source_value::AbstractDic out = update_value( TaskData(Dict{String,Any}(), Dict{String,Any}(), Set()), p.output_keys[1], - source_value[selected_key], + [source_value[selected_key]], ) - update_value(out, p.output_keys[2], rejected) + out = update_value(out, p.output_keys[2], [rejected]) + Dict(p.output_keys[1] => out[p.output_keys[1]][1], p.output_keys[2] => out[p.output_keys[2]][1]) end function from_abstract_value(p::Abstractor{SelectGroup}, selected, rejected, selector_key) diff --git a/src/solution.jl b/src/solution.jl index a6eedd4..0c226a2 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -689,8 +689,8 @@ Base.show(io::IO, s::Solution) = print( "\n)", ) -function (solution::Solution)(input_grids::Vector{Array{Int,2}})::Array{Int,2} - observed_data = TaskData(Dict{String,Vector}("input" => input_grids), Dict{String,Any}(), Set()) +function (solution::Solution)(input_grids::Vector{Array{Int,2}})::Vector{Array{Int,2}} + observed_data = TaskData(Dict{String,Vector}("input" => input_grids), Dict{String,Vector}(), Set()) for block in solution.blocks observed_data = block(observed_data) end diff --git a/test/utils.jl b/test/utils.jl index fe8a3b7..19994a1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -77,6 +77,7 @@ end using .FindSolution: validate_results function test_solution(solution, test_data) - answer = [solution(task["input"]) for task in test_data] + test_input = [task["input"] for task in test_data] + answer = solution(test_input) validate_results(test_data, [answer]) end From e802a6a254fb2398a7419e7b2248bc9c5a598c19 Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Sat, 3 Jul 2021 14:22:08 +0300 Subject: [PATCH 5/8] Store number of examples in taskdata explicitly --- src/abstractors/abstractors.jl | 2 +- src/operations/inc_param.jl | 2 +- src/operations/mult_param.jl | 2 +- src/operations/set_const.jl | 2 +- src/solution.jl | 6 +++--- src/taskdata/taskdata.jl | 7 +++++++ 6 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/abstractors/abstractors.jl b/src/abstractors/abstractors.jl index b82c0bc..249c8d9 100644 --- a/src/abstractors/abstractors.jl +++ b/src/abstractors/abstractors.jl @@ -81,7 +81,7 @@ Base.:(==)(a::Abstractor, b::Abstractor) = fetch_input_values(p::Abstractor, task_data) = [ - in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, fill(nothing, length(task_data["input"]))) for + in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, fill(nothing, task_data.num_examples)) for k in p.input_keys ] diff --git a/src/operations/inc_param.jl b/src/operations/inc_param.jl index a9dd790..36bcf17 100644 --- a/src/operations/inc_param.jl +++ b/src/operations/inc_param.jl @@ -17,5 +17,5 @@ Base.hash(op::IncParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_key function (op::IncParam)(taskdata::TaskData) output_value = [apply_func(val, (a, b) -> a .+ b, op.shift) for val in taskdata[op.input_keys[1]]] data = update_value(taskdata, op.output_keys[1], output_value) - update_value(data, op.output_keys[2], fill(op.shift, length(taskdata["input"]))) + update_value(data, op.output_keys[2], fill(op.shift, taskdata.num_examples)) end diff --git a/src/operations/mult_param.jl b/src/operations/mult_param.jl index c32aae8..34f009d 100644 --- a/src/operations/mult_param.jl +++ b/src/operations/mult_param.jl @@ -17,5 +17,5 @@ Base.hash(op::MultParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_ke function (op::MultParam)(taskdata::TaskData) output_value = [apply_func(val, (a, b) -> a .* b, op.factor) for val in taskdata[op.input_keys[1]]] data = update_value(taskdata, op.output_keys[1], output_value) - update_value(data, op.output_keys[2], fill(op.factor, length(taskdata["input"]))) + update_value(data, op.output_keys[2], fill(op.factor, taskdata.num_examples)) end diff --git a/src/operations/set_const.jl b/src/operations/set_const.jl index 5aa6072..ffc5986 100644 --- a/src/operations/set_const.jl +++ b/src/operations/set_const.jl @@ -14,4 +14,4 @@ Base.hash(op::SetConst, h::UInt64) = hash(op.output_keys, h) + hash(op.value, h) (op::SetConst)(taskdata::TaskData) = - update_value(taskdata, op.output_keys[1], fill(op.value, length(taskdata["input"]))) + update_value(taskdata, op.output_keys[1], fill(op.value, taskdata.num_examples)) diff --git a/src/solution.jl b/src/solution.jl index 0c226a2..ea8e6c3 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -111,8 +111,8 @@ struct Solution input_transformed_fields, complexity_score::Float64, ) - inp_val_hashes = fill(Set{UInt64}(), length(taskdata["input"])) - out_val_hashes = fill(Set{UInt64}(), length(taskdata["input"])) + inp_val_hashes = fill(Set{UInt64}(), taskdata.num_examples) + out_val_hashes = fill(Set{UInt64}(), taskdata.num_examples) for (key, values) in taskdata for (i, value) in enumerate(values) if in(key, transformed_fields) || in(key, filled_fields) || in(key, unfilled_fields) @@ -721,7 +721,7 @@ end function get_score(taskdata::TaskData, complexity_score)::Int score = compare_grids( taskdata["output"], - get(taskdata, "projected|output", fill(Array{Int}(undef, 0, 0), length(taskdata["output"]))), + get(taskdata, "projected|output", fill(Array{Int}(undef, 0, 0), taskdata.num_examples)), ) # if complexity_score > 100 # score += floor(complexity_score) diff --git a/src/taskdata/taskdata.jl b/src/taskdata/taskdata.jl index 4f3d371..000b0a2 100644 --- a/src/taskdata/taskdata.jl +++ b/src/taskdata/taskdata.jl @@ -5,6 +5,13 @@ struct TaskData <: AbstractDict{String,Any} persistent_data::Dict{String,Vector} updated_values::Dict{String,Vector} keys_to_delete::Set{String} + num_examples::Int + TaskData(persistent_data, updated_values, keys_to_delete) = new( + persistent_data, + updated_values, + keys_to_delete, + length(first(isempty(persistent_data) ? updated_values : persistent_data)[2]), + ) end Base.show(io::IO, t::TaskData) = From 2e6b7c15b1d686dfc6e8d728abbdb100fb013d5c Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Sat, 3 Jul 2021 15:18:14 +0300 Subject: [PATCH 6/8] Dinamycally calculate number of examples --- src/abstractors/abstractors.jl | 3 ++- src/operations/inc_param.jl | 2 +- src/operations/mult_param.jl | 2 +- src/operations/operation.jl | 2 +- src/operations/set_const.jl | 2 +- src/solution.jl | 8 ++++---- src/taskdata/taskdata.jl | 9 ++------- 7 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/abstractors/abstractors.jl b/src/abstractors/abstractors.jl index 249c8d9..ea6cfc8 100644 --- a/src/abstractors/abstractors.jl +++ b/src/abstractors/abstractors.jl @@ -79,9 +79,10 @@ Base.show(io::IO, p::Abstractor) = print( Base.:(==)(a::Abstractor, b::Abstractor) = a.cls == b.cls && a.to_abstract == b.to_abstract && a.input_keys == b.input_keys && a.output_keys == b.output_keys +using ..Taskdata: num_examples fetch_input_values(p::Abstractor, task_data) = [ - in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, fill(nothing, task_data.num_examples)) for + in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, fill(nothing, num_examples(task_data))) for k in p.input_keys ] diff --git a/src/operations/inc_param.jl b/src/operations/inc_param.jl index 36bcf17..cc33990 100644 --- a/src/operations/inc_param.jl +++ b/src/operations/inc_param.jl @@ -17,5 +17,5 @@ Base.hash(op::IncParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_key function (op::IncParam)(taskdata::TaskData) output_value = [apply_func(val, (a, b) -> a .+ b, op.shift) for val in taskdata[op.input_keys[1]]] data = update_value(taskdata, op.output_keys[1], output_value) - update_value(data, op.output_keys[2], fill(op.shift, taskdata.num_examples)) + update_value(data, op.output_keys[2], fill(op.shift, num_examples(taskdata))) end diff --git a/src/operations/mult_param.jl b/src/operations/mult_param.jl index 34f009d..92860d8 100644 --- a/src/operations/mult_param.jl +++ b/src/operations/mult_param.jl @@ -17,5 +17,5 @@ Base.hash(op::MultParam, h::UInt64) = hash(op.output_keys, h) + hash(op.input_ke function (op::MultParam)(taskdata::TaskData) output_value = [apply_func(val, (a, b) -> a .* b, op.factor) for val in taskdata[op.input_keys[1]]] data = update_value(taskdata, op.output_keys[1], output_value) - update_value(data, op.output_keys[2], fill(op.factor, taskdata.num_examples)) + update_value(data, op.output_keys[2], fill(op.factor, num_examples(taskdata))) end diff --git a/src/operations/operation.jl b/src/operations/operation.jl index 08c62eb..c668187 100644 --- a/src/operations/operation.jl +++ b/src/operations/operation.jl @@ -11,7 +11,7 @@ get_sorting_keys(operation::Operation) = operation.output_keys needed_input_keys(operation::Operation) = operation.input_keys using ..PatternMatching: update_value, apply_func -using ..Taskdata: TaskData +using ..Taskdata: TaskData, num_examples include("set_const.jl") include("copy_param.jl") diff --git a/src/operations/set_const.jl b/src/operations/set_const.jl index ffc5986..657d023 100644 --- a/src/operations/set_const.jl +++ b/src/operations/set_const.jl @@ -14,4 +14,4 @@ Base.hash(op::SetConst, h::UInt64) = hash(op.output_keys, h) + hash(op.value, h) (op::SetConst)(taskdata::TaskData) = - update_value(taskdata, op.output_keys[1], fill(op.value, taskdata.num_examples)) + update_value(taskdata, op.output_keys[1], fill(op.value, num_examples(taskdata))) diff --git a/src/solution.jl b/src/solution.jl index ea8e6c3..06bd9c7 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -34,7 +34,7 @@ end Base.show(io::IO, b::Block) = print(io, "Block([\n", (vcat((["\t\t", op, ",\n"] for op in b.operations)...))..., "\t])") -using ..Taskdata: TaskData, persist_data +using ..Taskdata: TaskData, persist_data, num_examples function (block::Block)(observed_data::TaskData)::TaskData for op in block.operations @@ -111,8 +111,8 @@ struct Solution input_transformed_fields, complexity_score::Float64, ) - inp_val_hashes = fill(Set{UInt64}(), taskdata.num_examples) - out_val_hashes = fill(Set{UInt64}(), taskdata.num_examples) + inp_val_hashes = fill(Set{UInt64}(), num_examples(taskdata)) + out_val_hashes = fill(Set{UInt64}(), num_examples(taskdata)) for (key, values) in taskdata for (i, value) in enumerate(values) if in(key, transformed_fields) || in(key, filled_fields) || in(key, unfilled_fields) @@ -721,7 +721,7 @@ end function get_score(taskdata::TaskData, complexity_score)::Int score = compare_grids( taskdata["output"], - get(taskdata, "projected|output", fill(Array{Int}(undef, 0, 0), taskdata.num_examples)), + get(taskdata, "projected|output", fill(Array{Int}(undef, 0, 0), num_examples(taskdata))), ) # if complexity_score > 100 # score += floor(complexity_score) diff --git a/src/taskdata/taskdata.jl b/src/taskdata/taskdata.jl index 000b0a2..1242aaf 100644 --- a/src/taskdata/taskdata.jl +++ b/src/taskdata/taskdata.jl @@ -5,15 +5,10 @@ struct TaskData <: AbstractDict{String,Any} persistent_data::Dict{String,Vector} updated_values::Dict{String,Vector} keys_to_delete::Set{String} - num_examples::Int - TaskData(persistent_data, updated_values, keys_to_delete) = new( - persistent_data, - updated_values, - keys_to_delete, - length(first(isempty(persistent_data) ? updated_values : persistent_data)[2]), - ) end +num_examples(taskdata::TaskData) = isempty(taskdata) ? 0 : length(first(taskdata)[2]) + Base.show(io::IO, t::TaskData) = print(io, "TaskData(", t.persistent_data, ", ", t.updated_values, ", ", t.keys_to_delete, ")") From d526a759533f7acddf7e6263e4565be7bf605e97 Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Tue, 6 Jul 2021 11:26:56 +0300 Subject: [PATCH 7/8] Store hash values for all examples together --- src/abstractors/abstractors.jl | 2 +- src/abstractors/select_group.jl | 1 + src/find_solution.jl | 12 +++++++----- src/pattern_matching/update_value.jl | 2 +- src/solution.jl | 22 ++++++++++------------ 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/abstractors/abstractors.jl b/src/abstractors/abstractors.jl index ea6cfc8..ecc3eab 100644 --- a/src/abstractors/abstractors.jl +++ b/src/abstractors/abstractors.jl @@ -312,7 +312,6 @@ function create( if length(found_aux_keys) != length(aux_keys(cls)) return [] end - data = init_create_check_data(cls, key, solution) if any(ismissing(value) for value in solution.taskdata[key]) return [] @@ -324,6 +323,7 @@ function create( aux_values = zip(aux_values...) end + data = init_create_check_data(cls, key, solution) if !all( wrap_check_task_value(cls, value, data, aux_vals) for (value, aux_vals) in zip(solution.taskdata[key], aux_values) diff --git a/src/abstractors/select_group.jl b/src/abstractors/select_group.jl index 929d3f3..0896f29 100644 --- a/src/abstractors/select_group.jl +++ b/src/abstractors/select_group.jl @@ -31,6 +31,7 @@ function create( solution, key, )::Array{Tuple{Float64,NamedTuple{(:to_abstract, :from_abstract),Tuple{Abstractor,Abstractor}}},1} + return [] data = init_create_check_data(cls, key, solution) if !all(!ismissing(val) && check_task_value(cls, val, data, []) for val in solution.taskdata[key]) diff --git a/src/find_solution.jl b/src/find_solution.jl index 0ac8495..8e71c16 100644 --- a/src/find_solution.jl +++ b/src/find_solution.jl @@ -108,11 +108,13 @@ function is_subsolution(old_sol::Solution, new_sol::Solution)::Bool if !issetequal(keys(new_sol.taskdata), keys(old_sol.taskdata)) equals = false end - for (new_inp_vals, new_out_vals, old_inp_vals, old_out_vals) in - zip(new_sol.inp_val_hashes, new_sol.out_val_hashes, old_sol.inp_val_hashes, old_sol.out_val_hashes) - if !issubset(new_out_vals, old_out_vals) || !issubset(new_inp_vals, old_inp_vals) - return false - end + new_inp_vals = Set(values(new_sol.inp_val_hashes)) + new_out_vals = Set(values(new_sol.out_val_hashes)) + old_inp_vals = Set(values(old_sol.inp_val_hashes)) + old_out_vals = Set(values(old_sol.out_val_hashes)) + + if !issubset(new_out_vals, old_out_vals) || !issubset(new_inp_vals, old_inp_vals) + return false end if equals && old_sol != new_sol return false diff --git a/src/pattern_matching/update_value.jl b/src/pattern_matching/update_value.jl index 9e4072b..2007199 100644 --- a/src/pattern_matching/update_value.jl +++ b/src/pattern_matching/update_value.jl @@ -29,7 +29,7 @@ function update_value(data::TaskData, path_keys::Array, value, ::Any)::TaskData data end -function update_value(data::TaskData, path_keys::Array, value, current_value::Vector)::TaskData +function update_value(data::TaskData, path_keys::Array, value, current_value::AbstractVector)::TaskData for (i, val) in enumerate(value) data = update_value(data, vcat(path_keys, [i]), val, current_value[i]) end diff --git a/src/solution.jl b/src/solution.jl index 06bd9c7..6b12080 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -97,8 +97,8 @@ struct Solution input_transformed_fields::Set{String} complexity_score::Float64 score::Int - inp_val_hashes::Vector{Set{UInt64}} - out_val_hashes::Vector{Set{UInt64}} + inp_val_hashes::Dict{String,UInt64} + out_val_hashes::Dict{String,UInt64} function Solution( taskdata, field_info, @@ -111,16 +111,14 @@ struct Solution input_transformed_fields, complexity_score::Float64, ) - inp_val_hashes = fill(Set{UInt64}(), num_examples(taskdata)) - out_val_hashes = fill(Set{UInt64}(), num_examples(taskdata)) + inp_val_hashes = Dict{String,UInt64}() + out_val_hashes = Dict{String,UInt64}() for (key, values) in taskdata - for (i, value) in enumerate(values) - if in(key, transformed_fields) || in(key, filled_fields) || in(key, unfilled_fields) - push!(out_val_hashes[i], hash(value)) - end - if in(key, unused_fields) || in(key, used_fields) || in(key, input_transformed_fields) - push!(inp_val_hashes[i], hash(value)) - end + if in(key, transformed_fields) || in(key, filled_fields) || in(key, unfilled_fields) + out_val_hashes[key] = hash(values) + end + if in(key, unused_fields) || in(key, used_fields) || in(key, input_transformed_fields) + inp_val_hashes[key] = hash(values) end end new( @@ -706,7 +704,7 @@ function check_task(solution::Solution, input_grids::Vector{Array{Int,2}}, targe compare_grids(targets, out) end -function compare_grids(targets::Vector, outputs::Vector) +function compare_grids(targets::AbstractVector, outputs::AbstractVector) result = 0 for (target, output) in zip(targets, outputs) if size(target) != size(output) From 81f50a8803a59f6eb7ed30a455696eac826791d1 Mon Sep 17 00:00:00 2001 From: Andrey Zakharevich Date: Wed, 7 Jul 2021 13:54:34 +0300 Subject: [PATCH 8/8] check hashes only for needed example --- src/pattern_matching/aux_value.jl | 2 +- src/pattern_matching/either.jl | 28 +++++++++++++++++----------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/pattern_matching/aux_value.jl b/src/pattern_matching/aux_value.jl index bae89e8..a7041f5 100644 --- a/src/pattern_matching/aux_value.jl +++ b/src/pattern_matching/aux_value.jl @@ -22,7 +22,7 @@ unpack_value(p::AuxValue) = isa(p.value, Matcher) ? [] : unpack_value(p.value) unwrap_matcher(p::AuxValue) = [p.value] -_update_value(data::TaskData, value, current_value::AuxValue) = _update_value(data, value, current_value.value) +_update_value(data::TaskData, example_num, value, current_value::AuxValue) = _update_value(data, example_num, value, current_value.value) function _drop_hashes(data::AuxValue, hashes) modified, effective, mod_hashes = _drop_hashes(data.value, hashes) diff --git a/src/pattern_matching/either.jl b/src/pattern_matching/either.jl index abadf30..0b40806 100644 --- a/src/pattern_matching/either.jl +++ b/src/pattern_matching/either.jl @@ -93,36 +93,42 @@ update_value(data::TaskData, path_keys::Array, value::Either, current_value::Eit invoke(update_value, Tuple{TaskData,Array,Any,Any}, data, path_keys, value, current_value) function update_value(data::TaskData, path_keys::Array, value, current_value::Either)::TaskData - data = _update_value(data, value, current_value) + data = _update_value(data, path_keys[2], value, current_value) return invoke(update_value, Tuple{TaskData,Array,Any,Any}, data, path_keys, value, current_value) end -function _update_value(data::TaskData, value, current_value::Either)::TaskData +function _update_value(data::TaskData, example_num, value, current_value::Either)::TaskData + hashes_to_del = Set() + matched_options = [] for option in current_value.options if isnothing(common_value(value, option.value)) if !isnothing(option.option_hash) - hashes_to_del = Set([option.option_hash]) - while !isempty(hashes_to_del) - data, hashes_to_del = drop_hashes(data, hashes_to_del) - end + push!(hashes_to_del, option.option_hash) end else - data = _update_value(data, value, option.value) + push!(matched_options, option) end end + while !isempty(hashes_to_del) + data, hashes_to_del = drop_hashes(data, example_num, hashes_to_del) + end + for option in matched_options + data = _update_value(data, example_num, value, option.value) + end return data end -_update_value(data::TaskData, value, current_value) = data +_update_value(data::TaskData, example_num, value, current_value) = data -function drop_hashes(data::TaskData, hashes) +function drop_hashes(data::TaskData, example_num, hashes) data = copy(data) new_hashes = Set() for (key, value) in data - modified, effective, mod_hashes = _drop_hashes(value, hashes) + modified, effective, mod_hashes = _drop_hashes(value[example_num], hashes) if effective - data[key] = modified + data[key] = copy(value) + data[key][example_num] = modified union!(new_hashes, mod_hashes) end end