Causal graphs
Before introducing the causal graph examples, let's create a function that can plot directed graphs that we'll use below.
using Graphs, CairoMakie, GraphMakie
function plotgraph(g)
f, ax, p = graphplot(g,
nlabels = repr.(1:nv(g)),
nlabels_color = [:red for i in 1:nv(g)],
)
offsets = 0.05 * (p[:node_pos][] .- p[:node_pos][][1])
offsets[1] = Point2f(0, 0.2)
p.nlabels_offset[] = offsets
autolimits!(ax)
hidedecorations!(ax)
hidespines!(ax)
ax.aspect = DataAspect()
return f
end
plotgraph (generic function with 1 method)
Optimal causation entropy
Here, we use the OCE
algorithm to infer a time series graph. We use a SurrogateTest
for the initial step, and a LocalPermutationTest
for the conditional steps.
using CausalityTools
using StableRNGs
rng = StableRNG(123)
# An example system where `X → Y → Z → W`.
sys = system(Logistic4Chain(; rng))
x, y, z, w = columns(first(trajectory(sys, 400, Ttr = 10000)))
# Independence tests for unconditional and conditional stages.
utest = SurrogateTest(MIShannon(), KSG2(k = 3, w = 1); rng, nshuffles = 150)
ctest = LocalPermutationTest(CMIShannon(), MesnerShalizi(k = 3, w = 1); rng, nshuffles = 150)
# Infer graph
alg = OCE(; utest, ctest, α = 0.05, τmax = 1)
parents = infer_graph(alg, [x, y, z, w])
# Convert to graph and inspect edges
g = SimpleDiGraph(parents)
collect(edges(g))
3-element Vector{Graphs.SimpleGraphs.SimpleEdge{Int64}}:
Edge 1 => 2
Edge 2 => 3
Edge 3 => 4
The algorithm nicely recovers the true causal directions. We can also plot the graph using the function we made above.
plotgraph(g)
PC-algorithm
Correlation-based tests
Here, we demonstrate the use of the PC
-algorithm with the correlation-based CorrTest
both for the pairwise (i.e. using PearsonCorrelation
) and conditional (i.e. using PartialCorrelation
) case.
We'll reproduce the first example from CausalInference.jl, where they also use a parametric correlation test to infer the skeleton graph for some normally distributed data.
using CausalityTools
using StableRNGs
rng = StableRNG(123)
n = 500
v = randn(rng, n)
x = v + randn(rng, n)*0.25
w = x + randn(rng, n)*0.25
z = v + w + randn(rng, n)*0.25
s = z + randn(rng, n)*0.25
X = [x, v, w, z, s]
# Infer a completed partially directed acyclic graph (CPDAG)
alg = PC(CorrTest(), CorrTest(); α = 0.05)
est_cpdag_parametric = infer_graph(alg, X; verbose = false)
# Plot the graph
plotgraph(est_cpdag_parametric)
Nonparametric tests
The main difference between the PC algorithm implementation here and in CausalInference.jl is that our implementation automatically works with any compatible and IndependenceTest
, and thus any combination of (nondirectional) AssociationMeasure
and estimator.
Here, we replicate the example above, but using a nonparametric SurrogateTest
with the Shannon mutual information MIShannon
measure and the GaoOhViswanath
estimator for the pairwise independence tests, and a LocalPermutationTest
with conditional mutual information CMIShannon
and the MesnerShalizi
.
rng = StableRNG(123)
# Use fewer observations, because MI/CMI takes longer to estimate
n = 400
v = randn(rng, n)
x = v + randn(rng, n)*0.25
w = x + randn(rng, n)*0.25
z = v + w + randn(rng, n)*0.25
s = z + randn(rng, n)*0.25
X = [x, v, w, z, s]
pairwise_test = SurrogateTest(MIShannon(), GaoOhViswanath(k = 10))
cond_test = LocalPermutationTest(CMIShannon(), MesnerShalizi(k = 10))
alg = PC(pairwise_test, cond_test; α = 0.05)
est_cpdag_nonparametric = infer_graph(alg, X; verbose = false)
plotgraph(est_cpdag_nonparametric)
We get the same graph as with the parametric estimator. However, for general non-gaussian data, the correlation-based tests (which assumes normally distributed data) will not give the same results as other independence tests.