@@ -558,7 +558,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
558558 check_batch_size = max (len (self .model .device_ids ), check_batch_size )
559559 _check_code (dataset = train_data , model = self .model , losser = losser , forward_func = self ._forward_func , metrics = metrics ,
560560 dev_data = dev_dataset , metric_key = self .metric_key , check_level = check_code_level ,
561- batch_size = check_batch_size )
561+ batch_size = check_batch_size , pin_memory = self . pin_memory )
562562
563563 self .train_data = train_data
564564 self .dev_data = dev_data # If None, No validation.
@@ -950,7 +950,7 @@ def _get_value_info(_dict):
950950 return strs
951951
952952
953- def _check_code (dataset , model , losser , metrics , forward_func , batch_size = DEFAULT_CHECK_BATCH_SIZE ,
953+ def _check_code (dataset , model , losser , metrics , forward_func , pin_memory , batch_size = DEFAULT_CHECK_BATCH_SIZE ,
954954 dev_data = None , metric_key = None , check_level = 0 ):
955955 # check get_loss 方法
956956 model_device = _get_model_device (model = model )
@@ -1010,7 +1010,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL
10101010
10111011 if dev_data is not None :
10121012 tester = Tester (data = dev_data [:batch_size * DEFAULT_CHECK_NUM_BATCH ], model = model , metrics = metrics ,
1013- batch_size = batch_size , verbose = - 1 , use_tqdm = False )
1013+ batch_size = batch_size , verbose = - 1 , use_tqdm = False , pin_memory = pin_memory )
10141014 evaluate_results = tester .test ()
10151015 _check_eval_results (metrics = evaluate_results , metric_key = metric_key , metric_list = metrics )
10161016
0 commit comments