Skip to content

Commit ddcafcb

Browse files
committed
add unit test for doc_classifier
1 parent 6d379b1 commit ddcafcb

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
import edsnlp
4+
import edsnlp.pipes as eds
5+
6+
pytest.importorskip("torch.nn")
7+
8+
9+
@pytest.mark.parametrize("pooling_mode", ["mean", "max", "cls", "sum"])
10+
@pytest.mark.parametrize("label_attr", ["label", "alive"])
11+
@pytest.mark.parametrize("num_classes", [2, 10])
12+
def test_doc_classifier(pooling_mode, label_attr, num_classes):
13+
nlp = edsnlp.blank("eds")
14+
doc = nlp.make_doc("Le patient est mort.")
15+
16+
nlp.add_pipe(
17+
eds.doc_classifier(
18+
embedding=eds.doc_pooler(
19+
pooling_mode=pooling_mode,
20+
embedding=eds.transformer(
21+
model="prajjwal1/bert-tiny",
22+
window=128,
23+
stride=96,
24+
),
25+
),
26+
num_classes=num_classes,
27+
label_attr=label_attr,
28+
),
29+
name="doc_classifier",
30+
)
31+
doc = nlp(doc)
32+
label = getattr(doc._, label_attr, None)
33+
assert label in range(0, num_classes)

0 commit comments

Comments
 (0)