Causal graphs
Before introducing the causal graph examples, let's create a function that can plot directed graphs that we'll use below.
using Graphs, CairoMakie, GraphMakie
function plotgraph(g; nlabels = repr.(1:nv(g)))
f, ax, p = graphplot(g,
ilabels = nlabels,
ilabels_color = [:white for i in 1:nv(g)],
node_color = :blue,
node_size = 80,
arrow_size = 15,
figure_padding = 10
)
offsets = 0.02 * (p[:node_pos][] .- p[:node_pos][][1])
offsets[1] = Point2f(0, 0.2)
p.nlabels_offset[] = offsets
autolimits!(ax)
hidedecorations!(ax)
hidespines!(ax)
ax.aspect = DataAspect()
return f
end
plotgraph (generic function with 1 method)
We'll also implement a set of chained logistic maps with unidirectional coupling.
using DynamicalSystemsBase
Base.@kwdef struct Logistic4Chain{V, RX, RY, RZ, RW, C1, C2, C3, Σ1, Σ2, Σ3, RNG}
xi::V = [0.1, 0.2, 0.3, 0.4]
rx::RX = 3.9
ry::RY = 3.6
rz::RZ = 3.6
rw::RW = 3.8
c_xy::C1 = 0.4
c_yz::C2 = 0.4
c_zw::C3 = 0.35
σ_xy::Σ1 = 0.05
σ_yz::Σ2 = 0.05
σ_zw::Σ3 = 0.05
rng::RNG = Random.default_rng()
end
function eom_logistic4_chain(u, p::Logistic4Chain, t)
(; xi, rx, ry, rz, rw, c_xy, c_yz, c_zw, σ_xy, σ_yz, σ_zw, rng) = p
x, y, z, w = u
f_xy = (y + c_xy*(x + σ_xy * rand(rng)) ) / (1 + c_xy*(1+σ_xy))
f_yz = (z + c_yz*(y + σ_yz * rand(rng)) ) / (1 + c_yz*(1+σ_yz))
f_zw = (w + c_zw*(z + σ_zw * rand(rng)) ) / (1 + c_zw*(1+σ_zw))
dx = rx * x * (1 - x)
dy = ry * (f_xy) * (1 - f_xy)
dz = rz * (f_yz) * (1 - f_yz)
dw = rw * (f_zw) * (1 - f_zw)
return SVector{4}(dx, dy, dz, dw)
end
function system(definition::Logistic4Chain)
return DiscreteDynamicalSystem(eom_logistic4_chain, definition.xi, definition)
end
system (generic function with 1 method)
Optimal causation entropy
Here, we use the OCE
algorithm to infer a time series graph. We use a SurrogateAssociationTest
for the initial step, and a LocalPermutationTest
for the conditional steps.
using Associations
using StableRNGs
rng = StableRNG(123)
# An example system where `X → Y → Z → W`.
sys = system(Logistic4Chain(; rng))
x, y, z, w = columns(first(trajectory(sys, 300, Ttr = 10000)))
# Independence tests for unconditional and conditional stages.
uest = KSG2(MIShannon(); k = 3, w = 1)
utest = SurrogateAssociationTest(uest; rng, nshuffles = 19)
cest = MesnerShalizi(CMIShannon(); k = 3, w = 1)
ctest = LocalPermutationTest(cest; rng, nshuffles = 19)
# Infer graph
alg = OCE(; utest, ctest, α = 0.05, τmax = 1)
parents = infer_graph(alg, [x, y, z, w])
# Convert to graph and inspect edges
g = SimpleDiGraph(parents)
collect(edges(g))
3-element Vector{Graphs.SimpleGraphs.SimpleEdge{Int64}}:
Edge 1 => 2
Edge 2 => 3
Edge 3 => 4
The algorithm nicely recovers the true causal directions. We can also plot the graph using the function we made above.
plotgraph(g; nlabels = ["x", "y", "z", "w"])
PC-algorithm
Correlation-based tests
Here, we demonstrate the use of the PC
-algorithm with the correlation-based CorrTest
both for the pairwise (i.e. using PearsonCorrelation
) and conditional (i.e. using PartialCorrelation
) case.
We'll reproduce the first example from CausalInference.jl, where they also use a parametric correlation test to infer the skeleton graph for some normally distributed data.
using Associations
using StableRNGs
rng = StableRNG(123)
n = 300
v = randn(rng, n)
x = v + randn(rng, n)*0.25
w = x + randn(rng, n)*0.25
z = v + w + randn(rng, n)*0.25
s = z + randn(rng, n)*0.25
X = [x, v, w, z, s]
# Infer a completed partially directed acyclic graph (CPDAG)
alg = PC(CorrTest(), CorrTest(); α = 0.05)
est_cpdag_parametric = infer_graph(alg, X; verbose = false)
# Plot the graph
plotgraph(est_cpdag_parametric; nlabels = ["x", "v", "w", "z", "s"])
Nonparametric tests
The main difference between the PC algorithm implementation here and in CausalInference.jl is that our implementation automatically works with any compatible and IndependenceTest
, and thus any combination of (nondirectional) AssociationMeasure
and estimator.
Here, we replicate the example above, but using a nonparametric SurrogateAssociationTest
with the Shannon mutual information MIShannon
measure and the GaoOhViswanath
estimator for the pairwise independence tests, and a LocalPermutationTest
with conditional mutual information CMIShannon
and the MesnerShalizi
.
using Associations
using StableRNGs
rng = StableRNG(123)
# Use fewer observations, because MI/CMI takes longer to estimate
n = 300
v = randn(rng, n)
x = v + randn(rng, n)*0.25
w = x + randn(rng, n)*0.25
z = v + w + randn(rng, n)*0.25
s = z + randn(rng, n)*0.25
X = [x, v, w, z, s]
est_pairwise = JointProbabilities(MIShannon(), CodifyVariables(ValueBinning(3)))
est_cond = MesnerShalizi(CMIShannon(); k = 5)
pairwise_test = SurrogateAssociationTest(est_pairwise; rng, nshuffles = 50)
cond_test = LocalPermutationTest(est_cond; rng, nshuffles = 50)
alg = PC(pairwise_test, cond_test; α = 0.05)
est_cpdag_nonparametric = infer_graph(alg, X; verbose = false)
plotgraph(est_cpdag_nonparametric)
We get the same basic structure of the graph, but which directional associations are correctly ruled out varies. In general, using different types of association measures with different independence tests, applied to general non-gaussian data, will not give the same results as the correlation-based tests.