Skip to content

Commit 4c5e187

Browse files
authored
Fix tests (#245)
1 parent 6d23721 commit 4c5e187

File tree

2 files changed

+178
-64
lines changed

2 files changed

+178
-64
lines changed

src/FixedEffectModel.jl

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,31 @@ end
113113

114114
# predict, residuals, modelresponse
115115

116+
# Utility functions for checking whether FE/continuous interactions are in formula
117+
# These are currently not supported in predict
118+
function is_cont_fe_int(x)
119+
x isa InteractionTerm || return false
120+
any(x -> isa(x, Term), x.terms) && any(x -> isa(x, FunctionTerm{typeof(fe), Vector{Term}}), x.terms)
121+
end
122+
123+
# Does the formula have InteractionTerms?
124+
function has_cont_fe_interaction(x::FormulaTerm)
125+
if x.rhs isa Term # only one term
126+
is_cont_fe_int(x)
127+
elseif hasfield(typeof(x.rhs), :lhs) # Is an IV term
128+
false # Is this correct?
129+
else
130+
any(is_cont_fe_int, x.rhs)
131+
end
132+
end
116133

117134
function StatsAPI.predict(m::FixedEffectModel, data)
118135
Tables.istable(data) ||
119136
throw(ArgumentError("expected second argument to be a Table, got $(typeof(data))"))
120-
has_fe(m) &&
121-
throw("To predict for a model with high-dimensional fixed effects, run `reg` with the option save = true, and then access predicted values using `fe().")
137+
138+
has_cont_fe_interaction(m.formula) &&
139+
throw(ArgumentError("Interaction of fixed effect and continuous variable detected in formula; this is currently not supported in `predict`"))
140+
122141
cdata = StatsModels.columntable(data)
123142
cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs)
124143
Xnew = modelmatrix(m.formula_schema, cols)
@@ -130,13 +149,25 @@ function StatsAPI.predict(m::FixedEffectModel, data)
130149
end
131150

132151
# Join FE estimates onto data and sum row-wise
133-
# This code does not work propertly with missing or with interacted fixed effect, so deleted
134-
#if has_fe(m)
135-
# df = DataFrame(t; copycols = false)
136-
# fes = leftjoin(select(df, m.fekeys), unique(m.fe); on = m.fekeys, makeunique = true, #matchmissing = :equal)
137-
# fes = combine(fes, AsTable(Not(m.fekeys)) => sum)
138-
# out[nonmissings] .+= fes[nonmissings, 1]
139-
#end
152+
# This does not account for FEs interacted with continuous variables - to be implemented
153+
if has_fe(m)
154+
nrow(fe(m)) > 0 || throw(ArgumentError("Model has no estimated fixed effects. To store estimates of fixed effects, run `reg` the option save = :fe"))
155+
156+
df = DataFrame(data; copycols = false)
157+
fes = leftjoin(select(df, m.fekeys), dropmissing(unique(m.fe)); on = m.fekeys,
158+
makeunique = true, matchmissing = :equal, order = :left)
159+
fes = combine(fes, AsTable(Not(m.fekeys)) => sum)
160+
161+
if any(ismissing, Matrix(select(df, m.fekeys))) || any(ismissing, Matrix(fes))
162+
out = allowmissing(out)
163+
end
164+
165+
out[nonmissings] .+= fes[nonmissings, 1]
166+
167+
if any(.!nonmissings)
168+
out[.!nonmissings] .= missing
169+
end
170+
end
140171

141172
return out
142173
end

test/predict.jl

Lines changed: 138 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ using FixedEffectModels, DataFrames, CategoricalArrays, CSV, Test
1414
residuals(result, df)
1515
@test responsename(result) == "Sales"
1616

17-
1817
model = @formula Sales ~ CPI + (Price ~ Pimin)
1918
result = reg(df, model)
2019
coeftable(result)
@@ -38,62 +37,147 @@ using FixedEffectModels, DataFrames, CategoricalArrays, CSV, Test
3837
show(result)
3938
end
4039

