Spaces:
Sleeping
Sleeping
Commit
·
136540f
1
Parent(s):
3fd3838
cleanup and train without train/val overlapping samples
Browse files- predict.py +1 -16
- train.py +19 -16
predict.py
CHANGED
|
@@ -47,7 +47,7 @@ def predict(
|
|
| 47 |
model_path = "checkpoints/rf_alltasks.joblib"
|
| 48 |
model.load_model(model_path)
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
# make predicitons
|
| 53 |
predictions = defaultdict(dict)
|
|
@@ -59,21 +59,6 @@ def predict(
|
|
| 59 |
preds = np.empty_like(is_clean, dtype=np.float64)
|
| 60 |
|
| 61 |
preds[~is_clean] = default_prediction
|
| 62 |
-
# selected_feat = X[:, rdkit_desc_idx].copy()
|
| 63 |
-
# quantiles = np.zeros_like(selected_feat)
|
| 64 |
-
|
| 65 |
-
# for column in range(selected_feat.shape[1]):
|
| 66 |
-
# raw_values = selected_feat[:, column].reshape(-1)
|
| 67 |
-
# ecdf = ecdfs[target][column]
|
| 68 |
-
# q = ecdf(raw_values)
|
| 69 |
-
# quantiles[:, column] = q
|
| 70 |
-
|
| 71 |
-
# X[:, rdkit_desc_idx] = quantiles
|
| 72 |
-
# X = X[:, feat_selec[target]]
|
| 73 |
-
|
| 74 |
-
# X = scalers[target].transform(X)
|
| 75 |
-
|
| 76 |
-
# preds[is_clean] = model[target].predict_proba(X)[:, 1]
|
| 77 |
preds[is_clean] = model.predict(target, X)
|
| 78 |
|
| 79 |
for smiles, pred in zip(smiles_list, preds):
|
|
|
|
| 47 |
model_path = "checkpoints/rf_alltasks.joblib"
|
| 48 |
model.load_model(model_path)
|
| 49 |
|
| 50 |
+
print(f"Loaded model from {model_path}")
|
| 51 |
|
| 52 |
# make predicitons
|
| 53 |
predictions = defaultdict(dict)
|
|
|
|
| 59 |
preds = np.empty_like(is_clean, dtype=np.float64)
|
| 60 |
|
| 61 |
preds[~is_clean] = default_prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
preds[is_clean] = model.predict(target, X)
|
| 63 |
|
| 64 |
for smiles, pred in zip(smiles_list, preds):
|
train.py
CHANGED
|
@@ -3,6 +3,7 @@ Script for fitting and saving any preprocessing assets, as well as the fitted RF
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
|
|
| 6 |
import logging
|
| 7 |
import argparse
|
| 8 |
|
|
@@ -44,8 +45,6 @@ parser.add_argument(
|
|
| 44 |
|
| 45 |
ECFP_RADIUS = 3
|
| 46 |
ECFP_FPSIZE = 8192
|
| 47 |
-
FEATURE_SELECTION_PATH = "data/feat_selection.npz"
|
| 48 |
-
ECDFS_PATH = "data/ecdfs.pkl"
|
| 49 |
|
| 50 |
task_config = {
|
| 51 |
"NR-AR": {
|
|
@@ -158,19 +157,23 @@ def main(args):
|
|
| 158 |
logger.info(args)
|
| 159 |
|
| 160 |
# seeding
|
| 161 |
-
|
| 162 |
-
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
|
| 175 |
# # remove molecules that couldn't be sanitized
|
| 176 |
# mask = ~np.isnan(train_X).any(axis=1)
|
|
@@ -178,13 +181,13 @@ def main(args):
|
|
| 178 |
# train_y = train_y[mask]
|
| 179 |
|
| 180 |
full_data = np.load(
|
| 181 |
-
|
| 182 |
allow_pickle=True,
|
| 183 |
)
|
| 184 |
|
| 185 |
-
train_val_mask = full_data["sets"] != "test"
|
| 186 |
-
data = full_data["features"][train_val_mask]
|
| 187 |
-
labels = full_data["labels"][train_val_mask]
|
| 188 |
print("Train data shape:", data.shape)
|
| 189 |
|
| 190 |
test_mask = full_data["sets"] == "test"
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
| 6 |
+
import random
|
| 7 |
import logging
|
| 8 |
import argparse
|
| 9 |
|
|
|
|
| 45 |
|
| 46 |
ECFP_RADIUS = 3
|
| 47 |
ECFP_FPSIZE = 8192
|
|
|
|
|
|
|
| 48 |
|
| 49 |
task_config = {
|
| 50 |
"NR-AR": {
|
|
|
|
| 157 |
logger.info(args)
|
| 158 |
|
| 159 |
# seeding
|
| 160 |
+
random.seed(args.seed)
|
| 161 |
+
np.random.seed(args.seed)
|
| 162 |
|
| 163 |
+
train_data = np.load(os.path.join(args.data_folder, "tox21_train_cv4.npz"))
|
| 164 |
+
train_X = train_data[
|
| 165 |
+
"features"
|
| 166 |
+
] # np.concatenate([train_data[descr] for descr in KNOWN_DESCR], axis=1)
|
| 167 |
+
train_y = train_data["labels"]
|
| 168 |
|
| 169 |
+
val_data = np.load(os.path.join(args.data_folder, "tox21_validation_cv4.npz"))
|
| 170 |
+
val_X = val_data[
|
| 171 |
+
"features"
|
| 172 |
+
] # np.concatenate([val_data[descr] for descr in KNOWN_DESCR], axis=1)
|
| 173 |
+
val_y = val_data["labels"]
|
| 174 |
|
| 175 |
+
data = np.concatenate([train_X, val_X], axis=0)
|
| 176 |
+
labels = np.concatenate([train_y, val_y], axis=0)
|
| 177 |
|
| 178 |
# # remove molecules that couldn't be sanitized
|
| 179 |
# mask = ~np.isnan(train_X).any(axis=1)
|
|
|
|
| 181 |
# train_y = train_y[mask]
|
| 182 |
|
| 183 |
full_data = np.load(
|
| 184 |
+
"data/tox21_descriptors.npz",
|
| 185 |
allow_pickle=True,
|
| 186 |
)
|
| 187 |
|
| 188 |
+
# train_val_mask = full_data["sets"] != "test"
|
| 189 |
+
# data = full_data["features"][train_val_mask]
|
| 190 |
+
# labels = full_data["labels"][train_val_mask]
|
| 191 |
print("Train data shape:", data.shape)
|
| 192 |
|
| 193 |
test_mask = full_data["sets"] == "test"
|