@@ -42,7 +42,7 @@ def __init__(
4242 name : str = "doc_classifier" ,
4343 * ,
4444 embedding : Union [WordEmbeddingComponent , WordContextualizerComponent ],
45- num_classes : int ,
45+ num_classes : Optional [ int ] = None ,
4646 label_attr : str = "label" ,
4747 label2id : Optional [Dict [str , int ]] = None ,
4848 id2label : Optional [Dict [int , str ]] = None ,
@@ -60,7 +60,8 @@ def __init__(
6060 "The embedding component must have an 'output_size' attribute."
6161 )
6262 embedding_size = self .embedding .output_size
63- self .classifier = torch .nn .Linear (embedding_size , num_classes )
63+ if num_classes :
64+ self .classifier = torch .nn .Linear (embedding_size , num_classes )
6465
6566 def set_extensions (self ) -> None :
6667 super ().set_extensions ()
@@ -77,6 +78,10 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
7778 if labels :
7879 self .label2id = {label : i for i , label in enumerate (sorted (labels ))}
7980 self .id2label = {i : label for label , i in self .label2id .items ()}
81+ print ("num classes:" , len (self .label2id ))
82+ self .classifier = torch .nn .Linear (
83+ self .embedding .output_size , len (self .label2id )
84+ )
8085 super ().post_init (gold_data , exclude = exclude )
8186
8287 def preprocess (self , doc : Doc ) -> Dict [str , Any ]:
0 commit comments