Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.10'
- '1.11'
os:
- ubuntu-latest
Expand All @@ -33,18 +32,8 @@ jobs:
- uses: julia-actions/julia-downgrade-compat@v1
if: ${{ matrix.version == '1.10' }}
- uses: julia-actions/cache@v1
- name: Set CAT packages to develop & resolve env
run: |
julia --project=test/ -e 'using Pkg;
Pkg.develop(path=".");
Pkg.resolve();
Pkg.instantiate()'
env:
R_HOME: '*'
- uses: julia-actions/julia-buildpkg@v1
- name: Run tests
run: |
cd test && julia --project=. --code-coverage=user ./runtests.jl
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: coverallsapp/github-action@v2
with:
Expand All @@ -59,7 +48,7 @@ jobs:
miniforge-version: latest
- uses: julia-actions/setup-julia@v1
with:
version: '1.10'
version: '1.11'
- name: Set CAT packages to develop & resolve env
run: |
julia --project=docs/ -e 'using Pkg;
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/benchmark_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: "1.10"
- uses: julia-actions/cache@v1
version: "1.11"
- uses: julia-actions/cache@v2
- name: Extract Package Name from Project.toml
id: extract-package-name
run: |
Expand Down
17 changes: 16 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
Accessors = "^0.1.12"
Aqua = "0.5.5, 0.6.5"
AutoHashEquals = "2"
ConstructionBase = "^1.2"
DataFrames = "1.6.1"
Expand All @@ -47,16 +48,30 @@ FittedItemBanks = "^0.6.3"
ForwardDiff = "0.10.24"
HypothesisTests = "^0.10.12, ^0.11.0"
Interpolations = "^0.14, ^0.15"
JET = "^0.9"
Lazy = "0.15"
LogarithmicNumbers = "1"
MacroTools = "^0.5.6"
Measurements = "^2.10.0"
Optim = "1.7.3"
OrderedCollections = "^1.6"
PsychometricsBazaarBase = "^0.8.1"
Reexport = "1"
ResumableFunctions = "^0.6"
Setfield = "^1"
StaticArrays = "1"
StatsBase = "^0.34"
StatsFuns = "^0.9.15, ^1"
Test = "^1.11"
UnPack = "1"
julia = "^1.10"
julia = "^1.11"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "JET", "Optim", "ResumableFunctions", "Test"]
42 changes: 32 additions & 10 deletions src/Comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ end

#phase_func=nothing;
function measure_all(comparison, system, cat, phase; kwargs...)
@info "measure_all" phase comparison.phases
@info "measure_all" phase system kwargs
if !(phase in keys(comparison.phases))
return
end
Expand Down Expand Up @@ -189,6 +189,7 @@ end
struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy
item_bank::AbstractItemBank
sizes::AbstractVector{Int}
responses::Vector # XXX: Type
starting_responses::Int
shuffle::Bool
time_limit::Float64
Expand All @@ -205,24 +206,42 @@ function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes)
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0, false, Inf)
end

function init_cat(cat::Stateful.StatefulCat, item_bank)
Stateful.set_item_bank!(cat, item_bank)
cat
end

function init_cat(cat, item_bank)
cat(item_bank)
end

function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy})
strategy = comparison.strategy
current_cats = collect(pairs(comparison.rules))
next_current_cats = copy(current_cats)
next_current_cats = []
@info "sizes" strategy.sizes
for size in strategy.sizes
subsetted_item_bank = subset(strategy.item_bank, 1:size)
empty!(next_current_cats)
for (name, cat) in current_cats
Stateful.set_item_bank!(cat, subsetted_item_bank)
for _ in 1:(strategy.starting_responses)
Stateful.next_item(cat)
for (name, mk_cat) in current_cats
init_time = @timed begin
cat = init_cat(mk_cat, subsetted_item_bank)
end
response_add_time = @timed begin
for idx in 1:(strategy.starting_responses)
Stateful.add_response!(cat, idx, strategy.responses[idx])
end
end
@info "responses" Stateful.get_responses(cat)
measure_all(
comparison,
name,
cat,
:before_next_item
:before_next_item,
init_time = init_time.time,
response_add_time = response_add_time.time,
num_items=size,
system_name=name
)
timed_next_item = @timed Stateful.next_item(cat)
next_item = timed_next_item.value
Expand All @@ -232,14 +251,17 @@ function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExec
cat,
:after_next_item,
next_item = next_item,
timing = timed_next_item
timing = timed_next_item,
num_items=size,
system_name=name
)
@info "next_item" timed_next_item.time strategy.time_limit
@info "next_item" name timed_next_item.time strategy.time_limit
if timed_next_item.time < strategy.time_limit
push!(next_current_cats, name => cat)
end
end
current_cats, next_current_cats = next_current_cats, current_cats
current_cats = next_current_cats
next_current_cats = []
end
end

Expand Down
7 changes: 5 additions & 2 deletions src/ComputerAdaptiveTesting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ include("./hacks.jl")
using Pkg
using Reexport

export ConfigBase, Responses, Aggregators, NextItemRules, TerminationConditions
export CatConfig, Sim
# Modules
export ConfigBase, Responses, Aggregators
export NextItemRules, TerminationConditions
export CatConfig, Sim, DecisionTree
export Stateful, Comparison

# Vendored dependencies
include("./vendor/PushVectors.jl")
Expand Down
40 changes: 24 additions & 16 deletions src/Stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ using ..CatConfig: CatLoopConfig, CatRules
using ..Responses: BareResponses, Response
using ..NextItemRules: compute_criteria, best_item

