diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b0c21c2..dbbac65 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,7 +18,6 @@ jobs: fail-fast: false matrix: version: - - '1.10' - '1.11' os: - ubuntu-latest @@ -33,18 +32,8 @@ jobs: - uses: julia-actions/julia-downgrade-compat@v1 if: ${{ matrix.version == '1.10' }} - uses: julia-actions/cache@v1 - - name: Set CAT packages to develop & resolve env - run: | - julia --project=test/ -e 'using Pkg; - Pkg.develop(path="."); - Pkg.resolve(); - Pkg.instantiate()' - env: - R_HOME: '*' - uses: julia-actions/julia-buildpkg@v1 - - name: Run tests - run: | - cd test && julia --project=. --code-coverage=user ./runtests.jl + - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - uses: coverallsapp/github-action@v2 with: @@ -59,7 +48,7 @@ jobs: miniforge-version: latest - uses: julia-actions/setup-julia@v1 with: - version: '1.10' + version: '1.11' - name: Set CAT packages to develop & resolve env run: | julia --project=docs/ -e 'using Pkg; diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml index abca885..4878aeb 100644 --- a/.github/workflows/benchmark_pr.yml +++ b/.github/workflows/benchmark_pr.yml @@ -14,10 +14,10 @@ jobs: steps: - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: - version: "1.10" - - uses: julia-actions/cache@v1 + version: "1.11" + - uses: julia-actions/cache@v2 - name: Extract Package Name from Project.toml id: extract-package-name run: | diff --git a/Project.toml b/Project.toml index 185950a..5542e1f 100644 --- a/Project.toml +++ b/Project.toml @@ -36,6 +36,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] Accessors = "^0.1.12" +Aqua = "0.5.5, 0.6.5" AutoHashEquals = "2" ConstructionBase = "^1.2" DataFrames = "1.6.1" @@ -47,16 +48,30 @@ FittedItemBanks = "^0.6.3" ForwardDiff = "0.10.24" HypothesisTests = "^0.10.12, ^0.11.0" Interpolations = "^0.14, ^0.15" +JET = "^0.9" Lazy = "0.15" LogarithmicNumbers = "1" MacroTools = "^0.5.6" Measurements = "^2.10.0" +Optim = "1.7.3" OrderedCollections = "^1.6" PsychometricsBazaarBase = "^0.8.1" Reexport = "1" +ResumableFunctions = "^0.6" Setfield = "^1" StaticArrays = "1" StatsBase = "^0.34" StatsFuns = "^0.9.15, ^1" +Test = "^1.11" UnPack = "1" -julia = "^1.10" +julia = "^1.11" + +[extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" +ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Aqua", "JET", "Optim", "ResumableFunctions", "Test"] diff --git a/src/Comparison.jl b/src/Comparison.jl index f8cec75..a2efd68 100644 --- a/src/Comparison.jl +++ b/src/Comparison.jl @@ -159,7 +159,7 @@ end #phase_func=nothing; function measure_all(comparison, system, cat, phase; kwargs...) - @info "measure_all" phase comparison.phases + @info "measure_all" phase system kwargs if !(phase in keys(comparison.phases)) return end @@ -189,6 +189,7 @@ end struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy item_bank::AbstractItemBank sizes::AbstractVector{Int} + responses::Vector # XXX: Type starting_responses::Int shuffle::Bool time_limit::Float64 @@ -205,24 +206,42 @@ function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes) return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0, false, Inf) end +function init_cat(cat::Stateful.StatefulCat, item_bank) + Stateful.set_item_bank!(cat, item_bank) + cat +end + +function init_cat(cat, item_bank) + cat(item_bank) +end + function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy}) strategy = comparison.strategy current_cats = collect(pairs(comparison.rules)) - next_current_cats = copy(current_cats) + next_current_cats = [] @info "sizes" strategy.sizes for size in strategy.sizes subsetted_item_bank = subset(strategy.item_bank, 1:size) empty!(next_current_cats) - for (name, cat) in current_cats - Stateful.set_item_bank!(cat, subsetted_item_bank) - for _ in 1:(strategy.starting_responses) - Stateful.next_item(cat) + for (name, mk_cat) in current_cats + init_time = @timed begin + cat = init_cat(mk_cat, subsetted_item_bank) end + response_add_time = @timed begin + for idx in 1:(strategy.starting_responses) + Stateful.add_response!(cat, idx, strategy.responses[idx]) + end + end + @info "responses" Stateful.get_responses(cat) measure_all( comparison, name, cat, - :before_next_item + :before_next_item, + init_time = init_time.time, + response_add_time = response_add_time.time, + num_items=size, + system_name=name ) timed_next_item = @timed Stateful.next_item(cat) next_item = timed_next_item.value @@ -232,14 +251,17 @@ function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExec cat, :after_next_item, next_item = next_item, - timing = timed_next_item + timing = timed_next_item, + num_items=size, + system_name=name ) - @info "next_item" timed_next_item.time strategy.time_limit + @info "next_item" name timed_next_item.time strategy.time_limit if timed_next_item.time < strategy.time_limit push!(next_current_cats, name => cat) end end - current_cats, next_current_cats = next_current_cats, current_cats + current_cats = next_current_cats + next_current_cats = [] end end diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index 024de2f..29630ff 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -5,8 +5,11 @@ include("./hacks.jl") using Pkg using Reexport -export ConfigBase, Responses, Aggregators, NextItemRules, TerminationConditions -export CatConfig, Sim +# Modules +export ConfigBase, Responses, Aggregators +export NextItemRules, TerminationConditions +export CatConfig, Sim, DecisionTree +export Stateful, Comparison # Vendored dependencies include("./vendor/PushVectors.jl") diff --git a/src/Stateful.jl b/src/Stateful.jl index f34ac38..aba5196 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -6,6 +6,10 @@ using ..CatConfig: CatLoopConfig, CatRules using ..Responses: BareResponses, Response using ..NextItemRules: compute_criteria, best_item +export StatefulCat, StatefulCatConfig, run_cat +public next_item, ranked_items, item_criteria +public add_response!, rollback!, reset!, get_responses, get_ability + ## StatefulCat interface abstract type StatefulCat end @@ -56,61 +60,65 @@ end ## TODO: Materialise the cat into a decsision tree ## Implementation for CatConfig -struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat +struct StatefulCatConfig{TrackedResponsesT <: TrackedResponses} <: StatefulCat rules::CatRules - tracked_responses::TrackedResponses - item_bank::Ref{ItemBankT} + tracked_responses::Ref{TrackedResponsesT} end -function StatefulCatConfig(rules, item_bank) +function StatefulCatConfig(rules::CatRules, item_bank::AbstractItemBank) bare_responses = BareResponses(ResponseType(item_bank)) tracked_responses = TrackedResponses( bare_responses, item_bank, rules.ability_tracker ) - return StatefulCatConfig(rules, tracked_responses, Ref(item_bank)) + return StatefulCatConfig(rules, Ref(tracked_responses)) end function next_item(config::StatefulCatConfig) - return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[]) + return best_item(config.rules.next_item, config.tracked_responses[]) end function ranked_items(config::StatefulCatConfig) return sortperm(compute_criteria( - config.rules.next_item, config.tracked_responses, config.item_bank[])) + config.rules.next_item, config.tracked_responses[])) end function item_criteria(config::StatefulCatConfig) return compute_criteria( - config.rules.next_item, config.tracked_responses, config.item_bank[]) + config.rules.next_item, config.tracked_responses[]) end function add_response!(config::StatefulCatConfig, index, response) + tracked_responses = config.tracked_responses[] Aggregators.add_response!( - config.tracked_responses, Response( - ResponseType(config.item_bank[]), index, response)) + tracked_responses, Response( + ResponseType(tracked_responses.item_bank), index, response)) end function rollback!(config::StatefulCatConfig) - pop_response!(config.tracked_responses) + pop_response!(config.tracked_responses[]) end function reset!(config::StatefulCatConfig) - empty!(config.tracked_responses) + empty!(config.tracked_responses[]) end function set_item_bank!(config::StatefulCatConfig, item_bank) - reset!(config) - config.item_bank[] = item_bank + bare_responses = BareResponses(ResponseType(item_bank)) + config.tracked_responses[] = TrackedResponses( + bare_responses, + item_bank, + config.rules.ability_tracker + ) end function get_responses(config::StatefulCatConfig) - return config.tracked_responses.responses + return config.tracked_responses[].responses end function get_ability(config::StatefulCatConfig) - return (config.rules.ability_estimator(config.tracked_responses), nothing) + return (config.rules.ability_estimator(config.tracked_responses[]), nothing) end ## TODO: Implementation for MaterializedDecisionTree diff --git a/src/next_item_rules/prelude/criteria.jl b/src/next_item_rules/prelude/criteria.jl index f95f247..277c65d 100644 --- a/src/next_item_rules/prelude/criteria.jl +++ b/src/next_item_rules/prelude/criteria.jl @@ -50,6 +50,13 @@ function compute_criteria( for item_idx in eachindex(items)] end +function compute_criteria( + criterion::ItemCriterion, + responses::TrackedResponses, +) + compute_criteria(criterion, responses, responses.item_bank) +end + function compute_criteria( rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT}, responses, @@ -58,6 +65,13 @@ function compute_criteria( compute_criteria(rule.criterion, responses, items) end +function compute_criteria( + rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT}, + responses::TrackedResponses +) where {StrategyT, ItemCriterionT <: ItemCriterion} + compute_criteria(rule.criterion, responses) +end + function compute_pointwise_criterion( ppic::PurePointwiseItemCriterion, tracked_responses, item_idx) compute_pointwise_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx)) diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/next_item_rules/prelude/next_item_rule.jl index cfd8263..bd708e8 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/next_item_rules/prelude/next_item_rule.jl @@ -51,3 +51,7 @@ function ItemStrategyNextItemRule(bits...; return ItemStrategyNextItemRule(strategy, criterion) end end + +function best_item(rule::NextItemRule, tracked_responses::TrackedResponses) + best_item(rule, tracked_responses, tracked_responses.item_bank) +end \ No newline at end of file diff --git a/src/next_item_rules/strategies/exhaustive.jl b/src/next_item_rules/strategies/exhaustive.jl index 8337ffb..7b47429 100644 --- a/src/next_item_rules/strategies/exhaustive.jl +++ b/src/next_item_rules/strategies/exhaustive.jl @@ -39,6 +39,5 @@ function best_item( responses::TrackedResponses, items ) where {ItemCriterionT <: ItemCriterion} - #, rule.strategy.parallel exhaustive_search(rule.criterion, responses, items)[1] -end +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 2b60c28..0000000 --- a/test/Project.toml +++ /dev/null @@ -1,21 +0,0 @@ -[deps] -Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ComputerAdaptiveTesting = "5a0d4f34-1f62-4a66-80fe-87aba0485488" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -FittedItemBanks = "3f797b09-34e4-41d7-acf6-3302ae3248a5" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" -PsychometricsBazaarBase = "b0d9cada-d963-45e9-a4c6-4746243987f1" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -Aqua = "0.5.5, 0.6.5" -Distributions = "0.25.88" -FittedItemBanks = "^0.6" -Optim = "1.7.3" -PsychometricsBazaarBase = "^0.8" -julia = "^1.10" diff --git a/test/dummy.jl b/test/dummy.jl index 0893ac7..5f71ddc 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -1,6 +1,5 @@ module Dummy -using Accessors using ComputerAdaptiveTesting.NextItemRules using ComputerAdaptiveTesting.Aggregators using ComputerAdaptiveTesting.Responses diff --git a/test/format.jl b/test/format.jl deleted file mode 100644 index 9a73ca2..0000000 --- a/test/format.jl +++ /dev/null @@ -1,8 +0,0 @@ -using JuliaFormatter -using ComputerAdaptiveTesting - -@testset "format" begin - dir = pkgdir(ComputerAdaptiveTesting) - @test format(dir * "/src"; overwrite = false) - @test format(dir * "/test"; overwrite = false) -end diff --git a/test/runtests.jl b/test/runtests.jl index 18bc266..a392689 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,5 +32,4 @@ using .Dummy include("./smoke.jl") include("./dt.jl") include("./stateful.jl") - include("./format.jl") end diff --git a/test/stateful.jl b/test/stateful.jl index f01a470..131e084 100644 --- a/test/stateful.jl +++ b/test/stateful.jl @@ -1,4 +1,18 @@ @testset "Stateful" begin + using ComputerAdaptiveTesting: CatRules + using FittedItemBanks.DummyData: dummy_full + using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL, + BooleanResponse + using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition + using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule + using ComputerAdaptiveTesting: Stateful + using ResumableFunctions + using Test: @test, @testset + + include("./dummy.jl") + using .Dummy + using Random + rng = Random.default_rng(42) # Create test data diff --git a/test/tests_top.jl b/test/tests_top.jl deleted file mode 100644 index 8b13789..0000000 --- a/test/tests_top.jl +++ /dev/null @@ -1 +0,0 @@ -