@@ -111,22 +111,34 @@ def process_request(
111111 print (f"Context tokens: { context_tokens .shape } " )
112112 responses : List [str ] = []
113113
114- for question in questions :
115- sparse_meta_data : Dict [str , Any ] = {}
114+ self .model .eval ()
115+ with torch .no_grad ():
116+ for question in questions :
117+ sparse_meta_data : Dict [str , Any ] = {}
116118
117- question_tokens = self .tokenizer .encode (question , return_tensors = "pt" )
118- if self .device is not None :
119- question_tokens = question_tokens .to (self .device )
119+ question_tokens = self .tokenizer .encode (question , return_tensors = "pt" )
120+ if self .device is not None :
121+ question_tokens = question_tokens .to (self .device )
120122
121- context_outputs = self .model (
122- context_tokens ,
123- past_key_values = None ,
124- use_cache = True ,
125- sparse_meta_data = sparse_meta_data ,
126- )
123+ context_outputs = self .model (
124+ context_tokens ,
125+ past_key_values = None ,
126+ use_cache = True ,
127+ sparse_meta_data = sparse_meta_data ,
128+ )
127129
128- if self ._sparse_attention_available :
129- with self .enable_sparse_mode ():
130+ if self ._sparse_attention_available :
131+ with self .enable_sparse_mode ():
132+ response_text = self ._generate_response (
133+ question_tokens ,
134+ context_outputs ,
135+ sparse_meta_data ,
136+ generation_kwargs ,
137+ ** kwargs ,
138+ )
139+ responses .append (response_text )
140+ else :
141+ # Dense-only mode: process questions with dense attention
130142 response_text = self ._generate_response (
131143 question_tokens ,
132144 context_outputs ,
@@ -135,16 +147,6 @@ def process_request(
135147 ** kwargs ,
136148 )
137149 responses .append (response_text )
138- else :
139- # Dense-only mode: process questions with dense attention
140- response_text = self ._generate_response (
141- question_tokens ,
142- context_outputs ,
143- sparse_meta_data ,
144- generation_kwargs ,
145- ** kwargs ,
146- )
147- responses .append (response_text )
148150
149151 if isinstance (request .questions , str ):
150152 return RequestResponse (responses = responses [0 ])
0 commit comments