diff --git a/changelog.md b/changelog.md index 19d9aea8d..21f3f507c 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,8 @@ ## 2.1 +- Bugfix for `OCE` for certain conditional variable cases. +- Improve docstring for `OCE`. - Updated readme. - Fixed bug related to `DifferentialEntropyEstimator` "unit" conversion. diff --git a/src/causal_graphs/oce/OCE.jl b/src/causal_graphs/oce/OCE.jl index e65e16898..dbb33e712 100644 --- a/src/causal_graphs/oce/OCE.jl +++ b/src/causal_graphs/oce/OCE.jl @@ -13,15 +13,19 @@ The optimal causation entropy (OCE) algorithm for causal discovery (Sun et al., ## Description The OCE algorithm has three steps to determine the parents of a variable `xᵢ`. -1. Perform pairwise association tests using `utest` and select the variable `xⱼ(-τ)` +1. Perform pairwise independence tests using `utest` and select the variable `xⱼ(-τ)` that has the highest significant (i.e. with associated p-value below `α`) - association with `xᵢ`. + association with `xᵢ(0)`. Assign it to the set of selected parents `P`. 2. Perform conditional independence tests using `ctest`, finding the parent - `Pₖ` that has the highest association with `xᵢ` given the already selected parents. + `Pₖ` that has the highest association with `xᵢ` given the already selected parents, + and add it to `P`. Repeat until no more variables with significant association are found. -3. Backwards elimination of parents `Pₖ` of `xᵢ` for which `xᵢ ⫫ Pₖ | P - {Pₖ}`, +3. Backwards elimination of parents `Pₖ` of `xᵢ(0)` for which `xᵢ(0) ⫫ Pₖ | P - {Pₖ}`, where `P` is the set of parent nodes found in the previous steps. +`τmax` indicates the maximum lag `τ` between the target variable `xᵢ(0)` and +its potential parents `xⱼ(-τ)`. + ## Returns When used with [`infer_graph`](@ref), it returns a vector `p`, where `p[i]` are the @@ -68,7 +72,6 @@ function select_parents(alg::OCE, x; verbose = false) end # Find the parents of each variable. parents = [select_parents(alg, τs, js, 𝒫s, x, k; verbose) for k in eachindex(x)] - return parents end @@ -80,15 +83,21 @@ Base.@kwdef mutable struct OCESelectedParents{P, PJ, PT} parents_τs::PT = Vector{Int}(undef, 0) end +function selected(o::OCESelectedParents) + js, τs = o.parents_js, o.parents_τs + @assert length(js) == length(τs) + return join(["x$(js[i])($(τs[i]))" for i in eachindex(js)], ", ") +end + + function Base.show(io::IO, x::OCESelectedParents) s = ["x$(x.parents_js[i])($(x.parents_τs[i]))" for i in eachindex(x.parents)] - all = "Parents for x$(x.i)(0): $(join(s, ", "))" + all = "x$(x.i)(0) ← $(join(s, ", "))" show(io, all) end function select_parents(alg::OCE, τs, js, 𝒫s, x, i::Int; verbose = false) - verbose && println("Finding parents for variable x$i(0)") - idxs_remaining = 1:length(𝒫s) |> collect + verbose && println("\nInferring parents for x$i(0)...") # Account for the fact that the `𝒫ⱼ ∈ 𝒫s` are embedded. This means that some points are # lost from the `xᵢ`s. xᵢ = @views x[i][alg.τmax+1:end] @@ -99,31 +108,37 @@ function select_parents(alg::OCE, τs, js, 𝒫s, x, i::Int; verbose = false) # Forward search ################################################################### # 1. Can we find a significant pairwise association? - significant_pairwise = select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose) + verbose && println("˧ Querying pairwise associations...") + + significant_pairwise = select_first_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose) if significant_pairwise + verbose && println("˧ Querying new variables conditioned on already selected variables...") # 2. Continue until there are no more significant conditional pairwise associations significant_cond = true k = 0 - verbose && println("Conditional tests") while significant_cond k += 1 - significant_cond = select_conditional_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose) + significant_cond = select_conditional_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose) end ################################################################### # Backward elimination ################################################################### - bw_significant = true - k = 0 - M = length(parents.parents) - verbose && println("Backwards elimination") - k = 1 - while length(parents.parents) >= 1 && k < length(parents.parents) - verbose && println("\tk=$k, length(parents) = $(length(parents.parents))") - bw_significant = backwards_eliminate!(parents, alg, xᵢ, k; verbose) - if bw_significant - k += 1 + if !(length(parents.parents) >= 2) + return parents + end + + verbose && println("˧ Backwards elimination...") + + eliminate = true + ks_remaining = Set(1:length(parents.parents)) + while eliminate && length(ks_remaining) >= 2 + for k in ks_remaining + eliminate = backwards_eliminate!(parents, alg, xᵢ, k; verbose) + if eliminate + filter!(x -> x == k, ks_remaining) + end end end end @@ -131,9 +146,13 @@ function select_parents(alg::OCE, τs, js, 𝒫s, x, i::Int; verbose = false) end # Pairwise associations -function select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose = false) +function select_first_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose = false) M = length(𝒫s) + if isempty(𝒫s) + return false + end + # Association measure values and the associated p-values Is, pvals = zeros(M), zeros(M) for (i, Pj) in enumerate(𝒫s) @@ -143,7 +162,8 @@ function select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ end if all(pvals .>= alg.α) - verbose && println("\tCouldn't find any significant pairwise associations.") + s = ["x$i(0) ⫫ x$j(t$τ) | ∅)" for (τ, j) in zip(τs, js)] + verbose && println("\t$(join(s, "\n\t"))") return false end # Select the variable that has the highest significant association with xᵢ. @@ -152,21 +172,27 @@ function select_first_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ idx = findfirst(x -> x == Imax, Is) if Is[idx] > 0 - verbose && println("\tFound significant pairwise association with: x$(js[idx])($(τs[idx]))") + verbose && println("\tx$i(0) !⫫ x$(js[idx])($(τs[idx])) | ∅") push!(parents.parents, 𝒫s[idx]) push!(parents.parents_js, js[idx]) push!(parents.parents_τs, τs[idx]) - deleteat!(idxs_remaining, idx) + deleteat!(𝒫s, idx) + deleteat!(js, idx) + deleteat!(τs, idx) return true else - verbose && println("\tCouldn't find any significant pairwise associations.") + s = ["x$i(0) ⫫ x$j($τ) | ∅)" for (τ, j) in zip(τs, js)] + verbose && println("\t$(join(s, "\n\t"))") return false end end -function select_conditional_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s, xᵢ; verbose) - P = StateSpaceSet(parents.parents...) +function select_conditional_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose) + if isempty(𝒫s) + return false + end + P = StateSpaceSet(parents.parents...) M = length(𝒫s) Is = zeros(M) pvals = zeros(M) @@ -178,21 +204,27 @@ function select_conditional_parent!(parents, idxs_remaining, alg, τs, js, 𝒫s # Select the variable that has the highest significant association with xᵢ. # "Significant" means a p-value strictly less than the significance level α. if all(pvals .>= alg.α) - verbose && println("\tCouldn't find any significant pairwise associations.") + s = ["x$i(0) ⫫ x$j($τ) | $(selected(parents))" for (τ, j) in zip(τs, js)] + verbose && println("\t$(join(s, "\n\t"))") return false end Imax = maximum(Is[pvals .< alg.α]) idx = findfirst(x -> x == Imax, Is) if Is[idx] > 0 - verbose && println("\tSignificant conditional association with: x$(js[idx])($(τs[idx]))") - push!(parents.parents, 𝒫s[idxs_remaining[idx]]) - push!(parents.parents_js, js[idxs_remaining[idx]]) - push!(parents.parents_τs, τs[idxs_remaining[idx]]) - deleteat!(idxs_remaining, idx) + τ = τs[idx] + j = js[idx] + verbose && println("\tx$i(0) !⫫ x$j($τ) | $(selected(parents))") + push!(parents.parents, 𝒫s[idx]) + push!(parents.parents_js, js[idx]) + push!(parents.parents_τs, τs[idx]) + deleteat!(𝒫s, idx) + deleteat!(τs, idx) + deleteat!(js, idx) return true else - verbose && println("\tCouldn't find any significant conditional associations.") + s = ["x$i(1) ⫫ x$j($τ) | $(selected(parents)))" for (τ, j) in zip(τs, js)] + verbose && println("\t$(join(s, "\n\t"))") return false end end @@ -205,17 +237,25 @@ function backwards_eliminate!(parents, alg, xᵢ, k; verbose = false) test = independence(alg.ctest, xᵢ, Pj, remaining) τ, j = parents.parents_τs[k], parents.parents_js[k] I = test.m - # If p-value >= α, then we can't reject the null, i.e. the statistic I is, in - # the frequentist hypothesis testingworld, indistringuishable from zero. + # If p-value >= α, then we can't reject the null, i.e. the statistic I is + # indistinguishable from zero, so we claim independence. if test.pvalue >= alg.α - verbose && println("\tEliminating k = $k") + τ = parents.parents_τs[k] + j = parents.parents_τs[j] + s = join(["x$(js[i])($(τs[i]))" for i in idxs], ", ") + r = "Removing x$(js[k])($(τs[k])) from parent set" + verbose && println("\tx$j($τ) ⫫ x$(js[k])($(τs[k])) | $s → $r") deleteat!(parents.parents, k) deleteat!(parents.parents_js, k) deleteat!(parents.parents_τs, k) - return false + return true # a variable was removed, so we decrement `k_remaining` in parent function else - verbose && println("\tpvalue(test)=$(pvalue(test)) > alg.α = $(alg.α)") - verbose && println("\tNot eliminating anything") - return true + idxs = setdiff(1:M, k) + τs = parents.parents_τs + js = parents.parents_js + s = join(["x$(js[i])($(τs[i]))" for i in idxs], ", ") + r = "Keeping x$(js[k])($(τs[k])) in parent set" + verbose && println("\tx$j($τ) !⫫ x$(js[k])($(τs[k])) | $s → $r") + return false end end diff --git a/test/Project.toml b/test/Project.toml index a7caf102e..8b70661fe 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Neighborhood = "645ca80c-8b79-4109-87ea-e1f58159d116" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StateSpaceSets = "40b095a5-5852-4c12-98c7-d43bf788e795" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/test/causal_graphs/oce.jl b/test/causal_graphs/oce.jl index 4b5e8991e..0aded76b8 100644 --- a/test/causal_graphs/oce.jl +++ b/test/causal_graphs/oce.jl @@ -1,13 +1,12 @@ using CausalityTools using Test -using Random +using StableRNGs -rng = MersenneTwister(1234) +rng = StableRNG(123) sys = system(Logistic4Chain(; rng)) -x, y, z, w = columns(trajectory(sys, 150, Ttr = 1000)) -X = [x, y, z, w] +X = columns(trajectory(sys, 350, Ttr = 10000)) -parents = infer_graph(OCE(τmax = 1), X) +parents = infer_graph(OCE(τmax = 2), X) @test all(x ∉ parents[1].parents_js for x in (2, 3, 4)) @test all(x ∉ parents[2].parents_js for x in (3, 4)) @test all(x ∉ parents[3].parents_js for x in (4))