@@ -66,6 +66,7 @@ def __bool__(self):
6666
6767
6868INFER = _InferType ()
69+ CONTEXT = [{}]
6970
7071T = TypeVar ("T" )
7172
@@ -152,18 +153,23 @@ def __repr__(self):
152153
153154
154155class MapOp (Op ):
155- def __init__ (self , pipe , kwargs ):
156+ def __init__ (self , pipe , kwargs , context = None ):
156157 self .pipe = pipe
157158 self .kwargs = kwargs
158159 self .is_generator = deep_isgeneratorfunction (pipe )
159160 self .elementwise = not self .is_generator
161+ self .context = context or {}
160162
161163 def __call__ (self , items ):
162164 for item in items :
163165 if isinstance (item , StreamSentinel ):
164166 yield item
165167 continue
168+
169+ CONTEXT [0 ], old = self .context , CONTEXT [0 ]
166170 res = self .pipe (item , ** self .kwargs )
171+ CONTEXT [0 ] = old
172+
167173 if self .is_generator :
168174 yield from res
169175 else :
@@ -178,21 +184,24 @@ def __repr__(self):
178184
179185
180186class MapBatchesOp (Op ):
181- def __init__ (self , pipe , kwargs , elementwise = False ):
187+ def __init__ (self , pipe , kwargs , context = None , elementwise = False ):
182188 self .pipe = pipe
183189 self .kwargs = kwargs
184190 self .is_generator = deep_isgeneratorfunction (pipe )
185191 if elementwise and self .is_generator :
186192 raise ValueError ("Cannot use elementwise=True with a generator function" )
187193 self .elementwise = elementwise
194+ self .context = context or {}
188195
189196 def __call__ (self , batches ):
190197 if hasattr (self .pipe , "batch_process" ):
191198 for batch in batches :
192199 if isinstance (batch , StreamSentinel ):
193200 yield batch
194201 continue
202+ CONTEXT [0 ], old = self .context , CONTEXT [0 ]
195203 res = self .pipe .batch_process (batch , ** self .kwargs )
204+ CONTEXT [0 ] = old
196205 res = list (res ) if self .is_generator else (res ,)
197206 yield from res
198207 else :
@@ -202,11 +211,13 @@ def __call__(self, batches):
202211 continue
203212 results = []
204213 for item in batch :
214+ CONTEXT [0 ], old = self .context , CONTEXT [0 ]
205215 res = (
206216 item
207217 if isinstance (item , StreamSentinel )
208218 else self .pipe (item , ** self .kwargs )
209219 )
220+ CONTEXT [0 ] = old
210221 res = list (res ) if self .is_generator else (res ,)
211222 results .extend (res )
212223 yield results
@@ -727,6 +738,8 @@ def map_pipeline(
727738 )
728739 ):
729740 op .kwargs ["tokenizer" ] = tokenizer
741+ if isinstance (op , (MapOp , MapBatchesOp )):
742+ op .context ["tokenizer" ] = tokenizer
730743 new_ops .append (op )
731744 new_ops .append (MapOp (model ._ensure_doc , {}))
732745 batch_size , batch_by = self .validate_batching (batch_size , batch_by )
0 commit comments