Skip to content

Commit 7902f5c

Browse files
authored
Enhance Axon.Loop: Add patience option for checkpointing and implement related test case (#610)
1 parent 15ace5c commit 7902f5c

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

lib/axon/loop.ex

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,6 @@ defmodule Axon.Loop do
11471147
{status, state} = fun.(state)
11481148
default = %{monitor => cur_criteria_value, :since_last_improvement => 0}
11491149
updated_handler_meta = Map.put(handler_meta, name, default)
1150-
11511150
{status, %{state | handler_metadata: updated_handler_meta}}
11521151
end
11531152
end
@@ -1221,6 +1220,7 @@ defmodule Axon.Loop do
12211220
filter: :always,
12221221
path: "checkpoint",
12231222
file_pattern: &default_checkpoint_file/1,
1223+
patience: 3,
12241224
mode: :min
12251225
])
12261226

@@ -1229,6 +1229,7 @@ defmodule Axon.Loop do
12291229
{filter, opts} = Keyword.pop!(opts, :filter)
12301230
{path, opts} = Keyword.pop!(opts, :path)
12311231
{file_pattern, opts} = Keyword.pop!(opts, :file_pattern)
1232+
{patience, opts} = Keyword.pop!(opts, :patience)
12321233
{mode, serialize_opts} = Keyword.pop!(opts, :mode)
12331234

12341235
checkpoint_fun = &checkpoint_impl(&1, path, file_pattern, serialize_opts)
@@ -1237,7 +1238,8 @@ defmodule Axon.Loop do
12371238
monitor(loop, criteria, checkpoint_fun, :checkpoint,
12381239
mode: mode,
12391240
event: event,
1240-
filter: filter
1241+
filter: filter,
1242+
patience: patience
12411243
)
12421244
else
12431245
handle_event(loop, event, checkpoint_fun, filter)

test/axon/loop_test.exs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,51 @@ defmodule Axon.LoopTest do
899899
assert ["checkpoint_0_1.ckpt", "checkpoint_1_1.ckpt", "checkpoint_2_1.ckpt"] ==
900900
File.ls!("checkpoint_custom_path") |> Enum.sort()
901901
end
902+
903+
test "with :criteria, saves checkpoint after :patience is exceeded when metric stops improving",
904+
%{loop: loop} do
905+
loop
906+
|> Loop.handle_event(:epoch_completed, fn
907+
%{epoch: 0} = state ->
908+
state = %{state | metrics: Map.put(state.metrics, "loss", Nx.tensor(15))}
909+
{:continue, state}
910+
911+
%{epoch: 1} = state ->
912+
# loss is improved (15 -> 10)
913+
# since_last_improvement = 0
914+
state = %{state | metrics: Map.put(state.metrics, "loss", Nx.tensor(10))}
915+
{:continue, state}
916+
917+
%{epoch: 2} = state ->
918+
# loss is improved (10 -> 5)
919+
# since_last_improvement = 0
920+
state = %{state | metrics: Map.put(state.metrics, "loss", Nx.tensor(5))}
921+
{:continue, state}
922+
923+
%{epoch: 3} = state ->
924+
# loss is NOT improved (5 -> 5)
925+
# since_last_improvement = 1
926+
state = %{state | metrics: Map.put(state.metrics, "loss", Nx.tensor(5))}
927+
{:continue, state}
928+
929+
%{epoch: 4} = state ->
930+
# loss is NOT improved (5 -> 5)
931+
# since_last_improvement = 2
932+
state = %{state | metrics: Map.put(state.metrics, "loss", Nx.tensor(5))}
933+
{:continue, state}
934+
935+
%{epoch: 5} = state ->
936+
# loss is NOT improved (5 -> 5)
937+
# since_last_improvement = 0 (goes back to 0 and waits for improvement)
938+
state = %{state | metrics: Map.put(state.metrics, "loss", Nx.tensor(5))}
939+
{:continue, state}
940+
end)
941+
|> Loop.checkpoint(criteria: "loss", mode: :min, patience: 2)
942+
|> Loop.run([{Nx.tensor([[0]]), Nx.tensor([[1]])}], Axon.ModelState.empty(), epochs: 6)
943+
944+
# checkpoint_{epoch}_{iteration}.ckpt
945+
assert ["checkpoint_5_1.ckpt"] == File.ls!("checkpoint")
946+
end
902947
end
903948

904949
describe "from_state" do

0 commit comments

Comments
 (0)