Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use single taskdata object for all the examples #4

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions src/abstractors/abstractors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,21 @@ function Abstractor(cls::AbstractorClass, key::String, to_abs::Bool, found_aux_k
end


function (p::Abstractor)(task_data)
out_data = copy(task_data)
function (p::Abstractor)(taskdata)
out_data = copy(taskdata)
input_values = fetch_input_values(p, out_data)
if p.to_abstract
func = to_abstract_value
else
func = from_abstract_value
end
merge!(out_data, wrap_func_call_value_root(p, func, input_values...))
updated_values = [wrap_func_call_value_root(p, func, inputs...) for inputs in zip(input_values...)]
for key in p.output_keys
if all(!haskey(values, key) for values in updated_values)
continue
end
out_data[key] = Any[values[key] for values in updated_values]
end
return out_data
end

Expand All @@ -73,9 +79,12 @@ Base.show(io::IO, p::Abstractor) = print(
Base.:(==)(a::Abstractor, b::Abstractor) =
a.cls == b.cls && a.to_abstract == b.to_abstract && a.input_keys == b.input_keys && a.output_keys == b.output_keys

using ..Taskdata: num_examples

fetch_input_values(p::Abstractor, task_data) =
[in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, nothing) for k in p.input_keys]
fetch_input_values(p::Abstractor, task_data) = [
in(k, needed_input_keys(p)) ? task_data[k] : get(task_data, k, fill(nothing, num_examples(task_data))) for
k in p.input_keys
]

using DataStructures: DefaultDict

Expand Down Expand Up @@ -296,24 +305,33 @@ function create(
solution,
key,
)::Array{Tuple{Float64,NamedTuple{(:to_abstract, :from_abstract),Tuple{Abstractor,Abstractor}}},1}
if any(haskey(solution.taskdata[1], k) for k in abs_keys(cls, key))
if any(haskey(solution.taskdata, k) for k in abs_keys(cls, key))
return []
end
found_aux_keys = [aux_keys(cls, key, task) for task in solution.taskdata]
if !all(all(length(keys) == length(aux_keys(cls))) && keys == found_aux_keys[1] for keys in found_aux_keys)
found_aux_keys = aux_keys(cls, key, solution.taskdata)
if length(found_aux_keys) != length(aux_keys(cls))
return []
end
data = init_create_check_data(cls, key, solution)

if any(ismissing(value) for value in solution.taskdata[key])
return []
end
aux_values = get_aux_values_for_task(cls, solution.taskdata, key, solution)
if isempty(aux_values)
aux_values = fill([], length(solution.taskdata[key]))
else
aux_values = zip(aux_values...)
end

data = init_create_check_data(cls, key, solution)
if !all(
haskey(task_data, key) &&
wrap_check_task_value(cls, task_data[key], data, get_aux_values_for_task(cls, task_data, key, solution)) for
task_data in solution.taskdata
wrap_check_task_value(cls, value, data, aux_vals) for
(value, aux_vals) in zip(solution.taskdata[key], aux_values)
)
return []
end
output = []
for (priority, abstractor) in create_abstractors(cls, data, key, found_aux_keys[1])
for (priority, abstractor) in create_abstractors(cls, data, key, found_aux_keys)
push!(output, (priority * (1.1^(length(split(key, '|')) - 1)), abstractor))
end
output
Expand All @@ -333,8 +351,8 @@ using ..PatternMatching: Matcher, unwrap_matcher
wrap_check_task_value(cls::AbstractorClass, value::Matcher, data, aux_values) =
all(wrap_check_task_value(cls, v, data, aux_values) for v in unwrap_matcher(value))

get_aux_values_for_task(cls::AbstractorClass, task_data, key, solution) =
[task_data[k] for k in aux_keys(cls, key, task_data)]
get_aux_values_for_task(cls::AbstractorClass, taskdata, key, solution) =
[taskdata[k] for k in aux_keys(cls, key, taskdata)]

function create_abstractors(cls::AbstractorClass, data, key, found_aux_keys)
if haskey(data, "effective") && data["effective"] == false
Expand All @@ -356,7 +374,6 @@ include("background_color.jl")
include("solid_objects.jl")
include("group_obj_by_color.jl")
include("compact_similar_objects.jl")
# include("sort_array.jl")
include("transpose.jl")
include("repeat_object_infinite.jl")
include("unwrap_tuple.jl")
Expand Down
54 changes: 23 additions & 31 deletions src/abstractors/select_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@ function create(
solution,
key,
)::Array{Tuple{Float64,NamedTuple{(:to_abstract, :from_abstract),Tuple{Abstractor,Abstractor}}},1}
return []
data = init_create_check_data(cls, key, solution)

if !all(
haskey(task_data, key) && check_task_value(cls, task_data[key], data, task_data) for
task_data in solution.taskdata
)
if !all(!ismissing(val) && check_task_value(cls, val, data, []) for val in solution.taskdata[key])
return []
end
output = []
Expand All @@ -48,39 +46,32 @@ end


function init_create_check_data(::SelectGroup, key, solution)
data = Dict(
"existing_choices" => Set{String}(),
"key" => key,
"effective" => false,
"field_info" => solution.field_info,
)
data = Dict("effective" => false, "allowed_choices" => Set{String}())
existing_choices = Set{String}()
matcher = Regex("$(replace(key, '|' => "\\|"))\\|selected_by\\|(.*)")
sample_task = solution.taskdata[1]
for k in keys(sample_task)
for k in keys(solution.taskdata)
m = match(matcher, k)
if !isnothing(m) && haskey(sample_task, m.captures[1])
push!(data["existing_choices"], m.captures[1])
if !isnothing(m) && haskey(solution.taskdata, m.captures[1])
push!(existing_choices, m.captures[1])
end
end
for k in keys(solution.taskdata)
if k != key &&
in(key, solution.field_info[k].previous_fields) &&
!in(k, existing_choices) &&
all(
isa(dict_value, AbstractDict) && haskey(dict_value, candidate_key) for
(dict_value, candidate_key) in zip(solution.taskdata[key], solution.taskdata[k])
)
push!(data["allowed_choices"], k)
end
end
data
end

function check_task_value(::SelectGroup, value::AbstractDict, data, task_data)
function check_task_value(::SelectGroup, value::AbstractDict, data, aux_values)
data["effective"] |= length(value) > 1
if !haskey(data, "allowed_choices")
data["allowed_choices"] = Set{String}()
for (key, data_value) in task_data
if key != data["key"] &&
in(data["key"], data["field_info"][key].previous_fields) &&
!in(key, data["existing_choices"]) &&
haskey(value, data_value)
push!(data["allowed_choices"], key)
end
end
else
filter!(key -> haskey(value, task_data[key]), data["allowed_choices"])
end
return !isempty(data["allowed_choices"])
return true
end

function create_abstractors(cls::SelectGroup, data, key)
Expand Down Expand Up @@ -117,9 +108,10 @@ function to_abstract_value(p::Abstractor{SelectGroup}, source_value::AbstractDic
out = update_value(
TaskData(Dict{String,Any}(), Dict{String,Any}(), Set()),
p.output_keys[1],
source_value[selected_key],
[source_value[selected_key]],
)
update_value(out, p.output_keys[2], rejected)
out = update_value(out, p.output_keys[2], [rejected])
Dict(p.output_keys[1] => out[p.output_keys[1]][1], p.output_keys[2] => out[p.output_keys[2]][1])
end

function from_abstract_value(p::Abstractor{SelectGroup}, selected, rejected, selector_key)
Expand Down
22 changes: 0 additions & 22 deletions src/abstractors/sort_array.jl

This file was deleted.

10 changes: 5 additions & 5 deletions src/abstractors/transpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@ function check_task_value(::Transpose, value::AbstractArray{Int,2}, data, aux_va
end


get_aux_values_for_task(::Transpose, task_data, key, solution) =
get_aux_values_for_task(::Transpose, taskdata, key, solution) =
in(key, solution.unfilled_fields) ?
values(
filter(
kv ->
isa(kv[2], Array{Int,2}) &&
isa(kv[2][1], Array{Int,2}) &&
kv[1] != key &&
(in(kv[1], solution.unfilled_fields) || in(kv[1], solution.transformed_fields)),
task_data,
taskdata,
),
) :
values(
filter(
kv ->
isa(kv[2], Array{Int,2}) &&
isa(kv[2][1], Array{Int,2}) &&
kv[1] != key &&
!in(kv[1], solution.unfilled_fields) &&
!in(kv[1], solution.transformed_fields),
task_data,
taskdata,
),
)

Expand Down
10 changes: 5 additions & 5 deletions src/data_transformers/find_const.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@

using ..Operations: SetConst

function find_const(taskdata::Vector{TaskData}, _, _, key::String)::Vector{SetConst}
function find_const(taskdata::TaskData, _, _, key::String)::Vector{SetConst}
result = nothing
if !in(key, updated_keys(taskdata))
return []
end
for task_data in taskdata
if !haskey(task_data, key)
for value in taskdata[key]
if ismissing(value)
continue
end
if isnothing(result)
result = task_data[key]
result = value
end
possible_value = common_value(result, task_data[key])
possible_value = common_value(result, value)
if isnothing(possible_value)
return []
end
Expand Down
12 changes: 5 additions & 7 deletions src/data_transformers/find_dependent_key.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@

using ..Operations: CopyParam

function find_dependent_key(taskdata::Vector{TaskData}, field_info, invalid_sources::AbstractSet{String}, key::String)
function find_dependent_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String)
upd_keys = updated_keys(taskdata)
skipmissing(
imap(keys(taskdata[1])) do input_key
imap(keys(taskdata)) do input_key
if in(input_key, invalid_sources) ||
field_info[key].type != field_info[input_key].type ||
(!in(key, upd_keys) && !in(input_key, upd_keys))
return missing
end
for task_data in taskdata
if !haskey(task_data, input_key)
for (input_value, out_value) in zip(taskdata[input_key], taskdata[key])
if ismissing(input_value)
return missing
end
if !haskey(task_data, key)
if ismissing(out_value)
continue
end
input_value = task_data[input_key]
out_value = task_data[key]
if !check_match(input_value, out_value)
return missing
end
Expand Down
25 changes: 7 additions & 18 deletions src/data_transformers/find_matching_obj_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,32 @@ using ..Abstractors: Abstractor, SelectGroup
_check_group_type(::Type, ::Type) = false
_check_group_type(::Type{Dict{K,V}}, expected::Type) where {K,V} = V == expected

function _get_matching_transformers(
taskdata::Vector{TaskData},
field_info,
invalid_sources::AbstractSet{String},
key::String,
)
if endswith(key, "|selected_group") || any(!haskey(task_data, key) for task_data in taskdata)
function _get_matching_transformers(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String)
if endswith(key, "|selected_group") || any(ismissing(val) for val in taskdata[key])
return []
end
upd_keys = updated_keys(taskdata)
flatten(
imap(keys(taskdata[1])) do input_key
imap(keys(taskdata)) do input_key
if in(input_key, invalid_sources) ||
(!in(key, upd_keys) && !in(input_key, upd_keys)) ||
!_check_group_type(field_info[input_key].type, field_info[key].type) ||
any(!haskey(task, input_key) for task in taskdata) ||
all(length(task[input_key]) <= 1 for task in taskdata)
any(ismissing(input_value) for input_value in taskdata[input_key]) ||
all(length(input_value) <= 1 for input_value in taskdata[input_key])
return []
end
matching_groups = []
for task_data in taskdata
if !haskey(task_data, input_key)
return []
end
input_value = task_data[input_key]
out_value = task_data[key]
for (input_value, out_value) in zip(taskdata[input_key], taskdata[key])
if !check_matching_group(input_value, out_value, matching_groups)
return []

end
end
imap(unroll_groups(matching_groups)) do group_keys
key_name = key * "|selected_group"
to_abs = MapValues(
key_name,
"output",
Dict(task_data["output"] => value for (task_data, value) in zip(taskdata, group_keys)),
Dict(output_value => value for (output_value, value) in zip(taskdata["output"], group_keys)),
)
from_abs = Abstractor(SelectGroup(), true, [input_key, key_name], [key, key * "|rejected"], String[])
return (to_abstract = to_abs, from_abstract = from_abs)
Expand Down
7 changes: 3 additions & 4 deletions src/data_transformers/find_neg_shift_by_key.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@

using ..Operations: DecByParam

_shifted_neg_key_filter(shift_key, input_value, output_value, task_data) =
haskey(task_data, shift_key) &&
check_match(apply_func(input_value, (x, y) -> x .- y, task_data[shift_key]), output_value)
_shifted_neg_key_filter(shift_value, input_value, output_value) =
check_match(apply_func(input_value, (x, y) -> x .- y, shift_value), output_value)

find_neg_shifted_by_key(taskdata::Vector{TaskData}, field_info, invalid_sources::AbstractSet{String}, key::String) =
find_neg_shifted_by_key(taskdata::TaskData, field_info, invalid_sources::AbstractSet{String}, key::String) =
find_matching_for_key(
taskdata,
field_info,
Expand Down
Loading