41-
@testset "predict" begin
42-
43-
df = DataFrame(CSV.File(joinpath(dirname(pathof(FixedEffectModels)), "../dataset/Cigar.csv")))
44-
df.StateC = categorical(df.State)
45-
46-
model = @formula Sales ~ Price + StateC
47-
result = reg(df, model)
48-
@test predict(result, df)[1] 115.9849874
49-
50-
#model = @formula Sales ~ Price + fe(State)
51-
#result = reg(df, model, save = :fe)
52-
#@test predict(result)[1] ≈ 115.9849874
53-
54-
model = @formula Sales ~ Price * Pop + StateC
55-
result = reg(df, model)
56-
@test predict(result, df)[1] 115.643985352
57-
58-
#model = @formula Sales ~ Price * Pop + fe(State)
59-
#result = reg(df, model, save = :fe)
60-
#@test predict(result, df)[1] ≈ 115.643985352
61-
62-
model = @formula Sales ~ Price + Pop + Price & Pop + StateC
63-
result = reg(df, model)
64-
@test predict(result, df)[1] 115.643985352
65-
66-
#model = @formula Sales ~ Price + Pop + Price & Pop + fe(State)
67-
#result = reg(df, model, save = :fe)
68-
#@test predict(result, df)[1] ≈ 115.643985352
69-
70-
71-
72-
73-
74-
# Tests for predict method
75-
# Test that predicting from model without saved FE test throws
76-
model = @formula Sales ~ Price + fe(State)
77-
result = reg(df, model)
78-
@test_throws "No estimates for fixed effects found. Fixed effects need to be estimated using the option save = :fe or :all for prediction to work." predict(result, df)
79-
80-
# Test basic functionality - adding 1 to price should increase prediction by coef
81-
#model = @formula Sales ~ Price + fe(State)
82-
#result = reg(df, model, save = :fe)
83-
#x = predict(result, DataFrame(Price = [1.0, 2.0], State = [1, 1]))
84-
#@test last(x) - first(x) ≈ only(result.coef)
40+
@testset "Predict" begin
41+
# Simple - one binary FE
42+
df = DataFrame(x = rand(100), g = rand(["a", "b"], 100))
43+
df.y = 1.0 .+ 0.5 .* df.x .+ (df.g .== "b")
44+
m = reg(df, @formula(y ~ x + fe(g)); save = :fe)
45+
pred = predict(m, df)
46+
@test pred df.y
47+
48+
# One group only
49+
df = DataFrame(x = rand(100), g = "a")
50+
df.y = 1.0 .+ 0.5 .* df.x
51+
m = reg(df, @formula(y ~ x + fe(g)); save = :fe)
52+
pred = predict(m, df)
53+
@test pred df.y
54+
55+
# Two groups and predict df has a level that's missing from model
56+
df = DataFrame(x = rand(100), g = rand(["a", "b"], 100))
57+
df.y = 1.0 .+ 0.5 .* df.x .+ (df.g .== "b")
58+
m = reg(df, @formula(y ~ x + fe(g)); save = :fe)
59+
pred = predict(m, DataFrame(x = [1.0, 2.0], g = ["a", "c"]))
60+
@test ismissing(pred[2])
61+
62+
# Two groups + missing observation of FE
63+
df = DataFrame(x = rand(100), g = [missing; rand(["a", "b"], 99)])
64+
df.y = 1.0 .+ 0.5 .* df.x .+ isequal.(df.g, "b")
65+
m = reg(df, @formula(y ~ x + fe(g)), save = :fe)
66+
pred = predict(m, df)
67+
@test pred isa Vector{Union{Missing, Float64}}
68+
@test ismissing(pred[1])
69+
@test pred[2:end] df.y[2:end]
70+
71+
# Two groups + missing observation of non-FE
72+
df = DataFrame(x = [missing; rand(99)], g = rand(["a", "b"], 100))
73+
df.y = 1.0 .+ 0.5 .* df.x .+ isequal.(df.g, "b")
74+
m = reg(df, @formula(y ~ x + fe(g)), save = :fe)
75+
pred = predict(m, df)
76+
@test pred isa Vector{Union{Missing, Float64}}
77+
@test ismissing(pred[1])
78+
@test pred[2] df.y[2]
79+
80+
# Two groups + two FEs
81+
df = DataFrame(x = rand(100), g1 = rand(["a", "b"], 100), g2 = rand(["c", "d"], 100))
82+
df.y = 1.0 .+ 0.5 .* df.x .+ isequal.(df.g1, "b") .+ (df.g2 .== "d") * 2
83+
m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)); save = :fe)
84+
pred = predict(m, df)
85+
@test pred isa Vector{Float64}
86+
@test pred df.y
87+
88+
# Two groups + two FEs, missing one FE
89+
df = DataFrame(x = rand(100), g1 = rand(["a", "b"], 100), g2 = [missing; rand(["c", "d"], 99)])
90+
df.y = 1.0 .+ 0.5 .* df.x .+ isequal.(df.g1, "b") .+ isequal.(df.g2,"d") * 2
91+
m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)); save = :fe)
92+
pred = predict(m, df)
93+
@test pred isa Vector{Union{Missing, Float64}}
94+
@test ismissing(pred[1])
95+
@test pred[2:end] df.y[2:end]
96+
97+
# Three FEs, "middle" one missing
98+
df = DataFrame(x = rand(100), g1 = rand(["a", "b"], 100), g2 = [missing; rand(["c", "d"], 99)],
99+
g3 = rand(["e", "f"], 100))
100+
df.y = 1.0 .+ 0.5 .* df.x .+ isequal.(df.g1, "b") .+ isequal.(df.g2, "d") * 2 .+
101+
isequal.(df.g3, "e")
102+
m = reg(df, @formula(y ~ x + fe(g1) + fe(g2) + fe(g3)); save = :fe)
103+
pred = predict(m, df)
104+
@test pred isa Vector{Union{Missing, Float64}}
105+
@test ismissing(pred[1])
106+
@test pred[2:end] df.y[2:end]
107+
108+
# Interactive FE
109+
df = DataFrame(x = rand(100), g1 = rand(["a", "b"], 100), g2 = rand(["c", "d"], 100))
110+
df.y = 1.0 .+ 0.5 .* df.x .+ (df.g1 .== "b") .+ (df.g1 .== "b" .&& df.g2 .== "d")
111+
m = reg(df, @formula(y ~ x + fe(g1)&fe(g2)); save = :fe)
112+
pred = predict(m, df)
113+
@test pred isa Vector{Float64}
114+
@test pred df.y
115+
116+
# Interactive FE + missing
117+
df = DataFrame(x = rand(100), g1 = rand(["a", "b"], 100), g2 = [missing; rand(["c", "d"], 99)])
118+
df.y = 1.0 .+ 0.5 .* df.x .+ (df.g1 .== "b") .+ (isequal.(df.g1, "b") .&& isequal.(df.g2, "d"))
119+
m6 = reg(df, @formula(y ~ x + fe(g1)&fe(g2)); save = :fe)
120+
pred = predict(m, df)
121+
@test pred isa Vector{Union{Missing, Float64}}
122+
@test length(pred) == nrow(df)
123+
@test ismissing(pred[1])
124+
@test pred[2:end] df.y[2:end]
125+
126+
# Interaction with continuous variable
127+
df = DataFrame(x = rand(100), g = rand(["a", "b"], 100), z = rand(100))
128+
df.y = 1.0 .+ 0.5 .* df.x .+ 2.0 .* (df.g .== "b") .* df.z
129+
m = reg(df, @formula(y ~ x + fe(g)&z); save = :fe)
130+
@test_throws ArgumentError pred = predict(m, df)
131+
#@test pred ≈ df.y # Once implemented
132+
133+
# Interaction with continuous variable, FE missing
134+
df = DataFrame(x = rand(100), g = [missing; rand(["a", "b"], 99)], z = rand(100))
135+
df.y = 1.0 .+ 0.5 .* df.x .+ 2.0 .* (df.g .== "b") .* df.z
136+
m = reg(df, @formula(y ~ x + fe(g)&z); save = :fe)
137+
@test_throws ArgumentError pred = predict(m, df)
138+
#@test pred ≈ df.y
139+
140+
# Interaction with continuous variable, cont var missing
141+
df = DataFrame(x = rand(100), g = rand(["a", "b"], 100), z = [missing; rand(99)])
142+
df.y = 1.0 .+ 0.5 .* df.x .+ 2.0 .* (df.g .== "b") .* df.z
143+
m = reg(df, @formula(y ~ x + fe(g)&z); save = :fe)
144+
@test_throws ArgumentError pred = predict(m, df)
145+
#@test pred ≈ df.y
146+
147+
# Regular FE + another FE interacted with continuous variable
148+
df = DataFrame(x = rand(100), g1 = rand(["a", "b"], 100), g2 = rand(["c", "d"], 100), z = rand(100))
149+
df.y = 1.0 .+ 0.5 .* df.x .+ 2.0 .* (df.g2 .== "b") .* df.z .+ 3.0 .* (df.g1 .== "b")
150+
m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)&z); save = :fe)
151+
@test_throws ArgumentError pred = predict(m, df)
152+
#@test pred ≈ df.y
153+
154+
# Two continuous/FE interactions
155+
m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)&z + fe(g1)&x); save = :fe)
156+
@test_throws ArgumentError pred = predict(m, df)
157+
158+
# Regular FE + interacted FE + FE/continuous interaction
159+
df = DataFrame(x = rand(100), g1 = rand(["a", "b"], 100), g2 = rand(["c", "d"], 100),
160+
g3 = rand(["e", "f"], 100), g4 = rand(["g", "h"], 100), z = rand(100))
161+
df.y = 1.0 .+ 0.5 .* df.x .+ 2.0 .* (df.g1 .== "b") .+ 3.0 .* (df.g2 .== "d") .* (df.g3 .== "f") .+
162+
4.0 .* (df.g4 .== "h") .* df.z
163+
m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)&fe(g3) + fe(g4)&z))
164+
@test_throws ArgumentError pred = predict(m, df)
165+
end
85166

