Skip to content

Commit 625c3f4

Browse files
FindHaometa-codesync[bot]
authored andcommitted
PR2: Load Tensor Compression Support (#157)
Summary: ## Overview This PR adds gzip compression support to the `load_tensor` function while maintaining backward compatibility with existing uncompressed tensor files. ## Key Changes ### Modified Files 1. `tritonparse/tools/load_tensor.py` 2. `tritonparse/reproducer/templates/example.py` ### Features **Compression Format Support** - Added support for `.bin.gz` format (gzip compressed tensors) - Maintains backward compatibility with existing `.bin` format - Auto-detects compression based on file extension **Hash Verification Updates** - Hash is computed on **decompressed data** for compressed files - Filename format: - Compressed: `{hash}.bin.gz` - Uncompressed: `{hash}.bin` (backward compatible) **Loading Process Improvements** - Read file contents first - Decompress if needed - Load tensor from memory buffer using `io.BytesIO` - Enhanced error handling with clear messages for decompression failures ## Technical Details ### New Dependencies - `gzip`: For decompression - `io`: For memory buffer operations ### Implementation Logic ```python # 1. Detect file format is_compressed = str(blob_path).endswith('.bin.gz') # 2. Read and decompress if needed with open(blob_path, "rb") as f: file_contents = f.read() if is_compressed: file_contents = gzip.decompress(file_contents) # 3. Verify hash (based on decompressed data) computed_hash = hashlib.blake2b(file_contents).hexdigest() # 4. Load from memory buffer = io.BytesIO(file_contents) tensor = torch.load(buffer, map_location=device) ``` ## Benefits - **Storage Optimization**: Compression significantly reduces tensor file sizes - **Backward Compatible**: Existing `.bin` files continue to work without changes - **Data Integrity**: Hash verification ensures data correctness - **Transparent**: Users don't need to worry about compression, API remains unchanged ## Impact - ✅ Fully backward compatible, no breaking changes - ✅ Applies to all scenarios using `load_tensor` - ✅ Reproducer templates automatically gain compression support Pull Request resolved: #157 Reviewed By: wychi Differential Revision: D84068071 Pulled By: FindHao fbshipit-source-id: 32fe673401e553b6d5d2c26ae19e8ac7229f0b7a
1 parent 5f28d03 commit 625c3f4

File tree

2 files changed

+57
-25
lines changed

2 files changed

+57
-25
lines changed

tritonparse/reproducer/templates/example.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
It contains a smallest testing example for a Triton kernel.
44
"""
55

6+
import gzip
67
import hashlib
78
import importlib
9+
import io
810
import json
911
import logging
1012
import sys
1113
from functools import lru_cache
1214
from pathlib import Path
15+
from typing import Union
1316

1417
import torch
1518

@@ -42,13 +45,14 @@ def _get_triton_tensor_types():
4245
)
4346

4447

45-
def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
48+
def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch.Tensor:
4649
"""
4750
Load a tensor from its file path and verify its integrity using the hash in the filename.
4851
4952
Args:
50-
tensor_file_path (str): Direct path to the tensor .bin file. The filename should be
51-
the hash of the file contents followed by .bin extension.
53+
tensor_file_path (str | Path): Direct path to the tensor file. Supports both:
54+
- .bin.gz: gzip-compressed tensor (hash is of uncompressed data)
55+
- .bin: uncompressed tensor (for backward compatibility)
5256
device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
5357
If None, keeps the tensor on its original device.
5458
@@ -65,13 +69,26 @@ def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
6569
if not blob_path.exists():
6670
raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
6771

68-
# Extract expected hash from filename (remove .bin extension)
69-
expected_hash = blob_path.stem
72+
# Detect compression by file extension
73+
is_compressed = blob_path.name.endswith(".bin.gz")
7074

71-
# Compute actual hash of file contents
72-
with open(blob_path, "rb") as f:
73-
file_contents = f.read()
74-
computed_hash = hashlib.blake2b(file_contents).hexdigest()
75+
# Read file contents (decompress if needed)
76+
try:
77+
with open(blob_path, "rb") as f:
78+
file_obj = gzip.GzipFile(fileobj=f, mode="rb") if is_compressed else f
79+
file_contents = file_obj.read()
80+
except (OSError, gzip.BadGzipFile) as e:
81+
if is_compressed:
82+
raise RuntimeError(f"Failed to decompress gzip file {blob_path}: {str(e)}")
83+
else:
84+
raise RuntimeError(f"Failed to read file {blob_path}: {str(e)}")
85+
86+
# Extract expected hash from filename
87+
# abc123.bin.gz -> abc123 or abc123.bin -> abc123
88+
expected_hash = blob_path.name.removesuffix(".bin.gz" if is_compressed else ".bin")
89+
90+
# Compute hash of uncompressed data
91+
computed_hash = hashlib.blake2b(file_contents).hexdigest()
7592

7693
# Verify hash matches filename
7794
if computed_hash != expected_hash:
@@ -80,12 +97,11 @@ def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
8097
)
8198

