Skip to content

Commit 389d773

Browse files
authored
Merge pull request #217 from david-thrower/216-add-support-for-gradient-accumulation-steps
216 add support for gradient accumulation steps
2 parents a3552cd + 56095db commit 389d773

File tree

4 files changed

+32
-13
lines changed

4 files changed

+32
-13
lines changed

.github/workflows/automerge.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ name: Python application
66
on:
77
push:
88

9-
branches: [ "main", "208-refactor-nlp-example-to-tokenize-first" ]
9+
branches: [ "main", "216-add-support-for-gradient-accumulation-steps" ]
1010

1111

1212
permissions:

cerebros/neuralnetworkfuture/neural_network_future.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
metrics=[tf.keras.metrics.RootMeanSquaredError()],
5757
model_graph_file='test_model_graph.html',
5858
train_data_dtype=tf.float32,
59+
gradient_accumulation_steps=1,
5960
*args,
6061
**kwargs):
6162
print(level_number)
@@ -76,6 +77,7 @@ def __init__(
7677
self.compiled_materialized_neural_network = []
7778
self.model_graph_file = model_graph_file
7879
self.train_data_dtype = train_data_dtype
80+
self.gradient_accumulation_steps = gradient_accumulation_steps
7981

8082
# super().__init__(self,
8183
# *args,
@@ -328,15 +330,30 @@ def compile_neural_network(self):
328330
jit_compile = True
329331
else:
330332
jit_compile = False
333+
if not isinstance(self.gradient_accumulation_steps, int):
334+
raise ValueError("gradient_accumulation_steps must be an int >= 0. You set it as {self.gradient_accumulation_steps} type {type(self.gradient_accumulation_steps)}")
335+
if self.gradient_accumulation_steps > 1:
336+
self.materialized_neural_network.compile(
337+
loss=self.loss,
338+
metrics=self.metrics,
339+
optimizer=tf.keras.optimizers.AdamW(
340+
learning_rate=self.learning_rate,
341+
weight_decay=0.004, # Add weight decay parameter
342+
gradient_accumulation_steps=self.gradient_accumulation_steps
343+
),
344+
jit_compile=jit_compile)
345+
elif self.gradient_accumulation_steps == 1:
346+
self.materialized_neural_network.compile(
347+
loss=self.loss,
348+
metrics=self.metrics,
349+
optimizer=tf.keras.optimizers.AdamW(
350+
learning_rate=self.learning_rate,
351+
weight_decay=0.004, # Add weight decay parameter
352+
),
353+
jit_compile=jit_compile)
354+
else:
355+
raise ValueError("gradient_accumulation_steps must be an int >= 0. You set it as {self.gradient_accumulation_steps} type {type(self.gradient_accumulation_steps)}")
331356

332-
self.materialized_neural_network.compile(
333-
loss=self.loss,
334-
metrics=self.metrics,
335-
optimizer=tf.keras.optimizers.AdamW(
336-
learning_rate=self.learning_rate,
337-
weight_decay=0.004 # Add weight decay parameter
338-
),
339-
jit_compile=jit_compile)
340357

341358
def util_parse_connectivity_csv(self):
342359

cerebros/simplecerebrosrandomsearch/simple_cerebros_random_search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def __init__(
314314
patience=7,
315315
project_name='cerebros-auto-ml-test',
316316
batch_size=200,
317+
gradient_accumulation_steps=1,
317318
meta_trial_number=0,
318319
base_models=[''],
319320
train_data_dtype=tf.float32,
@@ -373,6 +374,7 @@ def __init__(
373374
self.metrics = metrics
374375
self.epochs = epochs
375376
self.batch_size = batch_size
377+
self.gradient_accumulation_steps=gradient_accumulation_steps
376378
self.meta_trial_number = meta_trial_number
377379
self.base_models = base_models
378380
self.best_model_path = ""
@@ -480,7 +482,8 @@ def run_moity_permutations(self, spec, subtrial_number, lock):
480482
loss=self.loss,
481483
metrics=self.metrics,
482484
model_graph_file=model_graph_file,
483-
train_data_dtype=self.train_data_dtype
485+
train_data_dtype=self.train_data_dtype,
486+
gradient_accumulation_steps=self.gradient_accumulation_steps
484487
)
485488
tf.keras.backend.clear_session()
486489
collect()

phishing_email_detection_gpt2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ def from_config(cls, config):
503503
batch_size=batch_size,
504504
meta_trial_number=meta_trial_number,
505505
base_models=[cerebros_base_model],
506-
train_data_dtype=tf.int32) # Changed from tf.string to tf.int32
506+
train_data_dtype=tf.int32,
507+
gradient_accumulation_steps=2)
507508

508509
cerebros_t0 = time.time()
509510
result = cerebros_automl.run_random_search()
@@ -516,8 +517,6 @@ def from_config(cls, config):
516517

517518
print(f"Cerebros trained {models_tried} models FROM A COLD START in ONLY {cerebros_time_all_models_min} min. Cerebros took only {cerebros_time_per_model} minutes on average per model.")
518519
print(f"GPT2 took {gpt_time_on_one_model_min} just to FINE TUNE one PRE - TRAINED model for 3 epochs. Although this is a small scale test, this shows the advantage of scaling in ON timing VS ON**2 timing.")
519-
520-
521520
print(f'Cerebros best accuracy achieved is {result}')
522521
print(f'val set accuracy')
523522

0 commit comments

Comments
 (0)