44from tap import Tap
55from typing import Dict , Iterator , List , Optional , Union , Literal , Tuple
66import numpy as np
7-
8-
9- Metric = Literal ['roc-auc' , 'accuracy' , 'precision' , 'recall' , 'f1_score' ,
10- 'rmse' , 'mae' , 'mse' , 'r2' , 'max' ]
7+ from mgktools .evaluators .metric import Metric
118
129
1310class CommonArgs (Tap ):
@@ -47,10 +44,16 @@ class CommonArgs(Tap):
4744 """
4845 features_generator : List [str ] = None
4946 """Method(s) of generating additional features_mol."""
47+ features_combination : Literal ['concat' , 'mean' ] = None
48+ """How to combine features vector for mixtures."""
5049 target_columns : List [str ] = None
5150 """
5251 Name of the columns containing target values.
5352 """
53+ features_mol_normalize : bool = False
54+ """Nomralize the molecular features_mol."""
55+ features_add_normalize : bool = False
56+ """Nomralize the additonal features_mol."""
5457 group_reading : bool = False
5558 """Find unique input strings first, then read the data."""
5659 def __init__ (self , * args , ** kwargs ):
@@ -89,10 +92,12 @@ def process_args(self) -> None:
8992
9093
9194class KernelArgs (CommonArgs ):
92- graph_kernel_type : Literal ['graph' , 'preCalc ' ] = None
95+ graph_kernel_type : Literal ['graph' , 'pre-computed ' ] = None
9396 """The type of kernel to use."""
9497 graph_hyperparameters : List [str ] = None
9598 """hyperparameters file for graph kernel."""
99+ features_kernel_type : Literal ['dot_product' , 'rbf' ] = None
100+ """choose dot product kernel or rbf kernel for features."""
96101 features_hyperparameters : List [float ] = None
97102 """hyperparameters for molecular features."""
98103 features_hyperparameters_min : List [float ] = None
@@ -101,17 +106,16 @@ class KernelArgs(CommonArgs):
101106 """hyperparameters for molecular features."""
102107 features_hyperparameters_file : str = None
103108 """JSON file contains features hyperparameters"""
104- features_mol_normalize : bool = False
105- """Nomralize the molecular features_mol."""
106- features_add_normalize : bool = False
107- """Nomralize the additonal features_mol."""
108109 single_features_hyperparameter : bool = True
109110 """Use the same hyperparameter for all features."""
110111
111112 @property
112113 def features_hyperparameters_bounds (self ):
113114 if self .features_hyperparameters_min is None or self .features_hyperparameters_max is None :
114- return 'fixed'
115+ if self .features_hyperparameters is None :
116+ return None
117+ else :
118+ return 'fixed'
115119 else :
116120 return [(self .features_hyperparameters_min [i ], self .features_hyperparameters_max [i ])
117121 for i in range (len (self .features_hyperparameters ))]
@@ -148,19 +152,33 @@ def process_args(self) -> None:
148152 assert self .block_id [1 ] >= self .block_id [0 ]
149153
150154
155+ class DataSplitArgs (Tap ):
156+ split_type : Literal ['random' , 'scaffold_balanced' , 'loocv' ] = 'random'
157+ """Method of splitting the data into train/val/test."""
158+ split_sizes : Tuple [float , float ] = (0.8 , 0.2 )
159+ """Split proportions for train/validation/test sets."""
160+ num_folds : int = 1
161+ """Number of folds when performing cross validation."""
162+ save_dir : str
163+ """The output directory."""
164+ n_jobs : int = 1
165+ """The cpu numbers used for parallel computing."""
166+ data_path : str = None
167+ """The Path of input data CSV file."""
168+
169+
151170class TrainArgs (KernelArgs ):
152- dataset_type : Literal ['regression' , 'classification ' , 'multiclass ' ] = None
171+ task_type : Literal ['regression' , 'binary ' , 'multi-class ' ] = None
153172 """
154- Type of dataset . This determines the loss function used during training.
173+ Type of task . This determines the loss function used during training.
155174 """
156175 model_type : Literal ['gpr' , 'svc' , 'svr' , 'gpc' , 'gpr_nystrom' , 'gpr_nle' ]
157176 """Type of model to use"""
158177 loss : Literal ['loocv' , 'likelihood' ] = 'loocv'
159178 """The target loss function to minimize or maximize."""
160-
161- split_type : Literal ['random' , 'scaffold_balanced' , 'loocv' ] = 'random'
179+ split_type : Literal ['random' , 'scaffold_order' , 'scaffold_random' , 'stratified' , 'n_heavy' , 'loocv' ] = None
162180 """Method of splitting the data into train/val/test."""
163- split_sizes : Tuple [float , float ] = ( 0.8 , 0.2 )
181+ split_sizes : List [float ] = [ 0.8 , 0.2 ]
164182 """Split proportions for train/validation/test sets."""
165183 num_folds : int = 1
166184 """Number of folds when performing cross validation."""
@@ -181,7 +199,7 @@ class TrainArgs(KernelArgs):
181199 """The rule to combining prediction from estimators."""
182200 n_local : int = 500
183201 """The number of samples used in Naive Local Experts."""
184- n_core : int = 500
202+ n_core : int = None
185203 """The number of samples used in Nystrom core set."""
186204 metric : Metric = None
187205 """metric"""
@@ -195,23 +213,29 @@ class TrainArgs(KernelArgs):
195213 """If set True, 5 most similar molecules in the training set will be save in the test_*.log."""
196214 save_model : bool = False
197215 """Save the trained model file."""
216+ separate_test_path : str = None
217+ """Path to separate test set, optional."""
198218
199219 @property
200- def metrics (self ) -> List [str ]:
220+ def metrics (self ) -> List [Metric ]:
201221 return [self .metric ] + self .extra_metrics
202222
203223 @property
204224 def alpha_ (self ) -> float :
205- if isinstance (self .alpha , float ):
225+ if self .alpha is None :
226+ return None
227+ elif isinstance (self .alpha , float ):
206228 return self .alpha
207- if os .path .exists (self .alpha ):
229+ elif os .path .exists (self .alpha ):
208230 return float (open (self .alpha , 'r' ).read ())
209231 else :
210232 return float (self .alpha )
211233
212234 @property
213235 def C_ (self ) -> float :
214- if isinstance (self .C , float ):
236+ if self .C is None :
237+ return None
238+ elif isinstance (self .C , float ):
215239 return self .C
216240 elif os .path .exists (self .C ):
217241 return float (open (self .C , 'r' ).read ())
@@ -223,22 +247,19 @@ def kernel_args(self):
223247
224248 def process_args (self ) -> None :
225249 super ().process_args ()
226- if self .dataset_type == 'regression' :
250+ if self .task_type == 'regression' :
227251 assert self .model_type in ['gpr' , 'gpr_nystrom' , 'gpr_nle' , 'svr' ]
228252 for metric in self .metrics :
229253 assert metric in ['rmse' , 'mae' , 'mse' , 'r2' , 'max' ]
230- elif self .dataset_type == 'classification ' :
231- assert self .model_type in ['gpc' , 'svc' ]
254+ elif self .task_type == 'binary ' :
255+ assert self .model_type in ['gpc' , 'svc' , 'gpr' ]
232256 for metric in self .metrics :
233- assert metric in ['roc-auc' , 'accuracy' , 'precision' , 'recall' , 'f1_score' ]
234- else :
257+ assert metric in ['roc-auc' , 'accuracy' , 'precision' , 'recall' , 'f1_score' , 'mcc' ]
258+ elif self . task_type == 'multi-class' :
235259 assert self .model_type in ['gpc' , 'svc' ]
236260 for metric in self .metrics :
237261 assert metric in ['accuracy' , 'precision' , 'recall' , 'f1_score' ]
238262
239- if 'accuracy' in self .metrics :
240- assert self .no_proba
241-
242263 if self .split_type == 'loocv' :
243264 assert self .num_folds == 1
244265 assert self .model_type == 'gpr'
@@ -249,9 +270,6 @@ def process_args(self) -> None:
249270 if self .model_type == 'svc' :
250271 assert self .C is not None
251272
252- if self .split_type == 'loocv' :
253- assert self .dataset_type == 'regression'
254-
255273 if not hasattr (self , 'optimizer' ):
256274 self .optimizer = None
257275 if not hasattr (self , 'batch_size' ):
@@ -262,6 +280,9 @@ def process_args(self) -> None:
262280 assert self .split_sizes [0 ] > 0.99999
263281 assert self .model_type == 'gpr'
264282
283+ if self .ensemble :
284+ assert self .n_sample_per_model is not None
285+
265286
266287class PredictArgs (TrainArgs ):
267288 test_path : str
@@ -280,14 +301,18 @@ class HyperoptArgs(TrainArgs):
280301 """Bounds of alpha used in GPR."""
281302 alpha_uniform : float = None
282303 """"""
283- C_bounds : Tuple [float , float ] = (1e-3 , 1e3 )
304+ C_bounds : Tuple [float , float ] = None # (1e-3, 1e3)
284305 """Bounds of C used in SVC."""
285306 C_uniform : float = None
286307 """"""
287308 optimizer : Literal ['SLSQP' , 'L-BFGS-B' , 'BFGS' , 'fmin_l_bfgs_b' , 'sgd' , 'rmsprop' , 'adam' ] = None
288309 """Optimizer"""
289310 batch_size : int = None
290311 """batch_size"""
312+ num_splits : int = 1
313+ """split the dataset randomly into no. subsets."""
314+ save_all : bool = False
315+ """save all hyperparameters during bayesian optimization."""
291316
292317 @property
293318 def minimize_score (self ) -> bool :
@@ -312,7 +337,6 @@ def opt_C(self) -> bool:
312337
313338 def process_args (self ) -> None :
314339 super ().process_args ()
315- assert self .graph_kernel_type != 'preCalc'
316340 if self .optimizer in ['L-BFGS-B' ]:
317341 assert self .model_type == 'gpr'
318342
0 commit comments