Skip to content

Commit b0eac32

Browse files
committed
unified config with canonical base attributes; added config adapter/validator; updated dev env to uv; created and single node smoke test (Unsloth) locally
1 parent 6b7d0d5 commit b0eac32

File tree

17 files changed

+474
-57
lines changed

17 files changed

+474
-57
lines changed

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.13

common/configs/adapter.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import os
2+
from typing import Any, Dict, Tuple, List
3+
4+
try:
5+
import yaml # type: ignore
6+
except Exception as e:
7+
yaml = None
8+
9+
10+
def _ensure_dict(d: Dict[str, Any], key: str) -> Dict[str, Any]:
11+
if key not in d or d.get(key) is None:
12+
d[key] = {}
13+
return d[key]
14+
15+
16+
def normalize_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
17+
c = dict(cfg)
18+
19+
training = _ensure_dict(c, "training")
20+
optimizer_obj = training.get("optimizer")
21+
if isinstance(optimizer_obj, str):
22+
name = optimizer_obj
23+
lr = training.pop("lr", None)
24+
training["optimizer"] = {"name": name}
25+
if lr is not None:
26+
training["optimizer"]["lr"] = lr
27+
elif isinstance(optimizer_obj, dict):
28+
if "lr" not in optimizer_obj and "lr" in training:
29+
optimizer_obj["lr"] = training.pop("lr")
30+
else:
31+
lr = training.pop("lr", None)
32+
if lr is not None:
33+
training["optimizer"] = {"name": "adamw_torch", "lr": lr}
34+
35+
if "batch_size_per_gpu" not in training and "batch_size" in training:
36+
training["batch_size_per_gpu"] = training.pop("batch_size")
37+
38+
data = _ensure_dict(c, "data")
39+
if "num_workers" not in data and "num_proc" in data:
40+
data["num_workers"] = data.pop("num_proc")
41+
42+
validation = data.get("validation")
43+
if isinstance(validation, dict):
44+
if "batch_size_per_gpu" not in validation and "batch_size" in validation:
45+
validation["batch_size_per_gpu"] = validation.pop("batch_size")
46+
47+
model = _ensure_dict(c, "model")
48+
if "tokenizer_name" not in data and "name" in model:
49+
data["tokenizer_name"] = model["name"]
50+
51+
checkpoint = _ensure_dict(c, "checkpoint")
52+
if "output_dir" not in checkpoint and "dir" in checkpoint:
53+
checkpoint.setdefault("output_dir", checkpoint.get("dir"))
54+
55+
return c
56+
57+
58+
def validate_config(cfg: Dict[str, Any]) -> Tuple[List[str], List[str]]:
59+
errors: List[str] = []
60+
warnings: List[str] = []
61+
62+
def need(path: str):
63+
nonlocal errors
64+
node = cfg
65+
for k in path.split("."):
66+
if not isinstance(node, dict) or k not in node:
67+
errors.append(f"Missing required key: {path}")
68+
return None
69+
node = node[k]
70+
return node
71+
72+
need("model.name")
73+
need("data.name")
74+
need("data.prompt_template")
75+
76+
if need("training.batch_size_per_gpu") is not None:
77+
v = cfg["training"]["batch_size_per_gpu"]
78+
if not isinstance(v, int) or v <= 0:
79+
errors.append("training.batch_size_per_gpu must be a positive int")
80+
81+
need("training.grad_accum_steps")
82+
need("training.max_steps")
83+
need("training.optimizer.name")
84+
need("training.optimizer.lr")
85+
86+
need("checkpoint.save_interval")
87+
if need("checkpoint.output_dir") is None and need("checkpoint.dir") is None:
88+
warnings.append("checkpoint.output_dir is missing; will rely on SM_CHECKPOINT_DIR or checkpoint.dir if provided")
89+
90+
model = cfg.get("model", {})
91+
if model.get("load_in_4bit") and model.get("dtype"):
92+
warnings.append("model.load_in_4bit is set along with model.dtype; verify compatibility for the selected trainer")
93+
94+
data = cfg.get("data", {})
95+
if data.get("format") == "parquet" and data.get("streaming") is True:
96+
warnings.append("data.format=parquet with streaming=true may not be supported; verify dataset loader path")
97+
98+
return errors, warnings
99+
100+
101+
def resolve_checkpoint_dir(cfg: Dict[str, Any], env: Dict[str, str] | None = None) -> str:
102+
e = env or os.environ
103+
sm_dir = e.get("SM_CHECKPOINT_DIR")
104+
if sm_dir:
105+
return sm_dir
106+
checkpoint = cfg.get("checkpoint", {})
107+
if checkpoint.get("dir"):
108+
return str(checkpoint["dir"])
109+
if checkpoint.get("output_dir"):
110+
return str(checkpoint["output_dir"])
111+
return "./outputs"
112+
113+
114+
def load_config(path: str, env: Dict[str, str] | None = None) -> Tuple[Dict[str, Any], List[str], List[str]]:
115+
if yaml is None:
116+
raise RuntimeError("PyYAML is required to load config files")
117+
with open(path, "r") as f:
118+
raw = yaml.safe_load(f) or {}
119+
norm = normalize_config(raw)
120+
errors, warnings = validate_config(norm)
121+
if errors:
122+
raise ValueError("Config validation failed: " + "; ".join(errors))
123+
_ = resolve_checkpoint_dir(norm, env)
124+
return norm, errors, warnings

