@@ -899,6 +899,51 @@ defmodule Axon.LoopTest do
899899 assert [ "checkpoint_0_1.ckpt" , "checkpoint_1_1.ckpt" , "checkpoint_2_1.ckpt" ] ==
900900 File . ls! ( "checkpoint_custom_path" ) |> Enum . sort ( )
901901 end
902+
903+ test "with :criteria, saves checkpoint after :patience is exceeded when metric stops improving" ,
904+ % { loop: loop } do
905+ loop
906+ |> Loop . handle_event ( :epoch_completed , fn
907+ % { epoch: 0 } = state ->
908+ state = % { state | metrics: Map . put ( state . metrics , "loss" , Nx . tensor ( 15 ) ) }
909+ { :continue , state }
910+
911+ % { epoch: 1 } = state ->
912+ # loss is improved (15 -> 10)
913+ # since_last_improvement = 0
914+ state = % { state | metrics: Map . put ( state . metrics , "loss" , Nx . tensor ( 10 ) ) }
915+ { :continue , state }
916+
917+ % { epoch: 2 } = state ->
918+ # loss is improved (10 -> 5)
919+ # since_last_improvement = 0
920+ state = % { state | metrics: Map . put ( state . metrics , "loss" , Nx . tensor ( 5 ) ) }
921+ { :continue , state }
922+
923+ % { epoch: 3 } = state ->
924+ # loss is NOT improved (5 -> 5)
925+ # since_last_improvement = 1
926+ state = % { state | metrics: Map . put ( state . metrics , "loss" , Nx . tensor ( 5 ) ) }
927+ { :continue , state }
928+
929+ % { epoch: 4 } = state ->
930+ # loss is NOT improved (5 -> 5)
931+ # since_last_improvement = 2
932+ state = % { state | metrics: Map . put ( state . metrics , "loss" , Nx . tensor ( 5 ) ) }
933+ { :continue , state }
934+
935+ % { epoch: 5 } = state ->
936+ # loss is NOT improved (5 -> 5)
937+ # since_last_improvement = 0 (goes back to 0 and waits for improvement)
938+ state = % { state | metrics: Map . put ( state . metrics , "loss" , Nx . tensor ( 5 ) ) }
939+ { :continue , state }
940+ end )
941+ |> Loop . checkpoint ( criteria: "loss" , mode: :min , patience: 2 )
942+ |> Loop . run ( [ { Nx . tensor ( [ [ 0 ] ] ) , Nx . tensor ( [ [ 1 ] ] ) } ] , Axon.ModelState . empty ( ) , epochs: 6 )
943+
944+ # checkpoint_{epoch}_{iteration}.ckpt
945+ assert [ "checkpoint_5_1.ckpt" ] == File . ls! ( "checkpoint" )
946+ end
902947 end
903948
904949 describe "from_state" do
0 commit comments