Skip to content

Commit

Permalink
OCE algorithm bug fix and verbose mode output improvement (#269)
Browse files Browse the repository at this point in the history
* Decrement patch version

* Use τs, js directly and fix verbose mode

* Elaborate τmax in docstring

* Elaborate on parent set

* Changelog and version

* Typo

* Use StableRNGs.jl for reproducible tests
  • Loading branch information
kahaaga authored Mar 9, 2023
1 parent 0682dc1 commit 0c64352
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 47 deletions.
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 2.1

- Bugfix for `OCE` for certain conditional variable cases.
- Improve docstring for `OCE`.
- Updated readme.
- Fixed bug related to `DifferentialEntropyEstimator` "unit" conversion.

Expand Down
124 changes: 82 additions & 42 deletions src/causal_graphs/oce/OCE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@ The optimal causation entropy (OCE) algorithm for causal discovery (Sun et al.,
## Description
The OCE algorithm has three steps to determine the parents of a variable `xᵢ`.
1. Perform pairwise association tests using `utest` and select the variable `xⱼ(-τ)`
1. Perform pairwise independence tests using `utest` and select the variable `xⱼ(-τ)`
that has the highest significant (i.e. with associated p-value below `α`)
association with `xᵢ`.
association with `xᵢ(0)`. Assign it to the set of selected parents `P`.
2. Perform conditional independence tests using `ctest`, finding the parent
`Pₖ` that has the highest association with `xᵢ` given the already selected parents.
`Pₖ` that has the highest association with `xᵢ` given the already selected parents,
and add it to `P`.
Repeat until no more variables with significant association are found.
3. Backwards elimination of parents `Pₖ` of `xᵢ` for which `xᵢ ⫫ Pₖ | P - {Pₖ}`,
3. Backwards elimination of parents `Pₖ` of `xᵢ(0)` for which `xᵢ(0) ⫫ Pₖ | P - {Pₖ}`,
where `P` is the set of parent nodes found in the previous steps.
`τmax` indicates the maximum lag `τ` between the target variable `xᵢ(0)` and
its potential parents `xⱼ(-τ)`.
## Returns
When used with [`infer_graph`](@ref), it returns a vector `p`, where `p[i]` are the
Expand Down Expand Up @@ -68,7 +72,6 @@ function select_parents(alg::OCE, x; verbose = false)
end
# Find the parents of each variable.
parents = [select_parents(alg, τs, js, 𝒫s, x, k; verbose) for k in eachindex(x)]

return parents
end

Expand All @@ -80,15 +83,21 @@ Base.@kwdef mutable struct OCESelectedParents{P, PJ, PT}
parents_τs::PT = Vector{Int}(undef, 0)
end

function selected(o::OCESelectedParents)
js, τs = o.parents_js, o.parents_τs
@assert length(js) == length(τs)
return join(["x$(js[i])($(τs[i]))" for i in eachindex(js)], ", ")
end


function Base.show(io::IO, x::OCESelectedParents)
s = ["x$(x.parents_js[i])($(x.parents_τs[i]))" for i in eachindex(x.parents)]
all = "Parents for x$(x.i)(0): $(join(s, ", "))"
all = "x$(x.i)(0) $(join(s, ", "))"
show(io, all)
end

function select_parents(alg::OCE, τs, js, 𝒫s, x, i::Int; verbose = false)
verbose && println("Finding parents for variable x$i(0)")
idxs_remaining = 1:length(𝒫s) |> collect
verbose && println("\nInferring parents for x$i(0)...")
# Account for the fact that the `𝒫ⱼ ∈ 𝒫s` are embedded. This means that some points are
# lost from the `xᵢ`s.
xᵢ = @views x[i][alg.τmax+1:end]
Expand All @@ -99,41 +108,51 @@ function select_parents(alg::OCE, τs, js, 𝒫s, x, i::Int; verbose = false)
# Forward search
###################################################################
# 1. Can we find a significant pairwise association?
significant_pairwise = select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose)
verbose && println("˧ Querying pairwise associations...")

significant_pairwise = select_first_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose)

if significant_pairwise
verbose && println("˧ Querying new variables conditioned on already selected variables...")
# 2. Continue until there are no more significant conditional pairwise associations
significant_cond = true
k = 0
verbose && println("Conditional tests")
while significant_cond
k += 1
significant_cond = select_conditional_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose)
significant_cond = select_conditional_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose)
end

###################################################################
# Backward elimination
###################################################################
bw_significant = true
k = 0
M = length(parents.parents)
verbose && println("Backwards elimination")
k = 1
while length(parents.parents) >= 1 && k < length(parents.parents)
verbose && println("\tk=$k, length(parents) = $(length(parents.parents))")
bw_significant = backwards_eliminate!(parents, alg, xᵢ, k; verbose)
if bw_significant
k += 1
if !(length(parents.parents) >= 2)
return parents
end

verbose && println("˧ Backwards elimination...")

eliminate = true
ks_remaining = Set(1:length(parents.parents))
while eliminate && length(ks_remaining) >= 2
for k in ks_remaining
eliminate = backwards_eliminate!(parents, alg, xᵢ, k; verbose)
if eliminate
filter!(x -> x == k, ks_remaining)
end
end
end
end
return parents
end

# Pairwise associations
function select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose = false)
function select_first_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose = false)
M = length(𝒫s)

if isempty(𝒫s)
return false
end

# Association measure values and the associated p-values
Is, pvals = zeros(M), zeros(M)
for (i, Pj) in enumerate(𝒫s)
Expand All @@ -143,7 +162,8 @@ function select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ
end

