Skip to content

Commit 8766416

Browse files
committed
FIX fsdp_auto_wrap_policy for some models (#2167)
Some transformers models and custom models would throw an error when used with PEFT's fsdp_auto_wrap_policy. This is problematatic because Trainer applies the policy automatically when PEFT and FSDP are detected. Now there is no error.
1 parent 569ea69 commit 8766416

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/peft/utils/other.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ def fsdp_auto_wrap_policy(model):
525525
).split(",")
526526
transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding}
527527
for layer_class in transformer_cls_names_to_wrap:
528+
if len(layer_class) == 0:
529+
continue
528530
transformer_cls = get_module_class_from_name(model, layer_class)
529531
if transformer_cls is None:
530532
raise Exception("Could not find the transformer layer class to wrap in the model.")

tests/test_gpu_examples.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3774,6 +3774,11 @@ def test_bnb_4bit_wrap_fsdp(self):
37743774
# check that this does not raise:
37753775
FSDP(model, auto_wrap_policy=fsdp_auto_wrap_policy(model), use_orig_params=False, sync_module_states=True)
37763776

3777+
def test_fsdp_auto_wrap_policy_does_not_raise_on_custom_model(self):
3778+
# See #2167
3779+
# Avoid raising on custom models since Trainer uses fsdp_auto_wrap_policy automatically for PEFT + FSDP
3780+
fsdp_auto_wrap_policy(SimpleModel()) # does not raise
3781+
37773782

37783783
class TestBOFT:
37793784
"""

0 commit comments

Comments
 (0)