Skip to content

Commit d05fc36

Browse files
committed
Fix CUDA OOM errors due to torch tracking gradients
self.model.eval() with torch.no_grad(): before calling model
1 parent 255e485 commit d05fc36

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

sparse_attention_hub/adapters/huggingface.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,22 +111,34 @@ def process_request(
111111
print(f"Context tokens: {context_tokens.shape}")
112112
responses: List[str] = []
113113

114-
for question in questions:
115-
sparse_meta_data: Dict[str, Any] = {}
114+
self.model.eval()
115+
with torch.no_grad():
116+
for question in questions:
117+
sparse_meta_data: Dict[str, Any] = {}
116118

117-
question_tokens = self.tokenizer.encode(question, return_tensors="pt")
118-
if self.device is not None:
119-
question_tokens = question_tokens.to(self.device)
119+
question_tokens = self.tokenizer.encode(question, return_tensors="pt")
120+
if self.device is not None:
121+
question_tokens = question_tokens.to(self.device)
120122

121-
context_outputs = self.model(
122-
context_tokens,
123-
past_key_values=None,
124-
use_cache=True,
125-
sparse_meta_data=sparse_meta_data,
126-
)
123+
context_outputs = self.model(
124+
context_tokens,
125+
past_key_values=None,
126+
use_cache=True,
127+
sparse_meta_data=sparse_meta_data,
128+
)
127129

128-
if self._sparse_attention_available:
129-
with self.enable_sparse_mode():
130+
if self._sparse_attention_available:
131+
with self.enable_sparse_mode():
132+
response_text = self._generate_response(
133+
question_tokens,
134+
context_outputs,
135+
sparse_meta_data,
136+
generation_kwargs,
137+
**kwargs,
138+
)
139+
responses.append(response_text)
140+
else:
141+
# Dense-only mode: process questions with dense attention
130142
response_text = self._generate_response(
131143
question_tokens,
132144
context_outputs,
@@ -135,16 +147,6 @@ def process_request(
135147
**kwargs,
136148
)
137149
responses.append(response_text)
138-
else:
139-
# Dense-only mode: process questions with dense attention
140-
response_text = self._generate_response(
141-
question_tokens,
142-
context_outputs,
143-
sparse_meta_data,
144-
generation_kwargs,
145-
**kwargs,
146-
)
147-
responses.append(response_text)
148150

149151
if isinstance(request.questions, str):
150152
return RequestResponse(responses=responses[0])

0 commit comments

Comments
 (0)