export StatefulCat, StatefulCatConfig, run_cat
public next_item, ranked_items, item_criteria
public add_response!, rollback!, reset!, get_responses, get_ability

## StatefulCat interface
abstract type StatefulCat end

Expand Down Expand Up @@ -56,61 +60,65 @@ end
## TODO: Materialise the cat into a decsision tree

## Implementation for CatConfig
struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat
struct StatefulCatConfig{TrackedResponsesT <: TrackedResponses} <: StatefulCat
rules::CatRules
tracked_responses::TrackedResponses
item_bank::Ref{ItemBankT}
tracked_responses::Ref{TrackedResponsesT}
end

function StatefulCatConfig(rules, item_bank)
function StatefulCatConfig(rules::CatRules, item_bank::AbstractItemBank)
bare_responses = BareResponses(ResponseType(item_bank))
tracked_responses = TrackedResponses(
bare_responses,
item_bank,
rules.ability_tracker
)
return StatefulCatConfig(rules, tracked_responses, Ref(item_bank))
return StatefulCatConfig(rules, Ref(tracked_responses))
end

function next_item(config::StatefulCatConfig)
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[])
return best_item(config.rules.next_item, config.tracked_responses[])
end

function ranked_items(config::StatefulCatConfig)
return sortperm(compute_criteria(
config.rules.next_item, config.tracked_responses, config.item_bank[]))
config.rules.next_item, config.tracked_responses[]))
end

function item_criteria(config::StatefulCatConfig)
return compute_criteria(
config.rules.next_item, config.tracked_responses, config.item_bank[])
config.rules.next_item, config.tracked_responses[])
end

function add_response!(config::StatefulCatConfig, index, response)
tracked_responses = config.tracked_responses[]
Aggregators.add_response!(
config.tracked_responses, Response(
ResponseType(config.item_bank[]), index, response))
tracked_responses, Response(
ResponseType(tracked_responses.item_bank), index, response))
end

function rollback!(config::StatefulCatConfig)
pop_response!(config.tracked_responses)
pop_response!(config.tracked_responses[])
end

function reset!(config::StatefulCatConfig)
empty!(config.tracked_responses)
empty!(config.tracked_responses[])
end

function set_item_bank!(config::StatefulCatConfig, item_bank)
reset!(config)
config.item_bank[] = item_bank
bare_responses = BareResponses(ResponseType(item_bank))
config.tracked_responses[] = TrackedResponses(
bare_responses,
item_bank,
config.rules.ability_tracker
)
end

function get_responses(config::StatefulCatConfig)
return config.tracked_responses.responses
return config.tracked_responses[].responses
end

function get_ability(config::StatefulCatConfig)
return (config.rules.ability_estimator(config.tracked_responses), nothing)
return (config.rules.ability_estimator(config.tracked_responses[]), nothing)
end

## TODO: Implementation for MaterializedDecisionTree
Expand Down
14 changes: 14 additions & 0 deletions src/next_item_rules/prelude/criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ function compute_criteria(
for item_idx in eachindex(items)]
end

function compute_criteria(
criterion::ItemCriterion,
responses::TrackedResponses,
)
compute_criteria(criterion, responses, responses.item_bank)
end

function compute_criteria(
rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT},
responses,
Expand All @@ -58,6 +65,13 @@ function compute_criteria(
compute_criteria(rule.criterion, responses, items)
end

function compute_criteria(
rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT},
responses::TrackedResponses
) where {StrategyT, ItemCriterionT <: ItemCriterion}
compute_criteria(rule.criterion, responses)
end

function compute_pointwise_criterion(
ppic::PurePointwiseItemCriterion, tracked_responses, item_idx)
compute_pointwise_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx))
Expand Down
4 changes: 4 additions & 0 deletions src/next_item_rules/prelude/next_item_rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ function ItemStrategyNextItemRule(bits...;
return ItemStrategyNextItemRule(strategy, criterion)
end
end

function best_item(rule::NextItemRule, tracked_responses::TrackedResponses)
best_item(rule, tracked_responses, tracked_responses.item_bank)
end
3 changes: 1 addition & 2 deletions src/next_item_rules/strategies/exhaustive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,5 @@ function best_item(
responses::TrackedResponses,
items
) where {ItemCriterionT <: ItemCriterion}
#, rule.strategy.parallel
exhaustive_search(rule.criterion, responses, items)[1]
end
end
21 changes: 0 additions & 21 deletions test/Project.toml

This file was deleted.

1 change: 0 additions & 1 deletion test/dummy.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module Dummy

using Accessors
using ComputerAdaptiveTesting.NextItemRules
using ComputerAdaptiveTesting.Aggregators
using ComputerAdaptiveTesting.Responses
Expand Down
8 changes: 0 additions & 8 deletions test/format.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,4 @@ using .Dummy
include("./smoke.jl")
include("./dt.jl")
include("./stateful.jl")
include("./format.jl")
end
14 changes: 14 additions & 0 deletions test/stateful.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
@testset "Stateful" begin
using ComputerAdaptiveTesting: CatRules
using FittedItemBanks.DummyData: dummy_full
using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL,
BooleanResponse
using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition
using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule
using ComputerAdaptiveTesting: Stateful
using ResumableFunctions
using Test: @test, @testset

include("./dummy.jl")
using .Dummy
using Random

rng = Random.default_rng(42)

# Create test data
Expand Down
1 change: 0 additions & 1 deletion test/tests_top.jl

This file was deleted.

Loading