Causal graphs

Before introducing the causal graph examples, let's create a function that can plot directed graphs that we'll use below.

using Graphs, CairoMakie, GraphMakie

function plotgraph(g)
    f, ax, p = graphplot(g,
        nlabels = repr.(1:nv(g)),
        nlabels_color = [:red for i in 1:nv(g)],
    )
    offsets = 0.05 * (p[:node_pos][] .- p[:node_pos][][1])
    offsets[1] = Point2f(0, 0.2)
    p.nlabels_offset[] = offsets
    autolimits!(ax)
    hidedecorations!(ax)
    hidespines!(ax)
    ax.aspect = DataAspect()
    return f
end
plotgraph (generic function with 1 method)

Optimal causation entropy

Here, we use the OCE algorithm to infer a time series graph. We use a SurrogateTest for the initial step, and a LocalPermutationTest for the conditional steps.

using CausalityTools
using StableRNGs
rng = StableRNG(123)

# An example system where `X → Y → Z → W`.
sys = system(Logistic4Chain(; rng))
x, y, z, w = columns(first(trajectory(sys, 400, Ttr = 10000)))

# Independence tests for unconditional and conditional stages.
utest = SurrogateTest(MIShannon(), KSG2(k = 3, w = 1); rng, nshuffles = 150)
ctest = LocalPermutationTest(CMIShannon(), MesnerShalizi(k = 3, w = 1); rng, nshuffles = 150)

# Infer graph
alg = OCE(; utest, ctest, α = 0.05, τmax = 1)
parents = infer_graph(alg, [x, y, z, w])

# Convert to graph and inspect edges
g = SimpleDiGraph(parents)
collect(edges(g))
3-element Vector{Graphs.SimpleGraphs.SimpleEdge{Int64}}:
 Edge 1 => 2
 Edge 2 => 3
 Edge 3 => 4

The algorithm nicely recovers the true causal directions. We can also plot the graph using the function we made above.

plotgraph(g)

PC-algorithm

Correlation-based tests

Here, we demonstrate the use of the PC-algorithm with the correlation-based CorrTest both for the pairwise (i.e. using PearsonCorrelation) and conditional (i.e. using PartialCorrelation) case.

We'll reproduce the first example from CausalInference.jl, where they also use a parametric correlation test to infer the skeleton graph for some normally distributed data.

using CausalityTools
using StableRNGs
rng = StableRNG(123)
n = 500
v = randn(rng, n)
x = v + randn(rng, n)*0.25
w = x + randn(rng, n)*0.25
z = v + w + randn(rng, n)*0.25
s = z + randn(rng, n)*0.25
X = [x, v, w, z, s]

# Infer a completed partially directed acyclic graph (CPDAG)
alg = PC(CorrTest(), CorrTest(); α = 0.05)
est_cpdag_parametric = infer_graph(alg, X; verbose = false)

# Plot the graph
plotgraph(est_cpdag_parametric)

Nonparametric tests

The main difference between the PC algorithm implementation here and in CausalInference.jl is that our implementation automatically works with any compatible and IndependenceTest, and thus any combination of (nondirectional) AssociationMeasure and estimator.

Here, we replicate the example above, but using a nonparametric SurrogateTest with the Shannon mutual information MIShannon measure and the GaoOhViswanath estimator for the pairwise independence tests, and a LocalPermutationTest with conditional mutual information CMIShannon and the MesnerShalizi.

rng = StableRNG(123)

# Use fewer observations, because MI/CMI takes longer to estimate
n = 400
v = randn(rng, n)
x = v + randn(rng, n)*0.25
w = x + randn(rng, n)*0.25
z = v + w + randn(rng, n)*0.25
s = z + randn(rng, n)*0.25
X = [x, v, w, z, s]

pairwise_test = SurrogateTest(MIShannon(), GaoOhViswanath(k = 10))
cond_test = LocalPermutationTest(CMIShannon(), MesnerShalizi(k = 10))
alg = PC(pairwise_test, cond_test; α = 0.05)
est_cpdag_nonparametric = infer_graph(alg, X; verbose = false)
plotgraph(est_cpdag_nonparametric)

We get the same graph as with the parametric estimator. However, for general non-gaussian data, the correlation-based tests (which assumes normally distributed data) will not give the same results as other independence tests.