Skip to content

Commit e66ca94

Browse files
authored
005 (#17)
* add feature: atomic attribution for cross-validation. * bugfix of random forest. * update for mgktools==0.0.5
1 parent ec63dd5 commit e66ca94

File tree

7 files changed

+50
-41
lines changed

7 files changed

+50
-41
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ temperature, pressure, etc.
1717
</div>
1818

1919
## Installation
20-
GCC (7.*), NVIDIA Driver and CUDA toolkit(>=10.1).
20+
GCC (7.*), NVIDIA Driver and CUDA toolkit(>=10.1).
21+
Python 3.10 is suggested.
2122
```
22-
conda env create -f environment.yml
23-
conda activate graphdot
23+
pip install numpy==1.22.3 git+https://gitlab.com/Xiangyan93/graphdot.git@feature/xy git+https://github.com/bp-kelley/descriptastorus typed-argument-parser mgktools
2424
```
2525
For some combinations of GCC and CUDA, only old version of pycuda works```pip install pycuda==2020.1```
2626

chemml/args.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ class TrainArgs(KernelArgs):
215215
"""Save the trained model file."""
216216
separate_test_path: str = None
217217
"""Path to separate test set, optional."""
218+
atomic_attribution: bool = False
219+
"""Output interpretability."""
218220

219221
@property
220222
def metrics(self) -> List[Metric]:
@@ -283,6 +285,11 @@ def process_args(self) -> None:
283285
if self.ensemble:
284286
assert self.n_sample_per_model is not None
285287

288+
if self.atomic_attribution:
289+
assert self.graph_kernel_type == 'graph', 'Set graph_kernel_type to graph for interpretability'
290+
assert self.model_type == 'gpr', 'Set model_type to gpr for interpretability'
291+
assert self.ensemble is False
292+
286293

287294
class PredictArgs(TrainArgs):
288295
test_path: str
@@ -382,7 +389,7 @@ def process_args(self) -> None:
382389
self.cluster_size = self.add_size
383390
assert self.initial_size >= 2
384391
if self.surrogate_kernel is not None:
385-
assert self.graph_kernel_type == 'preCalc'
392+
assert self.graph_kernel_type == 'pre-computed'
386393

387394
if self.stop_uncertainty is None:
388395
self.stop_uncertainty = [-1.0]

chemml/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,26 @@
22
# -*- coding: utf-8 -*-
33
from .args import TrainArgs
44
from mgktools.models import GPR, GPC, LRAGPR, NLEGPR, SVC, SVR, ConsensusRegressor
5+
from mgktools.interpret.gpr import InterpretableGaussianProcessRegressor as IGPR
56

67

78
def set_model(args: TrainArgs,
89
kernel):
910
if args.model_type == 'gpr':
10-
model = GPR(
11-
kernel=kernel,
12-
optimizer=args.optimizer,
13-
alpha=args.alpha_,
14-
normalize_y=True,
15-
)
11+
if args.atomic_attribution:
12+
model = IGPR(
13+
kernel=kernel,
14+
optimizer=args.optimizer,
15+
alpha=args.alpha_,
16+
normalize_y=False,
17+
)
18+
else:
19+
model = GPR(
20+
kernel=kernel,
21+
optimizer=args.optimizer,
22+
alpha=args.alpha_,
23+
normalize_y=True,
24+
)
1625
if args.ensemble:
1726
model = ConsensusRegressor(
1827
model,

environment.yml

Lines changed: 0 additions & 19 deletions
This file was deleted.

run/ModelEvaluate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,16 @@ def main(args: TrainArgs) -> None:
5959
n_similar=None,
6060
kernel=None,
6161
n_core=args.n_core,
62+
atomic_attribution=args.atomic_attribution,
6263
seed=args.seed,
6364
verbose=True)
6465

6566
if args.separate_test_path is not None and args.target_columns is None:
6667
evaluator.fit(X=dataset.X, y=dataset.y)
6768
evaluator.predict(X=dataset_test.X, y=None, repr=dataset_test.repr.ravel()).to_csv(
6869
'%s/pred_ext.csv' % args.save_dir, sep='\t', index=False, float_format='%15.10f')
70+
if args.atomic_attribution:
71+
evaluator.interpret(dataset_test=dataset_test, output_tag='ext')
6972
else:
7073
evaluator.evaluate(external_test_dataset=dataset_test)
7174

run/RandomForest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class RandomForestArgs(Tap):
4747
"""
4848
Type of task. This determines the loss function used during training.
4949
"""
50-
split_type: Literal['random', 'scaffold_balanced', 'loocv'] = None
50+
split_type: Literal['random', 'scaffold_order', 'scaffold_random', 'loocv'] = None
5151
"""Method of splitting the data into train/val/test."""
52-
split_sizes: Tuple[float, float] = (0.8, 0.2)
52+
split_sizes: List[float] = [0.8, 0.2]
5353
"""Split proportions for train/validation/test sets."""
5454
num_folds: int = 1
5555
"""Number of folds when performing cross validation."""
@@ -100,9 +100,9 @@ def main(args: RandomForestArgs) -> None:
100100
else:
101101
dataset_test = None
102102
if args.task_type == 'regression':
103-
model = RandomForestRegressor()
103+
model = RandomForestRegressor(random_state=args.seed)
104104
else:
105-
model = RFClassifier()
105+
model = RFClassifier(random_state=args.seed)
106106
Evaluator(save_dir=args.save_dir,
107107
dataset=dataset,
108108
model=model,

test/test_a_read_data.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
CWD = os.path.dirname(os.path.abspath(__file__))
66
import sys
7+
import shutil
78
sys.path.append('%s/..' % CWD)
89
from chemml.args import CommonArgs
910
from run.ReadData import main
@@ -18,7 +19,8 @@
1819
def test_ReadData_PureGraph(dataset):
1920
dataset, pure_columns, target_columns = dataset
2021
save_dir = '%s/data/_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns))
21-
assert not os.path.exists(save_dir)
22+
if os.path.exists(save_dir):
23+
shutil.rmtree(save_dir)
2224
arguments = [
2325
'--save_dir', '%s' % save_dir,
2426
'--data_path', '%s/data/%s.csv' % (CWD, dataset),
@@ -40,7 +42,8 @@ def test_ReadData_PureGraph_FeaturesAdd(dataset, group_reading, features_scaling
4042
dataset, pure_columns, target_columns, features_columns = dataset
4143
save_dir = '%s/data/_%s_%s_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns),
4244
group_reading, features_scaling)
43-
assert not os.path.exists(save_dir)
45+
if os.path.exists(save_dir):
46+
shutil.rmtree(save_dir)
4447
arguments = [
4548
'--save_dir', '%s' % save_dir,
4649
'--data_path', '%s/data/%s.csv' % (CWD, dataset),
@@ -76,7 +79,8 @@ def test_ReadData_PureGraph_FeaturesMol(dataset, features_generator, features_sc
7679
dataset, pure_columns, target_columns = dataset
7780
save_dir = '%s/data/_%s_%s_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns),
7881
','.join(features_generator), features_scaling)
79-
assert not os.path.exists(save_dir)
82+
if os.path.exists(save_dir):
83+
shutil.rmtree(save_dir)
8084
arguments = [
8185
'--save_dir', '%s' % save_dir,
8286
'--data_path', '%s/data/%s.csv' % (CWD, dataset),
@@ -106,7 +110,8 @@ def test_ReadData_PureGraph_FeaturesAddMol(dataset, group_reading, features_gene
106110
dataset, pure_columns, target_columns, features_columns = dataset
107111
save_dir = '%s/data/_%s_%s_%s_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns),
108112
group_reading, ','.join(features_generator), features_scaling)
109-
assert not os.path.exists(save_dir)
113+
if os.path.exists(save_dir):
114+
shutil.rmtree(save_dir)
110115
arguments = [
111116
'--save_dir', '%s' % save_dir,
112117
'--data_path', '%s/data/%s.csv' % (CWD, dataset),
@@ -137,7 +142,8 @@ def test_ReadData_PureGraph_FeaturesAddMol(dataset, group_reading, features_gene
137142
def test_ReadData_MixtureGraph(dataset):
138143
dataset, pure_columns, target_columns = dataset
139144
save_dir = '%s/data/_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns))
140-
assert not os.path.exists(save_dir)
145+
if os.path.exists(save_dir):
146+
shutil.rmtree(save_dir)
141147
arguments = [
142148
'--save_dir', '%s' % save_dir,
143149
'--data_path', '%s/data/%s.csv' % (CWD, dataset),
@@ -159,7 +165,8 @@ def test_ReadData_MixtureGraph_FeaturesAdd(dataset, group_reading, features_scal
159165
dataset, pure_columns, target_columns, features_columns = dataset
160166
save_dir = '%s/data/_%s_%s_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns),
161167
group_reading, features_scaling)
162-
assert not os.path.exists(save_dir)
168+
if os.path.exists(save_dir):
169+
shutil.rmtree(save_dir)
163170
arguments = [
164171
'--save_dir', '%s' % save_dir,
165172
'--data_path', '%s/data/%s.csv' % (CWD, dataset),
@@ -195,7 +202,8 @@ def test_ReadData_MixtureGraph_FeaturesMol(dataset, features_generator, features
195202
dataset, pure_columns, target_columns = dataset
196203
save_dir = '%s/data/_%s_%s_%s_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns),
197204
','.join(features_generator), features_combination, features_scaling)
198-
assert not os.path.exists(save_dir)
205+
if os.path.exists(save_dir):
206+
shutil.rmtree(save_dir)
199207
arguments = [
200208
'--save_dir', '%s' % save_dir,
201209
'--data_path', '%s/data/%s.csv' % (CWD, dataset),
@@ -229,7 +237,8 @@ def test_ReadData_MixtureGraph_FeaturesMolAdd(dataset, group_reading, features_g
229237
save_dir = '%s/data/_%s_%s_%s_%s_%s_%s_%s' % (CWD, dataset, ','.join(pure_columns), ','.join(target_columns),
230238
group_reading, ','.join(features_generator), features_combination,
231239
features_scaling)
232-
assert not os.path.exists(save_dir)
240+
if os.path.exists(save_dir):
241+
shutil.rmtree(save_dir)
233242
arguments = [
234243
'--save_dir', '%s' % save_dir,
235244
'--data_path', '%s/data/%s.csv' % (CWD, dataset),

0 commit comments

Comments
 (0)