86-
# Missing variables in covariates should yield missing prediction
87-
#x = predict(result, DataFrame(Price = [1.0, missing], State = [1, 1]))
88-
#@test ismissing(last(x))
167+
@testset "Continuous/FE detection" begin
168+
# Regular interaction is fine as handled by StatsModels
169+
@test FixedEffectModels.has_cont_fe_interaction(@formula(y ~ x + y&z)) == false
170+
171+
# FE/FE interaction also works
172+
@test FixedEffectModels.has_cont_fe_interaction(@formula(y ~ x + fe(y)&fe(z))) == false
89173

90-
# Missing variables in fixed effects should yield missing prediction
91-
#x = predict(result, DataFrame(Price = [1.0, 2.0], State = [1, missing]))
92-
#@test ismissing(last(x))
174+
# Interaction of FEs with continuous variable requires special handling, currently not implemented
175+
@test FixedEffectModels.has_cont_fe_interaction(@formula(y ~ x + fe(y)&z)) == true
176+
@test FixedEffectModels.has_cont_fe_interaction(@formula(y ~ x + y&fe(z))) == true
177+
@test FixedEffectModels.has_cont_fe_interaction(@formula(y ~ x + fe(y)&fe(z)&a)) == true
93178

94-
# Fixed effect levels not in the estimation data should yield missing prediction
95-
#x = predict(result, DataFrame(Price = [1.0, 2.0], State = [1, 111]))
96-
#@test ismissing(last(x))
179+
# Interaction of continuous with non-FE function term again handled by StatsModels
180+
@test FixedEffectModels.has_cont_fe_interaction(@formula(y ~ x + y^2&z)) == false
97181
end
98182

99183

@@ -177,7 +261,6 @@ end
177261
result = reg(df, model, save = :fe)
178262
@test "fe_State" names(fe(result))
179263

180-
181264
# iv recategorized
182265
df.Pimin2 = df.Pimin
183266
m = @formula Sales ~ (Pimin2 + Price ~ NDI + Pimin)

0 commit comments

Comments
 (0)