|
1 | | -"""import os |
2 | | -import pickle |
3 | | -from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union |
4 | | -
|
5 | | -import torch |
6 | | -from spacy.tokens import Doc |
7 | | -from typing_extensions import NotRequired, TypedDict |
8 | | -
|
9 | | -from edsnlp.core.pipeline import PipelineProtocol |
10 | | -from edsnlp.core.torch_component import BatchInput, TorchComponent |
11 | | -from edsnlp.pipes.base import BaseComponent |
12 | | -from edsnlp.pipes.trainable.embeddings.typing import ( |
13 | | - WordContextualizerComponent, |
14 | | - WordEmbeddingComponent, |
15 | | -) |
16 | | -from edsnlp.utils.bindings import Attributes |
17 | | -
|
18 | | -DocClassifierBatchInput = TypedDict( |
19 | | - "DocClassifierBatchInput", |
20 | | - { |
21 | | - "embedding": BatchInput, |
22 | | - "targets": NotRequired[torch.Tensor], |
23 | | - }, |
24 | | -) |
25 | | -
|
26 | | -DocClassifierBatchOutput = TypedDict( |
27 | | - "DocClassifierBatchOutput", |
28 | | - { |
29 | | - "loss": Optional[torch.Tensor], |
30 | | - "labels": Optional[torch.Tensor], |
31 | | - }, |
32 | | -) |
33 | | -
|
34 | | -
|
35 | | -class TrainableDocClassifier( |
36 | | - TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput], |
37 | | - BaseComponent, |
38 | | -): |
39 | | - def __init__( |
40 | | - self, |
41 | | - nlp: Optional[PipelineProtocol] = None, |
42 | | - name: str = "doc_classifier", |
43 | | - *, |
44 | | - embedding: Union[WordEmbeddingComponent, WordContextualizerComponent], |
45 | | - num_classes: int, |
46 | | - label_attr: str = "label", |
47 | | - loss_fn=None, |
48 | | - ): |
49 | | - self.label_attr: Attributes = label_attr |
50 | | - super().__init__(nlp, name) |
51 | | - self.embedding = embedding |
52 | | - self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss() |
53 | | -
|
54 | | - if not hasattr(self.embedding, "output_size"): |
55 | | - raise ValueError( |
56 | | - "The embedding component must have an 'output_size' attribute." |
57 | | - ) |
58 | | - embedding_size = self.embedding.output_size |
59 | | - self.classifier = torch.nn.Linear(embedding_size, num_classes) |
60 | | -
|
61 | | - def set_extensions(self) -> None: |
62 | | - super().set_extensions() |
63 | | - if not Doc.has_extension(self.label_attr): |
64 | | - Doc.set_extension(self.label_attr, default={}) |
65 | | -
|
66 | | - def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]): |
67 | | - super().post_init(gold_data, exclude=exclude) |
68 | | -
|
69 | | - def preprocess(self, doc: Doc) -> Dict[str, Any]: |
70 | | - return {"embedding": self.embedding.preprocess(doc)} |
71 | | -
|
72 | | - def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]: |
73 | | - preps = self.preprocess(doc) |
74 | | - label = getattr(doc._, self.label_attr, None) |
75 | | - if label is None: |
76 | | - raise ValueError( |
77 | | - f"Document does not have a gold label in 'doc._.{self.label_attr}'" |
78 | | - ) |
79 | | - return { |
80 | | - **preps, |
81 | | - "targets": torch.tensor(label, dtype=torch.long), |
82 | | - } |
83 | | -
|
84 | | - def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput: |
85 | | - embeddings = self.embedding.collate(batch["embedding"]) |
86 | | - batch_input: DocClassifierBatchInput = {"embedding": embeddings} |
87 | | - if "targets" in batch: |
88 | | - batch_input["targets"] = torch.stack(batch["targets"]) |
89 | | - return batch_input |
90 | | -
|
91 | | - def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput: |
92 | | - pooled = self.embedding(batch["embedding"]) |
93 | | - embeddings = pooled["embeddings"] |
94 | | -
|
95 | | - logits = self.classifier(embeddings) |
96 | | -
|
97 | | - output: DocClassifierBatchOutput = {} |
98 | | - if "targets" in batch: |
99 | | - loss = self.loss_fn(logits, batch["targets"]) |
100 | | - output["loss"] = loss |
101 | | - output["labels"] = None |
102 | | - else: |
103 | | - output["loss"] = None |
104 | | - output["labels"] = torch.argmax(logits, dim=-1) |
105 | | - return output |
106 | | -
|
107 | | - def postprocess(self, docs, results, input): |
108 | | - labels = results["labels"] |
109 | | - if isinstance(labels, torch.Tensor): |
110 | | - labels = labels.tolist() |
111 | | - for doc, label in zip(docs, labels): |
112 | | - setattr(doc._, self.label_attr, label) |
113 | | - # doc._.label = label |
114 | | - return docs |
115 | | -
|
116 | | - def to_disk(self, path, *, exclude=set()): |
117 | | - repr_id = object.__repr__(self) |
118 | | - if repr_id in exclude: |
119 | | - return |
120 | | - exclude.add(repr_id) |
121 | | - os.makedirs(path, exist_ok=True) |
122 | | - data_path = path / "label_attr.pkl" |
123 | | - with open(data_path, "wb") as f: |
124 | | - pickle.dump({"label_attr": self.label_attr}, f) |
125 | | - return super().to_disk(path, exclude=exclude) |
126 | | -""" |
127 | | - |
128 | 1 | import os |
129 | 2 | import pickle |
130 | 3 | from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union |
|
0 commit comments