|
1 | | -import os |
| 1 | +"""import os |
2 | 2 | import pickle |
3 | 3 | from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union |
4 | 4 |
|
@@ -123,3 +123,166 @@ def to_disk(self, path, *, exclude=set()): |
123 | 123 | with open(data_path, "wb") as f: |
124 | 124 | pickle.dump({"label_attr": self.label_attr}, f) |
125 | 125 | return super().to_disk(path, exclude=exclude) |
| 126 | +""" |
| 127 | + |
| 128 | +import os |
| 129 | +import pickle |
| 130 | +from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union |
| 131 | + |
| 132 | +import torch |
| 133 | +from spacy.tokens import Doc |
| 134 | +from typing_extensions import NotRequired, TypedDict |
| 135 | + |
| 136 | +from edsnlp.core.pipeline import PipelineProtocol |
| 137 | +from edsnlp.core.torch_component import BatchInput, TorchComponent |
| 138 | +from edsnlp.pipes.base import BaseComponent |
| 139 | +from edsnlp.pipes.trainable.embeddings.typing import ( |
| 140 | + WordContextualizerComponent, |
| 141 | + WordEmbeddingComponent, |
| 142 | +) |
| 143 | +from edsnlp.utils.bindings import Attributes |
| 144 | + |
| 145 | +DocClassifierBatchInput = TypedDict( |
| 146 | + "DocClassifierBatchInput", |
| 147 | + { |
| 148 | + "embedding": BatchInput, |
| 149 | + "targets": NotRequired[torch.Tensor], |
| 150 | + }, |
| 151 | +) |
| 152 | + |
| 153 | +DocClassifierBatchOutput = TypedDict( |
| 154 | + "DocClassifierBatchOutput", |
| 155 | + { |
| 156 | + "loss": Optional[torch.Tensor], |
| 157 | + "labels": Optional[torch.Tensor], |
| 158 | + }, |
| 159 | +) |
| 160 | + |
| 161 | + |
| 162 | +class TrainableDocClassifier( |
| 163 | + TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput], |
| 164 | + BaseComponent, |
| 165 | +): |
| 166 | + def __init__( |
| 167 | + self, |
| 168 | + nlp: Optional[PipelineProtocol] = None, |
| 169 | + name: str = "doc_classifier", |
| 170 | + *, |
| 171 | + embedding: Union[WordEmbeddingComponent, WordContextualizerComponent], |
| 172 | + num_classes: int, |
| 173 | + label_attr: str = "label", |
| 174 | + label2id: Optional[Dict[str, int]] = None, |
| 175 | + id2label: Optional[Dict[int, str]] = None, |
| 176 | + loss_fn=None, |
| 177 | + ): |
| 178 | + self.label_attr: Attributes = label_attr |
| 179 | + self.label2id = label2id or {} |
| 180 | + self.id2label = id2label or {} |
| 181 | + super().__init__(nlp, name) |
| 182 | + self.embedding = embedding |
| 183 | + self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss() |
| 184 | + |
| 185 | + if not hasattr(self.embedding, "output_size"): |
| 186 | + raise ValueError( |
| 187 | + "The embedding component must have an 'output_size' attribute." |
| 188 | + ) |
| 189 | + embedding_size = self.embedding.output_size |
| 190 | + self.classifier = torch.nn.Linear(embedding_size, num_classes) |
| 191 | + |
| 192 | + def set_extensions(self) -> None: |
| 193 | + super().set_extensions() |
| 194 | + if not Doc.has_extension(self.label_attr): |
| 195 | + Doc.set_extension(self.label_attr, default={}) |
| 196 | + |
| 197 | + def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]): |
| 198 | + if not self.label2id: |
| 199 | + labels = set() |
| 200 | + for doc in gold_data: |
| 201 | + label = getattr(doc._, self.label_attr, None) |
| 202 | + if isinstance(label, str): |
| 203 | + labels.add(label) |
| 204 | + if labels: |
| 205 | + self.label2id = {label: i for i, label in enumerate(sorted(labels))} |
| 206 | + self.id2label = {i: label for label, i in self.label2id.items()} |
| 207 | + super().post_init(gold_data, exclude=exclude) |
| 208 | + |
| 209 | + def preprocess(self, doc: Doc) -> Dict[str, Any]: |
| 210 | + return {"embedding": self.embedding.preprocess(doc)} |
| 211 | + |
| 212 | + def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]: |
| 213 | + preps = self.preprocess(doc) |
| 214 | + label = getattr(doc._, self.label_attr, None) |
| 215 | + if label is None: |
| 216 | + raise ValueError( |
| 217 | + f"Document does not have a gold label in 'doc._.{self.label_attr}'" |
| 218 | + ) |
| 219 | + if isinstance(label, str) and self.label2id: |
| 220 | + if label not in self.label2id: |
| 221 | + raise ValueError(f"Label '{label}' not in label2id mapping.") |
| 222 | + label = self.label2id[label] |
| 223 | + return { |
| 224 | + **preps, |
| 225 | + "targets": torch.tensor(label, dtype=torch.long), |
| 226 | + } |
| 227 | + |
| 228 | + def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput: |
| 229 | + embeddings = self.embedding.collate(batch["embedding"]) |
| 230 | + batch_input: DocClassifierBatchInput = {"embedding": embeddings} |
| 231 | + if "targets" in batch: |
| 232 | + batch_input["targets"] = torch.stack(batch["targets"]) |
| 233 | + return batch_input |
| 234 | + |
| 235 | + def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput: |
| 236 | + pooled = self.embedding(batch["embedding"]) |
| 237 | + embeddings = pooled["embeddings"] |
| 238 | + |
| 239 | + logits = self.classifier(embeddings) |
| 240 | + |
| 241 | + output: DocClassifierBatchOutput = {} |
| 242 | + if "targets" in batch: |
| 243 | + loss = self.loss_fn(logits, batch["targets"]) |
| 244 | + output["loss"] = loss |
| 245 | + output["labels"] = None |
| 246 | + else: |
| 247 | + output["loss"] = None |
| 248 | + output["labels"] = torch.argmax(logits, dim=-1) |
| 249 | + return output |
| 250 | + |
| 251 | + def postprocess(self, docs, results, input): |
| 252 | + labels = results["labels"] |
| 253 | + if isinstance(labels, torch.Tensor): |
| 254 | + labels = labels.tolist() |
| 255 | + for doc, label in zip(docs, labels): |
| 256 | + if self.id2label and isinstance(label, int): |
| 257 | + label = self.id2label.get(label, label) |
| 258 | + setattr(doc._, self.label_attr, label) |
| 259 | + return docs |
| 260 | + |
| 261 | + def to_disk(self, path, *, exclude=set()): |
| 262 | + repr_id = object.__repr__(self) |
| 263 | + if repr_id in exclude: |
| 264 | + return |
| 265 | + exclude.add(repr_id) |
| 266 | + os.makedirs(path, exist_ok=True) |
| 267 | + data_path = path / "label_attr.pkl" |
| 268 | + with open(data_path, "wb") as f: |
| 269 | + pickle.dump( |
| 270 | + { |
| 271 | + "label_attr": self.label_attr, |
| 272 | + "label2id": self.label2id, |
| 273 | + "id2label": self.id2label, |
| 274 | + }, |
| 275 | + f, |
| 276 | + ) |
| 277 | + return super().to_disk(path, exclude=exclude) |
| 278 | + |
| 279 | + @classmethod |
| 280 | + def from_disk(cls, path, **kwargs): |
| 281 | + data_path = path / "label_attr.pkl" |
| 282 | + with open(data_path, "rb") as f: |
| 283 | + data = pickle.load(f) |
| 284 | + obj = super().from_disk(path, **kwargs) |
| 285 | + obj.label_attr = data.get("label_attr", "label") |
| 286 | + obj.label2id = data.get("label2id", {}) |
| 287 | + obj.id2label = data.get("id2label", {}) |
| 288 | + return obj |
0 commit comments