Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM lmsysorg/sglang:v0.5.2rc2-cu126
ENV BASE_MODEL nvidia/Llama-3.1-8B-Instruct-FP8
ENV DRAFT_MODEL lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B
ENV SGLANG_ARGS "--tp-size 1 --max-running-requests 32 --mem-fraction-static 0.8 --enable-torch-compile --speculative-algorithm EAGLE3 --speculative-num-steps 3 --speculative-eagle-topk 2 --speculative-num-draft-tokens 4 --dtype float16 --attention-backend fa3 --host 0.0.0.0 --port 30000"
ENV SGL_HOST 0.0.0.0
ENV SGL_PORT 30000
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN 1

EXPOSE 30000
ENTRYPOINT python3 -m sglang.launch_server --model-path $BASE_MODEL --speculative-draft-model-path $DRAFT_MODEL $SGLANG_ARGS
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,18 @@
"source": [
"import matplotlib.pyplot as plt\n",
"from scripts.utils import setup_workspace\n",
"from scripts.dataset import prepare_finqa_dataset\n",
"from scripts.dataset import prepare_finqa_dataset, prepare_sharegpt_dataset\n",
"from scripts.run import get_run_metrics\n",
"from scripts.reinforcement_learning import run_rl_training_pipeline\n",
"from scripts.evaluation import run_evaluation_pipeline\n",
"from scripts.speculative_decoding import (\n",
" run_draft_model_pipeline,\n",
" prepare_combined_model_for_deployment,\n",
" deploy_speculative_decoding_endpoint,\n",
" deploy_base_model_endpoint,\n",
" run_evaluation_speculative_decoding,\n",
")\n",
"from scripts.deployment import create_managed_deployment, test_deployment"
"from scripts.deployment import test_deployment"
]
},
{
Expand All @@ -150,7 +152,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<p>Prepare dataset for Finetuning. This would save train, test and valid dataset under data folder</p>"
"<p>Prepare dataset for Fine-tuning. This would save train, test and valid dataset under data folder</p>"
]
},
{
Expand Down Expand Up @@ -484,6 +486,15 @@
"<p><strong>Reference:</strong> <a href=\"https://arxiv.org/abs/2503.01840\">https://arxiv.org/abs/2503.01840</a></p>\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"draft_train_data_path = prepare_sharegpt_dataset()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -498,7 +509,7 @@
" num_epochs=1, # Number of train epochs to be run by draft trainer.\n",
" monitor=False, # Set to True to wait for completion.\n",
" base_model_mlflow_path=\"azureml://registries/azureml-meta/models/Meta-Llama-3-8B-Instruct/versions/9\",\n",
" draft_train_data_path=\"./data_for_draft_model/train/sharegpt_train_small.jsonl\",\n",
" draft_train_data_path=draft_train_data_path,\n",
")"
]
},
Expand Down Expand Up @@ -591,8 +602,7 @@
"endpoint_name = deploy_speculative_decoding_endpoint(\n",
" ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n",
" combined_model=combined_model, # Reference from previous steps where combined model is created.\n",
" instance_type=\"octagepu\", # Instance type Kubernetes Cluster\n",
" compute_name=\"k8s-a100-compute\",\n",
" instance_type=\"Standard_NC40ads_H100_v5\", # Instance type\n",
")"
]
},
Expand Down Expand Up @@ -631,10 +641,9 @@
"outputs": [],
"source": [
"# Deploy managed online endpoint with base model\n",
"base_endpoint_name = create_managed_deployment( # Function to create endpoint for base model.\n",
"base_endpoint_name = deploy_base_model_endpoint( # Function to create endpoint for base model.\n",
" ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n",
" model_asset_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # Huggingface ID of the base model.\n",
" instance_type=\"Standard_ND96amsr_A100_v4\", # Compute SKU on which base model will be deployed.\n",
" instance_type=\"Standard_NC40ads_H100_v5\", # Compute SKU on which base model will be deployed.\n",
")"
]
},
Expand Down Expand Up @@ -711,10 +720,12 @@
"# Run evaluation job to compare base model and speculative decoding endpoints' performance\n",
"evaluation_job = run_evaluation_speculative_decoding(\n",
" ml_client=ml_client,\n",
" registry_ml_client=registry_ml_client,\n",
" base_endpoint_name=base_endpoint_name, # Base model endpoint from previous step.\n",
" speculative_endpoint_name=endpoint_name, # Speculative endpoint from previous step.\n",
" base_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n",
" speculative_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n",
" base_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n",
" speculative_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n",
" compute_cluster=\"d13-v2\",\n",
")"
]
},
Expand All @@ -735,7 +746,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">"
"<img src=\"./images/metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">"
]
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from azure.ai.ml import MLClient
from azure.ai.ml.entities import Data
from azure.ai.ml.constants import AssetTypes
from typing import Optional
from json import JSONDecodeError
import requests
from tqdm import tqdm


SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"


def register_dataset(ml_client: MLClient, dataset_name: str, file_path: str):
Expand Down Expand Up @@ -164,3 +171,100 @@ def map_fn(example: pd.Series, idx: int, split: str):
return train_data.id, test_data.id, valid_data.id

return train_dataset_path, test_dataset_path, valid_dataset_path


def _is_file_valid_json(path):
if not os.path.isfile(path):
return False

try:
with open(path) as f:
json.load(f)
return True
except JSONDecodeError as e:
print(
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
)
return False


def _download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])

# Check if the cache file already exists
if _is_file_valid_json(filename):
return filename

print(f"Downloading from {url} to {filename}")

# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors

# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB

# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))

return filename


def prepare_sharegpt_dataset(dataset_path="./data/draft_model/sharegpt_train_processed.jsonl") -> str:
"""Prepare the ShareGPT dataset for training the draft model."""
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
temp_dataset_path = _download_and_cache_file(SHAREGPT_URL)

# Load the dataset.
with open(temp_dataset_path) as f:
temp_dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
temp_dataset = [data for data in temp_dataset if len(data["conversations"]) >= 2]

# Keep one conversation in one list
new_dataset = []
for temp_data in temp_dataset:
if len(temp_data["conversations"]) % 2 != 0:
continue
if temp_data["conversations"][0]["from"] != "human":
continue

new_conversations = []

for i in range(0, len(temp_data["conversations"]), 2):
new_conversations.extend([
{
"role": "user",
"content": temp_data["conversations"][i]["value"],
},
{
"role": "assistant",
"content": temp_data["conversations"][i + 1]["value"],
}
])

new_data = {}
new_data["id"] = temp_data.get("id", "")
new_data["conversations"] = new_conversations

new_dataset.append(new_data)

os.makedirs(os.path.dirname(dataset_path), exist_ok=True)
with open(dataset_path, "w") as f:
for item in new_dataset:
f.write(json.dumps(item) + "\n")

return dataset_path
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def create_managed_deployment(
ml_client: MLClient,
model_asset_id: str, # Asset ID of the model to deploy
instance_type: str, # Supported instance type for managed deployment
model_mount_path: Optional[str] = None,
environment_asset_id: Optional[str] = None, # Asset ID of the serving engine to use
endpoint_name: Optional[str] = None,
endpoint_description: str = "Sample endpoint",
Expand Down Expand Up @@ -65,6 +66,7 @@ def create_managed_deployment(
name=deployment_name,
endpoint_name=endpoint_name,
model=model_asset_id,
model_mount_path=model_mount_path,
instance_type=instance_type,
instance_count=1,
environment=environment_asset_id,
Expand Down Expand Up @@ -151,7 +153,10 @@ def test_deployment(ml_client, endpoint_name):
"""Run a test request against a deployed endpoint and print the result."""
print("Testing endpoint...")
# Retrieve endpoint URI and API key to authenticate test request
scoring_uri = ml_client.online_endpoints.get(endpoint_name).scoring_uri
scoring_uri = (
ml_client.online_endpoints.get(endpoint_name).scoring_uri.replace("/score", "/")
+ "v1/chat/completions"
)
if not scoring_uri:
raise ValueError("Scoring URI not found for endpoint.")

Expand Down
Loading