8299
try:
83-
# Load the tensor using torch.load (tensors are saved with torch.save)
84-
# If device is None, keep tensor on its original device, otherwise move to specified device
85-
tensor = torch.load(blob_path, map_location=device)
100+
# Load the tensor from memory buffer
101+
tensor = torch.load(io.BytesIO(file_contents), map_location=device)
86102
return tensor
87103
except Exception as e:
88-
raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}") from e
104+
raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}")
89105

90106

91107
def create_args_from_json_file(json_path):

tritonparse/tools/load_tensor.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,23 @@
66
tensor = load_tensor.load_tensor(tensor_file_path, device)
77
"""
88

9+
import gzip
910
import hashlib
11+
import io
1012
from pathlib import Path
13+
from typing import Union
1114

1215
import torch
1316

1417

15-
def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
18+
def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch.Tensor:
1619
"""
1720
Load a tensor from its file path and verify its integrity using the hash in the filename.
1821
1922
Args:
20-
tensor_file_path (str): Direct path to the tensor .bin file. The filename should be
21-
the hash of the file contents followed by .bin extension.
23+
tensor_file_path (str | Path): Direct path to the tensor file. Supports both:
24+
- .bin.gz: gzip-compressed tensor (hash is of uncompressed data)
25+
- .bin: uncompressed tensor (for backward compatibility)
2226
device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
2327
If None, keeps the tensor on its original device.
2428
@@ -35,13 +39,26 @@ def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
3539
if not blob_path.exists():
3640
raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
3741

38-
# Extract expected hash from filename (remove .bin extension)
39-
expected_hash = blob_path.stem
42+
# Detect compression by file extension
43+
is_compressed = blob_path.name.endswith(".bin.gz")
4044

41-
# Compute actual hash of file contents
42-
with open(blob_path, "rb") as f:
43-
file_contents = f.read()
44-
computed_hash = hashlib.blake2b(file_contents).hexdigest()
45+
# Read file contents (decompress if needed)
46+
try:
47+
with open(blob_path, "rb") as f:
48+
file_obj = gzip.GzipFile(fileobj=f, mode="rb") if is_compressed else f
49+
file_contents = file_obj.read()
50+
except (OSError, gzip.BadGzipFile) as e:
51+
if is_compressed:
52+
raise RuntimeError(f"Failed to decompress gzip file {blob_path}: {str(e)}")
53+
else:
54+
raise RuntimeError(f"Failed to read file {blob_path}: {str(e)}")
55+
56+
# Extract expected hash from filename
57+
# abc123.bin.gz -> abc123 or abc123.bin -> abc123
58+
expected_hash = blob_path.name.removesuffix(".bin.gz" if is_compressed else ".bin")
59+
60+
# Compute hash of uncompressed data
61+
computed_hash = hashlib.blake2b(file_contents).hexdigest()
4562

4663
# Verify hash matches filename
4764
if computed_hash != expected_hash:
@@ -50,9 +67,8 @@ def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
5067
)
5168

5269
try:
53-
# Load the tensor using torch.load (tensors are saved with torch.save)
54-
# If device is None, keep tensor on its original device, otherwise move to specified device
55-
tensor = torch.load(blob_path, map_location=device)
70+
# Load the tensor from memory buffer
71+
tensor = torch.load(io.BytesIO(file_contents), map_location=device)
5672
return tensor
5773
except Exception as e:
5874
raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}")

0 commit comments

Comments
 (0)