Skip to content

Commit b2af895

Browse files
committed
automatically set num_classes from labels if not provided by user
1 parent 4969250 commit b2af895

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
name: str = "doc_classifier",
4343
*,
4444
embedding: Union[WordEmbeddingComponent, WordContextualizerComponent],
45-
num_classes: int,
45+
num_classes: Optional[int] = None,
4646
label_attr: str = "label",
4747
label2id: Optional[Dict[str, int]] = None,
4848
id2label: Optional[Dict[int, str]] = None,
@@ -60,7 +60,8 @@ def __init__(
6060
"The embedding component must have an 'output_size' attribute."
6161
)
6262
embedding_size = self.embedding.output_size
63-
self.classifier = torch.nn.Linear(embedding_size, num_classes)
63+
if num_classes:
64+
self.classifier = torch.nn.Linear(embedding_size, num_classes)
6465

6566
def set_extensions(self) -> None:
6667
super().set_extensions()
@@ -77,6 +78,10 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
7778
if labels:
7879
self.label2id = {label: i for i, label in enumerate(sorted(labels))}
7980
self.id2label = {i: label for label, i in self.label2id.items()}
81+
print("num classes:", len(self.label2id))
82+
self.classifier = torch.nn.Linear(
83+
self.embedding.output_size, len(self.label2id)
84+
)
8085
super().post_init(gold_data, exclude=exclude)
8186

8287
def preprocess(self, doc: Doc) -> Dict[str, Any]:

0 commit comments

Comments
 (0)