Skip to content

Commit 874d36a

Browse files
committed
fix: support mlg-norm and eds-pseudo again
1 parent e76c22d commit 874d36a

File tree

7 files changed

+38
-8
lines changed

7 files changed

+38
-8
lines changed

edsnlp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def find_spec(self, fullname, path, target=None): # pragma: no cover
4444
new_name = "edsnlp.metrics" + fullname[14:]
4545
spec = importlib.util.spec_from_loader(fullname, AliasLoader(new_name))
4646
return spec
47+
if fullname.startswith("edsnlp.metrics.span_classification"):
48+
new_name = "edsnlp.metrics.span_attributes" + fullname[34:]
49+
spec = importlib.util.spec_from_loader(fullname, AliasLoader(new_name))
50+
return spec
4751
if "span_qualifier" in fullname.split("."):
4852
new_name = fullname.replace("span_qualifier", "span_classifier")
4953
spec = importlib.util.spec_from_loader(fullname, AliasLoader(new_name))

edsnlp/data/converters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343

4444
_DEFAULT_TOKENIZER = None
4545

46+
# For backward compatibility
47+
SequenceStr = AsList[str]
48+
4649

4750
def without_filename(d):
4851
d.pop(FILENAME, None)

edsnlp/metrics/span_attributes.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
def span_attribute_metric(
1212
examples: Examples,
1313
span_getter: SpanGetterArg,
14-
attributes: Attributes,
14+
attributes: Attributes = None,
1515
include_falsy: bool = False,
1616
default_values: Dict = {},
1717
micro_key: str = "micro",
1818
filter_expr: Optional[str] = None,
19+
**kwargs: Any,
1920
):
2021
"""
2122
Scores the attributes predictions between a list of gold and predicted spans.
@@ -47,6 +48,23 @@ def span_attribute_metric(
4748
-------
4849
Dict[str, float]
4950
"""
51+
if "qualifiers" in kwargs:
52+
warnings.warn(
53+
"The `qualifiers` argument of span_attribute_metric() is "
54+
"deprecated. Use `attributes` instead.",
55+
DeprecationWarning,
56+
)
57+
assert attributes is None
58+
attributes = kwargs.pop("qualifiers")
59+
if attributes is None:
60+
raise TypeError(
61+
"span_attribute_metric() missing 1 required argument: 'attributes'"
62+
)
63+
if kwargs:
64+
raise TypeError(
65+
f"span_attribute_metric() got unexpected keyword arguments: "
66+
f"{', '.join(kwargs.keys())}"
67+
)
5068
examples = make_examples(examples)
5169
if filter_expr is not None:
5270
filter_fn = eval(f"lambda doc: {filter_expr}")

edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def preprocess(
129129
"same."
130130
)
131131
aligned_contexts = (
132-
[[c] for c in contexts] if pre_aligned else align_spans(contexts, spans)
132+
[[c] for c in contexts]
133+
if pre_aligned
134+
else align_spans(contexts, spans, sort_by_overlap=True)
133135
)
134136
for i, (span, ctx) in enumerate(zip(spans, aligned_contexts)):
135137
if len(ctx) == 0 or ctx[0].start > span.start or ctx[0].end < span.end:
@@ -196,8 +198,9 @@ def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput:
196198
"""
197199
device = next(self.parameters()).device
198200
if len(batch["begins"]) == 0:
201+
span_embeds = torch.empty(0, self.output_size, device=device)
199202
return {
200-
"embeddings": torch.empty(0, self.output_size, device=device),
203+
"embeddings": batch["begins"].with_data(span_embeds),
201204
}
202205

203206
embeds = self.embedding(batch["embedding"])["embeddings"]

edsnlp/pipes/trainable/span_classifier/span_classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(
194194
embedding: SpanEmbeddingComponent,
195195
attributes: AttributesArg = None,
196196
qualifiers: AttributesArg = None,
197-
span_getter: SpanGetterArg = {"ents": True},
197+
span_getter: SpanGetterArg = None,
198198
context_getter: Optional[SpanGetterArg] = None,
199199
values: Optional[Dict[str, List[Any]]] = None,
200200
keep_none: bool = False,
@@ -217,6 +217,7 @@ def __init__(
217217
sub_span_getter is not None and span_getter is None
218218
): # pragma: no cover # noqa: E501
219219
span_getter = sub_span_getter
220+
span_getter = span_getter or {"ents": True}
220221
sub_context_getter = getattr(embedding, "context_getter", None)
221222
if (
222223
sub_context_getter is not None and context_getter is None

edsnlp/pipes/trainable/span_linker/span_linker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __init__(
278278
rescale: float = 20,
279279
threshold: float = 0.5,
280280
attribute: str = "cui",
281-
span_getter: SpanGetterArg = {"ents": True},
281+
span_getter: SpanGetterArg = None,
282282
context_getter: Optional[SpanGetterArg] = None,
283283
reference_mode: Literal["concept", "synonym"] = "concept",
284284
probability_mode: Literal["softmax", "sigmoid"] = "sigmoid",
@@ -289,6 +289,7 @@ def __init__(
289289
sub_span_getter = getattr(embedding, "span_getter", None)
290290
if sub_span_getter is not None and span_getter is None: # pragma: no cover
291291
span_getter = sub_span_getter
292+
span_getter = span_getter or {"ents": True}
292293
sub_context_getter = getattr(embedding, "context_getter", None)
293294
if (
294295
sub_context_getter is not None and context_getter is None
@@ -309,7 +310,7 @@ def __init__(
309310
self.reference_mode = reference_mode
310311
self.probability_mode = probability_mode
311312
self.init_weights = init_weights
312-
self.context_getter: SpanGetter = context_getter or span_getter
313+
self.context_getter: SpanGetter = context_getter or self.span_getter
313314
with warnings.catch_warnings():
314315
warnings.simplefilter("ignore", UserWarning)
315316
self.classifier = Metric(

edsnlp/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from confit import Cli
22

3-
from edsnlp.core.registries import registry
4-
from edsnlp.training import train
3+
from edsnlp.training.trainer import * # noqa: F403
4+
from edsnlp.training.trainer import registry, train
55

66
app = Cli(pretty_exceptions_show_locals=False)
77
train_command = app.command(name="train", registry=registry)(train)

0 commit comments

Comments
 (0)