common/configs/base_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ defaults:
22
- fsdp_defaults # For FSDP jobs only
33

44
training:
5-
batch_size: 4 # Unsloth: per-device batch size; FSDP: may use batch_size_per_gpu
65
batch_size_per_gpu: 4 # FSDP trainer expects this
76
grad_accum_steps: 1
8-
lr: 2e-5
97
max_steps: 1000
10-
optimizer: adamw_torch
8+
optimizer:
9+
name: adamw_torch
10+
lr: 2e-5
1111

1212
checkpoint:
1313
save_interval: 100

common/utils/logging_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
22
import wandb
3-
import smdebug.pytorch as smd
3+
try:
4+
import smdebug.pytorch as smd # type: ignore
5+
except Exception:
6+
smd = None
47
import logging
58
from typing import Any, Dict
69
try:
@@ -18,9 +21,12 @@ def __init__(self, config):
1821
wandb.init(project=config["wandb_project"])
1922
self.loggers.append(("wandb", wandb.log))
2023

21-
if "SM_DEBUG" in os.environ: # SageMaker Debugger
22-
self.hook = smd.Hook.create_from_json_file()
23-
self.loggers.append(("smdebug", self.hook.log_metric))
24+
if smd is not None and "SM_DEBUG" in os.environ: # SageMaker Debugger
25+
try:
26+
self.hook = smd.Hook.create_from_json_file()
27+
self.loggers.append(("smdebug", self.hook.log_metric))
28+
except Exception as e:
29+
logging.error(f"Failed to initialize smdebug hook: {e}")
2430

2531
def log_metrics(self, metrics, step):
2632
for name, logger in self.loggers:

main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def main():
2+
print("Hello from fsdp-multi-gpu-training!")
3+
4+
5+
if __name__ == "__main__":
6+
main()

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[project]
2+
name = "fsdp-multi-gpu-training"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
requires-python = ">=3.13"
7+
dependencies = []

requirements.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ datasets>=2.19.0
55
accelerate>=0.31.0
66
omegaconf>=2.3.0
77
safetensors>=0.4.3
8+
jupyter>=1.0.0
89

910
# Quantization / 4-bit (for Unsloth or fallback paths)
10-
bitsandbytes>=0.43.1
11+
# Install only on Linux (x86_64/aarch64). No macOS arm64 wheels are available.
12+
bitsandbytes>=0.48.1; platform_system == "Linux" and (platform_machine == "x86_64" or platform_machine == "aarch64")
1113

1214
# Optional but commonly required for LLM tokenizers
1315
sentencepiece>=0.1.99
@@ -17,12 +19,15 @@ wandb>=0.17.0
1719

1820
# Storage / data access
1921
s3fs>=2024.6.1
22+
boto3>=1.34.0
2023

2124
# Unsloth (when using the Unsloth strategy)
22-
unsloth>=2024.5.0
25+
# Install only on Linux and Python < 3.13 to avoid xformers build issues on macOS/Apple Silicon and Py3.13
26+
unsloth>=2024.5.0; platform_system == "Linux" and python_version < "3.13"
2327
peft>=0.11.1
2428

