Skip to content

Commit 4969250

Browse files
committed
remove old code
1 parent 28aae43 commit 4969250

File tree

1 file changed

+0
-127
lines changed

1 file changed

+0
-127
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 0 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,3 @@
1-
"""import os
2-
import pickle
3-
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union
4-
5-
import torch
6-
from spacy.tokens import Doc
7-
from typing_extensions import NotRequired, TypedDict
8-
9-
from edsnlp.core.pipeline import PipelineProtocol
10-
from edsnlp.core.torch_component import BatchInput, TorchComponent
11-
from edsnlp.pipes.base import BaseComponent
12-
from edsnlp.pipes.trainable.embeddings.typing import (
13-
WordContextualizerComponent,
14-
WordEmbeddingComponent,
15-
)
16-
from edsnlp.utils.bindings import Attributes
17-
18-
DocClassifierBatchInput = TypedDict(
19-
"DocClassifierBatchInput",
20-
{
21-
"embedding": BatchInput,
22-
"targets": NotRequired[torch.Tensor],
23-
},
24-
)
25-
26-
DocClassifierBatchOutput = TypedDict(
27-
"DocClassifierBatchOutput",
28-
{
29-
"loss": Optional[torch.Tensor],
30-
"labels": Optional[torch.Tensor],
31-
},
32-
)
33-
34-
35-
class TrainableDocClassifier(
36-
TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput],
37-
BaseComponent,
38-
):
39-
def __init__(
40-
self,
41-
nlp: Optional[PipelineProtocol] = None,
42-
name: str = "doc_classifier",
43-
*,
44-
embedding: Union[WordEmbeddingComponent, WordContextualizerComponent],
45-
num_classes: int,
46-
label_attr: str = "label",
47-
loss_fn=None,
48-
):
49-
self.label_attr: Attributes = label_attr
50-
super().__init__(nlp, name)
51-
self.embedding = embedding
52-
self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()
53-
54-
if not hasattr(self.embedding, "output_size"):
55-
raise ValueError(
56-
"The embedding component must have an 'output_size' attribute."
57-
)
58-
embedding_size = self.embedding.output_size
59-
self.classifier = torch.nn.Linear(embedding_size, num_classes)
60-
61-
def set_extensions(self) -> None:
62-
super().set_extensions()
63-
if not Doc.has_extension(self.label_attr):
64-
Doc.set_extension(self.label_attr, default={})
65-
66-
def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
67-
super().post_init(gold_data, exclude=exclude)
68-
69-
def preprocess(self, doc: Doc) -> Dict[str, Any]:
70-
return {"embedding": self.embedding.preprocess(doc)}
71-
72-
def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
73-
preps = self.preprocess(doc)
74-
label = getattr(doc._, self.label_attr, None)
75-
if label is None:
76-
raise ValueError(
77-
f"Document does not have a gold label in 'doc._.{self.label_attr}'"
78-
)
79-
return {
80-
**preps,
81-
"targets": torch.tensor(label, dtype=torch.long),
82-
}
83-
84-
def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput:
85-
embeddings = self.embedding.collate(batch["embedding"])
86-
batch_input: DocClassifierBatchInput = {"embedding": embeddings}
87-
if "targets" in batch:
88-
batch_input["targets"] = torch.stack(batch["targets"])
89-
return batch_input
90-
91-
def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
92-
pooled = self.embedding(batch["embedding"])
93-
embeddings = pooled["embeddings"]
94-
95-
logits = self.classifier(embeddings)
96-
97-
output: DocClassifierBatchOutput = {}
98-
if "targets" in batch:
99-
loss = self.loss_fn(logits, batch["targets"])
100-
output["loss"] = loss
101-
output["labels"] = None
102-
else:
103-
output["loss"] = None
104-
output["labels"] = torch.argmax(logits, dim=-1)
105-
return output
106-
107-
def postprocess(self, docs, results, input):
108-
labels = results["labels"]
109-
if isinstance(labels, torch.Tensor):
110-
labels = labels.tolist()
111-
for doc, label in zip(docs, labels):
112-
setattr(doc._, self.label_attr, label)
113-
# doc._.label = label
114-
return docs
115-
116-
def to_disk(self, path, *, exclude=set()):
117-
repr_id = object.__repr__(self)
118-
if repr_id in exclude:
119-
return
120-
exclude.add(repr_id)
121-
os.makedirs(path, exist_ok=True)
122-
data_path = path / "label_attr.pkl"
123-
with open(data_path, "wb") as f:
124-
pickle.dump({"label_attr": self.label_attr}, f)
125-
return super().to_disk(path, exclude=exclude)
126-
"""
127-
1281
import os
1292
import pickle
1303
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union

0 commit comments

Comments
 (0)