Skip to content

Commit d501eb6

Browse files
committed
fix: support get_current_tokenizer again
1 parent 91f8190 commit d501eb6

File tree

3 files changed

+18
-33
lines changed

3 files changed

+18
-33
lines changed

edsnlp/core/stream.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __bool__(self):
6666

6767

6868
INFER = _InferType()
69+
CONTEXT = [{}]
6970

7071
T = TypeVar("T")
7172

@@ -152,18 +153,23 @@ def __repr__(self):
152153

153154

154155
class MapOp(Op):
155-
def __init__(self, pipe, kwargs):
156+
def __init__(self, pipe, kwargs, context=None):
156157
self.pipe = pipe
157158
self.kwargs = kwargs
158159
self.is_generator = deep_isgeneratorfunction(pipe)
159160
self.elementwise = not self.is_generator
161+
self.context = context or {}
160162

161163
def __call__(self, items):
162164
for item in items:
163165
if isinstance(item, StreamSentinel):
164166
yield item
165167
continue
168+
169+
CONTEXT[0], old = self.context, CONTEXT[0]
166170
res = self.pipe(item, **self.kwargs)
171+
CONTEXT[0] = old
172+
167173
if self.is_generator:
168174
yield from res
169175
else:
@@ -178,21 +184,24 @@ def __repr__(self):
178184

179185

180186
class MapBatchesOp(Op):
181-
def __init__(self, pipe, kwargs, elementwise=False):
187+
def __init__(self, pipe, kwargs, context=None, elementwise=False):
182188
self.pipe = pipe
183189
self.kwargs = kwargs
184190
self.is_generator = deep_isgeneratorfunction(pipe)
185191
if elementwise and self.is_generator:
186192
raise ValueError("Cannot use elementwise=True with a generator function")
187193
self.elementwise = elementwise
194+
self.context = context or {}
188195

189196
def __call__(self, batches):
190197
if hasattr(self.pipe, "batch_process"):
191198
for batch in batches:
192199
if isinstance(batch, StreamSentinel):
193200
yield batch
194201
continue
202+
CONTEXT[0], old = self.context, CONTEXT[0]
195203
res = self.pipe.batch_process(batch, **self.kwargs)
204+
CONTEXT[0] = old
196205
res = list(res) if self.is_generator else (res,)
197206
yield from res
198207
else:
@@ -202,11 +211,13 @@ def __call__(self, batches):
202211
continue
203212
results = []
204213
for item in batch:
214+
CONTEXT[0], old = self.context, CONTEXT[0]
205215
res = (
206216
item
207217
if isinstance(item, StreamSentinel)
208218
else self.pipe(item, **self.kwargs)
209219
)
220+
CONTEXT[0] = old
210221
res = list(res) if self.is_generator else (res,)
211222
results.extend(res)
212223
yield results
@@ -727,6 +738,8 @@ def map_pipeline(
727738
)
728739
):
729740
op.kwargs["tokenizer"] = tokenizer
741+
if isinstance(op, (MapOp, MapBatchesOp)):
742+
op.context["tokenizer"] = tokenizer
730743
new_ops.append(op)
731744
new_ops.append(MapOp(model._ensure_doc, {}))
732745
batch_size, batch_by = self.validate_batching(batch_size, batch_by)

edsnlp/data/converters.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Doc objects, and writers convert Doc objects to dictionaries.
55
"""
66

7-
import contextlib
87
import inspect
98
from copy import copy
109
from types import FunctionType
@@ -26,6 +25,7 @@
2625

2726
import edsnlp
2827
from edsnlp import registry
28+
from edsnlp.core.stream import CONTEXT
2929
from edsnlp.utils.bindings import BINDING_GETTERS
3030
from edsnlp.utils.span_getters import (
3131
SpanGetterArg,
@@ -136,21 +136,13 @@ def validate_attributes_mapping(value: AttributesMappingArg) -> Dict[str, str]:
136136

137137
def get_current_tokenizer():
138138
global _DEFAULT_TOKENIZER
139+
if "tokenizer" in CONTEXT[0]:
140+
return CONTEXT[0]["tokenizer"]
139141
if _DEFAULT_TOKENIZER is None:
140142
_DEFAULT_TOKENIZER = edsnlp.blank("eds").tokenizer
141143
return _DEFAULT_TOKENIZER
142144

143145

144-
@contextlib.contextmanager
145-
def set_current_tokenizer(tokenizer):
146-
global _DEFAULT_TOKENIZER
147-
old = _DEFAULT_TOKENIZER
148-
if tokenizer:
149-
_DEFAULT_TOKENIZER = tokenizer
150-
yield
151-
_DEFAULT_TOKENIZER = old
152-
153-
154146
@registry.factory.register("eds.standoff_dict2doc", spacy_compatible=False)
155147
class StandoffDict2DocConverter:
156148
"""

edsnlp/processing/utils.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)