2529
# Utilities
2630
pyyaml>=6.0.1
2731
tqdm>=4.66.0
2832
python-dotenv>=1.0.1
33+
psutil>=5.9.0

scripts/configs/fsdp/llama-70b.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ checkpoint:
4545
format: "sharded" # FSDP-required format
4646
s3_uri: "s3://${env:BUCKET_NAME}/fsdp-checkpoints/"
4747
save_optimizer: false # Saves VRAM
48+
save_interval: 500
49+
output_dir: ./outputs
4850

4951
# Logging
5052
logging:

scripts/configs/unsloth/llama-7b.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,28 @@ data:
1818
### Instruction: {instruction}
1919
### Input: {input}
2020
### Response: {output}{eos_token}
21-
num_proc: 4 # Parallel loading
21+
num_workers: 4 # Parallel loading
2222
validation:
2323
name: "s3://${env:BUCKET_NAME}/llm-data/validation" # or HF path
2424
split: "validation"
2525
interval: 200 # Steps between validations
26-
batch_size: 4 # Unsloth
26+
batch_size_per_gpu: 4 # Unsloth
2727

2828
# Training Parameters
2929
training:
30-
batch_size: 4 # Adjust based on VRAM
30+
batch_size_per_gpu: 4 # Adjust based on VRAM
3131
grad_accum_steps: 2
3232
max_steps: 1000
33-
lr: 2e-5
34-
optimizer: "adamw_8bit"
33+
optimizer:
34+
name: "adamw_8bit"
35+
lr: 2e-5
3536

3637
# Checkpointing
3738
checkpoint:
3839
dir: "/opt/ml/checkpoints" # SageMaker compatible
3940
save_interval: 100
4041
s3_uri: "s3://${env:BUCKET_NAME}/unsloth-checkpoints/"
42+
output_dir: ./outputs
4143

4244
# Logging
4345
logging:

scripts/core/data_loader.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class DataLoaderConfig:
2020
streaming: bool = False
2121
cache_dir: Optional[str] = None
2222
hf_token: Optional[str] = None
23+
config_name: Optional[str] = None # HF dataset config (e.g., "wikitext-2-raw-v1")
2324

2425

2526
class DataLoader:
@@ -28,8 +29,14 @@ class DataLoader:
2829
Prefer this over ad-hoc loaders for consistent preprocessing and error context.
2930
"""
3031
def __init__(self, config: Union[DictConfig, DataLoaderConfig]):
31-
self.config = config if isinstance(config, DataLoaderConfig) \
32-
else DataLoaderConfig(**config.data)
32+
if isinstance(config, DataLoaderConfig):
33+
self.config = config
34+
elif isinstance(config, DictConfig):
35+
self.config = DataLoaderConfig(**config.data)
36+
elif isinstance(config, dict):
37+
self.config = DataLoaderConfig(**config)
38+
else:
39+
raise TypeError(f"Unsupported config type for DataLoader: {type(config)}")
3340
try:
3441
self.tokenizer = AutoTokenizer.from_pretrained(
3542
self.config.tokenizer_name,
@@ -45,13 +52,15 @@ def __init__(self, config: Union[DictConfig, DataLoaderConfig]):
4552
def _load_from_hf(self) -> Dataset:
4653
"""Load dataset directly from Hugging Face Hub."""
4754
try:
48-
return load_dataset(
49-
self.config.name,
50-
split=self.config.split,
51-
streaming=self.config.streaming,
52-
token=self.config.hf_token or os.getenv("HF_TOKEN"),
53-
cache_dir=self.config.cache_dir
54-
)
55+
kwargs = {
56+
"split": self.config.split,
57+
"streaming": self.config.streaming,
58+
"token": self.config.hf_token or os.getenv("HF_TOKEN"),
59+
"cache_dir": self.config.cache_dir,
60+
}
61+
if self.config.config_name:
62+
kwargs["name"] = self.config.config_name
63+
return load_dataset(self.config.name, **kwargs)
5564
except Exception as e:
5665
msg = (
5766
f"Failed to load HF dataset '{self.config.name}' split='{self.config.split}' "

0 commit comments

Comments
 (0)