if all(pvals .>= alg.α)
verbose && println("\tCouldn't find any significant pairwise associations.")
s = ["x$i(0) ⫫ x$j(t) | ∅)" for (τ, j) in zip(τs, js)]
verbose && println("\t$(join(s, "\n\t"))")
return false
end
# Select the variable that has the highest significant association with xᵢ.
Expand All @@ -152,21 +172,27 @@ function select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ
idx = findfirst(x -> x == Imax, Is)

if Is[idx] > 0
verbose && println("\tFound significant pairwise association with: x$(js[idx])($(τs[idx]))")
verbose && println("\tx$i(0) !⫫ x$(js[idx])($(τs[idx])) | ∅")
push!(parents.parents, 𝒫s[idx])
push!(parents.parents_js, js[idx])
push!(parents.parents_τs, τs[idx])
deleteat!(idxs_remaining, idx)
deleteat!(𝒫s, idx)
deleteat!(js, idx)
deleteat!(τs, idx)
return true
else
verbose && println("\tCouldn't find any significant pairwise associations.")
s = ["x$i(0) ⫫ x$j() | ∅)" for (τ, j) in zip(τs, js)]
verbose && println("\t$(join(s, "\n\t"))")
return false
end
end

function select_conditional_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose)
P = StateSpaceSet(parents.parents...)
function select_conditional_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose)
if isempty(𝒫s)
return false
end

P = StateSpaceSet(parents.parents...)
M = length(𝒫s)
Is = zeros(M)
pvals = zeros(M)
Expand All @@ -178,21 +204,27 @@ function select_conditional_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s
# Select the variable that has the highest significant association with xᵢ.
# "Significant" means a p-value strictly less than the significance level α.
if all(pvals .>= alg.α)
verbose && println("\tCouldn't find any significant pairwise associations.")
s = ["x$i(0) ⫫ x$j() | $(selected(parents))" for (τ, j) in zip(τs, js)]
verbose && println("\t$(join(s, "\n\t"))")
return false
end
Imax = maximum(Is[pvals .< alg.α])
idx = findfirst(x -> x == Imax, Is)

if Is[idx] > 0
verbose && println("\tSignificant conditional association with: x$(js[idx])($(τs[idx]))")
push!(parents.parents, 𝒫s[idxs_remaining[idx]])
push!(parents.parents_js, js[idxs_remaining[idx]])
push!(parents.parents_τs, τs[idxs_remaining[idx]])
deleteat!(idxs_remaining, idx)
τ = τs[idx]
j = js[idx]
verbose && println("\tx$i(0) !⫫ x$j() | $(selected(parents))")
push!(parents.parents, 𝒫s[idx])
push!(parents.parents_js, js[idx])
push!(parents.parents_τs, τs[idx])
deleteat!(𝒫s, idx)
deleteat!(τs, idx)
deleteat!(js, idx)
return true
else
verbose && println("\tCouldn't find any significant conditional associations.")
s = ["x$i(1) ⫫ x$j() | $(selected(parents)))" for (τ, j) in zip(τs, js)]
verbose && println("\t$(join(s, "\n\t"))")
return false
end
end
Expand All @@ -205,17 +237,25 @@ function backwards_eliminate!(parents, alg, xᵢ, k; verbose = false)
test = independence(alg.ctest, xᵢ, Pj, remaining)
τ, j = parents.parents_τs[k], parents.parents_js[k]
I = test.m
# If p-value >= α, then we can't reject the null, i.e. the statistic I is, in
# the frequentist hypothesis testingworld, indistringuishable from zero.
# If p-value >= α, then we can't reject the null, i.e. the statistic I is
# indistinguishable from zero, so we claim independence.
if test.pvalue >= alg.α
verbose && println("\tEliminating k = $k")
τ = parents.parents_τs[k]
j = parents.parents_τs[j]
s = join(["x$(js[i])($(τs[i]))" for i in idxs], ", ")
r = "Removing x$(js[k])($(τs[k])) from parent set"
verbose && println("\tx$j() ⫫ x$(js[k])($(τs[k])) | $s$r")
deleteat!(parents.parents, k)
deleteat!(parents.parents_js, k)
deleteat!(parents.parents_τs, k)
return false
return true # a variable was removed, so we decrement `k_remaining` in parent function
else
verbose && println("\tpvalue(test)=$(pvalue(test)) > alg.α = $(alg.α)")
verbose && println("\tNot eliminating anything")
return true
idxs = setdiff(1:M, k)
τs = parents.parents_τs
js = parents.parents_js
s = join(["x$(js[i])($(τs[i]))" for i in idxs], ", ")
r = "Keeping x$(js[k])($(τs[k])) in parent set"
verbose && println("\tx$j() !⫫ x$(js[k])($(τs[k])) | $s$r")
return false
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Neighborhood = "645ca80c-8b79-4109-87ea-e1f58159d116"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StateSpaceSets = "40b095a5-5852-4c12-98c7-d43bf788e795"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
9 changes: 4 additions & 5 deletions test/causal_graphs/oce.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
using CausalityTools
using Test
using Random
using StableRNGs

rng = MersenneTwister(1234)
rng = StableRNG(123)
sys = system(Logistic4Chain(; rng))
x, y, z, w = columns(trajectory(sys, 150, Ttr = 1000))
X = [x, y, z, w]
X = columns(trajectory(sys, 350, Ttr = 10000))

parents = infer_graph(OCE(τmax = 1), X)
parents = infer_graph(OCE(τmax = 2), X)
@test all(x parents[1].parents_js for x in (2, 3, 4))
@test all(x parents[2].parents_js for x in (3, 4))
@test all(x parents[3].parents_js for x in (4))
Expand Down

0 comments on commit 0c64352

Please sign in to comment.