Skip to content

Commit 28aae43

Browse files
committed
add string label handling
1 parent ddcafcb commit 28aae43

File tree

1 file changed

+164
-1
lines changed

1 file changed

+164
-1
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
"""import os
22
import pickle
33
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union
44
@@ -123,3 +123,166 @@ def to_disk(self, path, *, exclude=set()):
123123
with open(data_path, "wb") as f:
124124
pickle.dump({"label_attr": self.label_attr}, f)
125125
return super().to_disk(path, exclude=exclude)
126+
"""
127+
128+
import os
129+
import pickle
130+
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union
131+
132+
import torch
133+
from spacy.tokens import Doc
134+
from typing_extensions import NotRequired, TypedDict
135+
136+
from edsnlp.core.pipeline import PipelineProtocol
137+
from edsnlp.core.torch_component import BatchInput, TorchComponent
138+
from edsnlp.pipes.base import BaseComponent
139+
from edsnlp.pipes.trainable.embeddings.typing import (
140+
WordContextualizerComponent,
141+
WordEmbeddingComponent,
142+
)
143+
from edsnlp.utils.bindings import Attributes
144+
145+
DocClassifierBatchInput = TypedDict(
146+
"DocClassifierBatchInput",
147+
{
148+
"embedding": BatchInput,
149+
"targets": NotRequired[torch.Tensor],
150+
},
151+
)
152+
153+
DocClassifierBatchOutput = TypedDict(
154+
"DocClassifierBatchOutput",
155+
{
156+
"loss": Optional[torch.Tensor],
157+
"labels": Optional[torch.Tensor],
158+
},
159+
)
160+
161+
162+
class TrainableDocClassifier(
163+
TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput],
164+
BaseComponent,
165+
):
166+
def __init__(
167+
self,
168+
nlp: Optional[PipelineProtocol] = None,
169+
name: str = "doc_classifier",
170+
*,
171+
embedding: Union[WordEmbeddingComponent, WordContextualizerComponent],
172+
num_classes: int,
173+
label_attr: str = "label",
174+
label2id: Optional[Dict[str, int]] = None,
175+
id2label: Optional[Dict[int, str]] = None,
176+
loss_fn=None,
177+
):
178+
self.label_attr: Attributes = label_attr
179+
self.label2id = label2id or {}
180+
self.id2label = id2label or {}
181+
super().__init__(nlp, name)
182+
self.embedding = embedding
183+
self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()
184+
185+
if not hasattr(self.embedding, "output_size"):
186+
raise ValueError(
187+
"The embedding component must have an 'output_size' attribute."
188+
)
189+
embedding_size = self.embedding.output_size
190+
self.classifier = torch.nn.Linear(embedding_size, num_classes)
191+
192+
def set_extensions(self) -> None:
193+
super().set_extensions()
194+
if not Doc.has_extension(self.label_attr):
195+
Doc.set_extension(self.label_attr, default={})
196+
197+
def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
198+
if not self.label2id:
199+
labels = set()
200+
for doc in gold_data:
201+
label = getattr(doc._, self.label_attr, None)
202+
if isinstance(label, str):
203+
labels.add(label)
204+
if labels:
205+
self.label2id = {label: i for i, label in enumerate(sorted(labels))}
206+
self.id2label = {i: label for label, i in self.label2id.items()}
207+
super().post_init(gold_data, exclude=exclude)
208+
209+
def preprocess(self, doc: Doc) -> Dict[str, Any]:
210+
return {"embedding": self.embedding.preprocess(doc)}
211+
212+
def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
213+
preps = self.preprocess(doc)
214+
label = getattr(doc._, self.label_attr, None)
215+
if label is None:
216+
raise ValueError(
217+
f"Document does not have a gold label in 'doc._.{self.label_attr}'"
218+
)
219+
if isinstance(label, str) and self.label2id:
220+
if label not in self.label2id:
221+
raise ValueError(f"Label '{label}' not in label2id mapping.")
222+
label = self.label2id[label]
223+
return {
224+
**preps,
225+
"targets": torch.tensor(label, dtype=torch.long),
226+
}
227+
228+
def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput:
229+
embeddings = self.embedding.collate(batch["embedding"])
230+
batch_input: DocClassifierBatchInput = {"embedding": embeddings}
231+
if "targets" in batch:
232+
batch_input["targets"] = torch.stack(batch["targets"])
233+
return batch_input
234+
235+
def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
236+
pooled = self.embedding(batch["embedding"])
237+
embeddings = pooled["embeddings"]
238+
239+
logits = self.classifier(embeddings)
240+
241+
output: DocClassifierBatchOutput = {}
242+
if "targets" in batch:
243+
loss = self.loss_fn(logits, batch["targets"])
244+
output["loss"] = loss
245+
output["labels"] = None
246+
else:
247+
output["loss"] = None
248+
output["labels"] = torch.argmax(logits, dim=-1)
249+
return output
250+
251+
def postprocess(self, docs, results, input):
252+
labels = results["labels"]
253+
if isinstance(labels, torch.Tensor):
254+
labels = labels.tolist()
255+
for doc, label in zip(docs, labels):
256+
if self.id2label and isinstance(label, int):
257+
label = self.id2label.get(label, label)
258+
setattr(doc._, self.label_attr, label)
259+
return docs
260+
261+
def to_disk(self, path, *, exclude=set()):
262+
repr_id = object.__repr__(self)
263+
if repr_id in exclude:
264+
return
265+
exclude.add(repr_id)
266+
os.makedirs(path, exist_ok=True)
267+
data_path = path / "label_attr.pkl"
268+
with open(data_path, "wb") as f:
269+
pickle.dump(
270+
{
271+
"label_attr": self.label_attr,
272+
"label2id": self.label2id,
273+
"id2label": self.id2label,
274+
},
275+
f,
276+
)
277+
return super().to_disk(path, exclude=exclude)
278+
279+
@classmethod
280+
def from_disk(cls, path, **kwargs):
281+
data_path = path / "label_attr.pkl"
282+
with open(data_path, "rb") as f:
283+
data = pickle.load(f)
284+
obj = super().from_disk(path, **kwargs)
285+
obj.label_attr = data.get("label_attr", "label")
286+
obj.label2id = data.get("label2id", {})
287+
obj.id2label = data.get("id2label", {})
288+
return obj

0 commit comments

Comments
 (0)