Skip to content

Commit b0c840a

Browse files
committed
speed up label mapping
1 parent b2af895 commit b0c840a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
7676
if isinstance(label, str):
7777
labels.add(label)
7878
if labels:
79-
self.label2id = {label: i for i, label in enumerate(sorted(labels))}
80-
self.id2label = {i: label for label, i in self.label2id.items()}
79+
self.label2id = {}
80+
self.id2label = {}
81+
for i, label in enumerate(labels):
82+
self.label2id[label] = i
83+
self.id2label[i] = label
8184
print("num classes:", len(self.label2id))
8285
self.classifier = torch.nn.Linear(
8386
self.embedding.output_size, len(self.label2id)

0 commit comments

Comments
 (0)