antoniaebner commited on
Commit
97697e0
·
1 Parent(s): b004460

adapt load/saving, preprocessing, app, readme, modelcard

Browse files
Files changed (9) hide show
  1. MODELCARD.md → MODEL_CARD.md +1 -5
  2. README.md +27 -33
  3. app.py +2 -2
  4. config/config.json +5 -3
  5. predict.py +2 -3
  6. src/model.py +12 -7
  7. src/preprocess.py +75 -24
  8. src/utils.py +0 -2
  9. train.py +3 -4
MODELCARD.md → MODEL_CARD.md RENAMED
@@ -1,7 +1,7 @@
1
  # Model card - tox21_rf_classifier
2
  ### Model details
3
  - Model name: Random Forest Tox21 Baseline
4
- - Developer: ML-JKU
5
  - Paper URL: https://link.springer.com/article/10.1023/A:1010933404324
6
  - Model type / architecture:
7
  - Random Forest implemented using sklearn.RandomForestClassifier.
@@ -29,7 +29,3 @@ Tox21 training and validation sets.
29
 
30
  ### Evaluation data
31
  Tox21 test set.
32
-
33
- ### Hugging Face Space details
34
- - Space: MASKED-FOR-REVIEW
35
- - Git commit hash: MASKED-FOR-REVIEW
 
1
  # Model card - tox21_rf_classifier
2
  ### Model details
3
  - Model name: Random Forest Tox21 Baseline
4
+ - Developer: JKU (Linz)
5
  - Paper URL: https://link.springer.com/article/10.1023/A:1010933404324
6
  - Model type / architecture:
7
  - Random Forest implemented using sklearn.RandomForestClassifier.
 
29
 
30
  ### Evaluation data
31
  Tox21 test set.
 
 
 
 
README.md CHANGED
@@ -6,31 +6,39 @@ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  license: cc-by-nc-4.0
9
- short_description: This is a RF classifier for the Tox21 test dataset
10
  ---
11
 
12
  # Tox21 Random Forest Classifier
13
 
14
- This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/ml-jku/tox21_leaderboard).
15
 
 
16
 
17
- Here **Random Forest (RF)** models are trained on the Tox21 dataset, and the trained models are provided for
18
- inference. For each of the twelve toxic effects, a separate RF model is trained. The input to the model
19
- is a **SMILES** string of the small molecule, and the output are 12 numeric values for
20
- each of the toxic effects of the Tox21 dataset.
21
-
 
 
22
 
23
- **Important:** For leaderboard submission, your Space does not need to include training code. It only needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a prediction dictionary as output, with SMILES and targets as keys. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
24
 
25
  # Repository Structure
26
  - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
27
  - `app.py` - FastAPI application wrapper (can be used as-is).
 
 
 
 
 
 
28
 
29
  - `src/` - Core model & preprocessing logic:
30
- - `data.py` - SMILES preprocessing pipeline
31
- - `model.py` - Random Forest classifier wrapper
32
- - `train.py` - Script to train the classifier
33
- - `utils.py` – Constants and Helper functions
34
 
35
  # Quickstart with Spaces
36
 
@@ -44,13 +52,17 @@ You can easily adapt this project in your own Hugging Face account:
44
 
45
  - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
46
 
 
 
 
 
47
  That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
48
 
49
  # Installation
50
  To run (and train) the random forest, clone the repository and install dependencies:
51
 
52
  ```bash
53
- git clone https://huggingface.co/spaces/ml-jku/tox21_rf_classifier
54
  cd tox_21_rf_classifier
55
 
56
  conda create -n tox21_rf_cls python=3.11
@@ -58,22 +70,6 @@ conda activate tox21_rf_cls
58
  pip install -r requirements.txt
59
  ```
60
 
61
- # Training
62
-
63
- To train the Random Forest model from scratch:
64
-
65
- ```bash
66
- python -m src/train.py
67
- ```
68
-
69
- This will:
70
-
71
- 1. Load and preprocess the Tox21 training dataset.
72
- 2. Train a Random Forest classifier.
73
- 3. Save the trained model to the assets/ folder.
74
- 4. Evaluate the trained Random Forest classifier on the validation split.
75
-
76
-
77
  # Inference
78
 
79
  For inference, you only need `predict.py`.
@@ -101,8 +97,6 @@ The output will be a nested dictionary in the format:
101
 
102
  # Notes
103
 
104
- - Only adapting `predict.py` for your model inference is required for leaderboard submission.
105
-
106
- - Training (`src/train.py`) is provided for reproducibility.
107
 
108
- - Preprocessing (here inside `src/data.py`) must be applied at inference time, not just training.
 
6
  sdk: docker
7
  pinned: false
8
  license: cc-by-nc-4.0
9
+ short_description: Random Forest Baseline for Tox21
10
  ---
11
 
12
  # Tox21 Random Forest Classifier
13
 
14
+ This repository hosts a Hugging Face Space that provides an API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/tschouis/tox21_leaderboard).
15
 
16
+ Here **Random Forest (RF)** models are trained on the Tox21 dataset, and the trained models are provided for inference. For each of the twelve toxic effects, a separate RF model is trained. The input to the model is a **SMILES** string of the small molecule, and the output are 12 numeric values for each of the toxic effects of the Tox21 dataset.
17
 
18
+ **Important:** For leaderboard submission, your Space needs to include training code. The file `train.py` should train the model using the config specified inside the `config/` folder and save the final model parameters into a file inside the `checkpoints/` folder. The model should be trained using the [Tox21_dataset](https://huggingface.co/datasets/tschouis/tox21) provided on Hugging Face. The datasets can be loaded like this:
19
+ ```python
20
+ from datasets import load_dataset
21
+ ds = load_dataset("ml-jku/tox21", token=token)
22
+ train_df = ds["train"].to_pandas()
23
+ val_df = ds["validation"].to_pandas()
24
+ ```
25
 
26
+ Additionally, the Space needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a nested prediction dictionary as output, with SMILES as keys and dictionaries containing targetname-prediction pairs as values. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
27
 
28
  # Repository Structure
29
  - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
30
  - `app.py` - FastAPI application wrapper (can be used as-is).
31
+ - `preprocess.py` - preprocesses SMILES strings to generate feature descriptors and saves results as NPZ files in `data/`.
32
+ - `train.py` - trains and saves a model using the config in the `config/` folder.
33
+ - `config/` - the config file used by `train.py`.
34
+ - `logs/` - all the logs of `train.py`, the saved model, and predictions on the validation set.
35
+ - `data/` - RF uses numerical data. During preprocessing in `preprocess.py` two NPZ files containing molecule features are created and saved here.
36
+ - `checkpoints/` - the saved model that is used in `predict.py` is here.
37
 
38
  - `src/` - Core model & preprocessing logic:
39
+ - `preprocess.py` - SMILES preprocessing logic
40
+ - `model.py` - RF model class with processing, saving and loading logic
41
+ - `utils.py` - utility functions
 
42
 
43
  # Quickstart with Spaces
44
 
 
52
 
53
  - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
54
 
55
+ - Modify `train.py` and/or `preprocess.py` according to your model and preprocessing pipeline.
56
+
57
+ - Modify the file inside `config/` to contain all hyperparameters that are set in `train.py`.
58
+
59
  That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
60
 
61
  # Installation
62
  To run (and train) the random forest, clone the repository and install dependencies:
63
 
64
  ```bash
65
+ git clone https://huggingface.co/spaces/tschouis/tox21_rf_classifier
66
  cd tox_21_rf_classifier
67
 
68
  conda create -n tox21_rf_cls python=3.11
 
70
  pip install -r requirements.txt
71
  ```
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # Inference
74
 
75
  For inference, you only need `predict.py`.
 
97
 
98
  # Notes
99
 
100
+ - Adapting `predict.py`, `train.py`, `config/`, and `checkpoints/` is required for leaderboard submission.
 
 
101
 
102
+ - Preprocessing must be done inside `predict.py` not just `train.py`.
app.py CHANGED
@@ -44,7 +44,7 @@ def root():
44
  @app.get("/metadata")
45
  def metadata():
46
  return {
47
- "name": "Tox21RandomForest",
48
  "version": "1.0.0",
49
  "max_batch_size": 256,
50
  "tox_endpoints": [
@@ -74,5 +74,5 @@ def predict(request: Request):
74
  predictions = predict_func(request.smiles)
75
  return {
76
  "predictions": predictions,
77
- "model_info": {"name": "Tox21RandomForest", "version": "1.0.0"},
78
  }
 
44
  @app.get("/metadata")
45
  def metadata():
46
  return {
47
+ "name": "RF",
48
  "version": "1.0.0",
49
  "max_batch_size": 256,
50
  "tox_endpoints": [
 
74
  predictions = predict_func(request.smiles)
75
  return {
76
  "predictions": predictions,
77
+ "model_info": {"name": "RF", "version": "1.0.0"},
78
  }
config/config.json CHANGED
@@ -17,9 +17,11 @@
17
  "use": "true",
18
  "min_var": 0.01,
19
  "max_corr": 0.95,
20
- "feature_keys": ["ecfps", "tox", "maccs", "rdkit_descrs"],
21
- "independent_keys": "false",
22
- "max_features": -1
 
 
23
  },
24
  "feature_quantilization": {
25
  "use": "true",
 
17
  "use": "true",
18
  "min_var": 0.01,
19
  "max_corr": 0.95,
20
+ "max_features": -1,
21
+ "min_var__feature_keys": ["ecfps", "tox", "maccs", "rdkit_descrs"],
22
+ "max_corr__feature_keys": ["ecfps", "tox", "maccs", "rdkit_descrs"],
23
+ "min_var__independent_keys": "false",
24
+ "max_corr__independent_keys": "false"
25
  },
26
  "feature_quantilization": {
27
  "use": "true",
predict.py CHANGED
@@ -57,12 +57,11 @@ def predict(
57
  scaler=config["scaler"],
58
  )
59
 
60
- state = joblib.load(config["ckpt_path"])
61
- model.set_state(state)
62
  print(f"Loaded model from {config['ckpt_path']}")
63
 
64
  state = joblib.load(config["preprocessor_path"])
65
- preprocessor.__setstate__(state)
66
  print(f"Loaded preprocessor from {config['preprocessor_path']}")
67
 
68
  # make predicitons
 
57
  scaler=config["scaler"],
58
  )
59
 
60
+ model.load(config["ckpt_path"])
 
61
  print(f"Loaded model from {config['ckpt_path']}")
62
 
63
  state = joblib.load(config["preprocessor_path"])
64
+ preprocessor.set_state(state)
65
  print(f"Loaded preprocessor from {config['preprocessor_path']}")
66
 
67
  # make predicitons
src/model.py CHANGED
@@ -6,6 +6,7 @@ SMILES and target names as keys.
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
 
9
  import numpy as np
10
  from sklearn.ensemble import RandomForestClassifier
11
 
@@ -33,17 +34,21 @@ class Tox21RFClassifier:
33
  for task in self.tasks
34
  }
35
 
36
- def set_state(self, state: dict) -> None:
37
- """Sets the state of the model
38
 
39
  Args:
40
- state (dict): models state dict
41
  """
42
- self.models = state
43
 
44
- def get_state(self) -> None:
45
- """Return model state dict"""
46
- return self.models
 
 
 
 
47
 
48
  def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
49
  """Train the random forest for a given task
 
6
 
7
  # ---------------------------------------------------------------------------------------
8
  # Dependencies
9
+ import joblib
10
  import numpy as np
11
  from sklearn.ensemble import RandomForestClassifier
12
 
 
34
  for task in self.tasks
35
  }
36
 
37
+ def load(self, path: str) -> None:
38
+ """Load model from filepath
39
 
40
  Args:
41
+ path (str): filepath to model checkpoint
42
  """
43
+ self.models = joblib.load(path)
44
 
45
+ def save(self, path: str) -> None:
46
+ """Save model to filepath
47
+
48
+ Args:
49
+ path (str): filepath to model checkpoint
50
+ """
51
+ joblib.dump(self.models, path)
52
 
53
  def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
54
  """Train the random forest for a given task
src/preprocess.py CHANGED
@@ -78,7 +78,7 @@ class SubSampler(TransformerMixin, BaseEstimator):
78
  _X = X.copy()
79
  _y = y.copy() if y is not None else None
80
 
81
- if self.max_samples > 0:
82
  resample_idxs = np.random.choice(
83
  np.arange(_X.shape[0]), size=(self.max_samples,), replace=True
84
  )
@@ -127,21 +127,32 @@ class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator):
127
  max_corr=1.0,
128
  max_features=-1,
129
  feature_keys=None,
130
- independent_keys=False,
 
 
 
 
 
131
  ):
132
  self.min_var = min_var
133
  self.max_corr = max_corr
134
  self.max_features = max_features
135
- self.independent_keys = independent_keys
136
- self._feature_mask = None
 
 
 
 
 
 
137
 
138
  super().__init__(feature_keys=feature_keys)
139
 
140
- def _get_min_var_feature_mask(self, X: np.ndarray) -> np.ndarray:
141
  var_thresh = VarianceThreshold(threshold=self.min_var)
142
  return var_thresh.fit(X).get_support() # mask
143
 
144
- def _get_max_corr_feature_mask(
145
  self, X: np.ndarray, prev_feature_mask: np.ndarray
146
  ) -> np.ndarray:
147
  _prev_feature_mask = prev_feature_mask.copy()
@@ -156,45 +167,86 @@ class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator):
156
  _prev_feature_mask[_prev_feature_mask] = to_keep
157
  return _prev_feature_mask
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def fit(self, X: dict[str, np.ndarray]):
160
  _X = self.dict_to_array(X)
161
  feature_mask = np.ones((_X.shape[1]), dtype=bool)
162
 
163
  # select features with at least min_var variation
164
  if self.min_var > 0.0:
165
- if self.independent_keys:
166
- for key in self.feature_keys:
167
  key_mask = self._curr_keys == key
168
- subset = _X[:, key_mask]
169
- feature_mask[key_mask] = self._get_min_var_feature_mask(subset)
170
 
171
  else:
172
- feature_mask = self._get_min_var_feature_mask(_X)
 
 
 
173
 
174
  # select features with at least max_var variation
175
  if self.max_corr < 1.0:
176
- if self.independent_keys:
177
- for key in self.feature_keys:
178
  key_mask = self._curr_keys == key
179
  subset = _X[:, key_mask]
180
- feature_mask[key_mask] = self._get_max_corr_feature_mask(
181
  subset, feature_mask[key_mask]
182
  )
183
  else:
184
- feature_mask = self._get_max_corr_feature_mask(_X, feature_mask)
 
 
 
185
 
186
  if self.max_features == 0:
187
  raise ValueError(
188
  f"max_features (={self.max_features}) must be -1 or larger 0."
189
  )
190
  elif self.max_features > 0:
191
- # select features with at least max_var variation
192
- feature_vars = np.nanvar(_X[:, feature_mask], axis=0)
193
- order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1]
194
- keep_feat_idx = np.arange(feature_mask)[order]
195
- feature_mask = np.isin(
196
- np.arange(feature_mask), keep_feat_idx, assume_unique=True
197
- )
 
 
 
 
 
 
198
 
199
  self._feature_mask = feature_mask
200
  self.is_fitted_ = True
@@ -278,6 +330,7 @@ class FeaturePreprocessor(TransformerMixin, BaseEstimator):
278
 
279
  self.feature_selection_config = copy.deepcopy(feature_selection_config)
280
  self.use_feat_selec = self.feature_selection_config.pop("use")
 
281
  self.feature_selector = FeatureSelector(**self.feature_selection_config)
282
 
283
  self.max_samples = max_samples
@@ -330,10 +383,8 @@ class FeaturePreprocessor(TransformerMixin, BaseEstimator):
330
 
331
  if self.use_feat_quant:
332
  _X = self.quantile_creator.transform(_X)
333
-
334
  if self.use_feat_selec:
335
  _X = self.feature_selector.transform(_X)
336
-
337
  _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
338
  _X = self.scaler.transform(_X)
339
 
 
78
  _X = X.copy()
79
  _y = y.copy() if y is not None else None
80
 
81
+ if self.max_samples > 0 and _X.shape[0] > self.max_samples:
82
  resample_idxs = np.random.choice(
83
  np.arange(_X.shape[0]), size=(self.max_samples,), replace=True
84
  )
 
127
  max_corr=1.0,
128
  max_features=-1,
129
  feature_keys=None,
130
+ min_var__feature_keys=None,
131
+ max_corr__feature_keys=None,
132
+ max_features__feature_keys=None,
133
+ min_var__independent_keys=False,
134
+ max_corr__independent_keys=False,
135
+ max_features__independent_keys=False,
136
  ):
137
  self.min_var = min_var
138
  self.max_corr = max_corr
139
  self.max_features = max_features
140
+
141
+ self.min_var__feature_keys = min_var__feature_keys
142
+ self.max_corr__feature_keys = max_corr__feature_keys
143
+ self.max_features__feature_keys = max_features__feature_keys
144
+
145
+ self.min_var__independent_keys = min_var__independent_keys
146
+ self.max_corr__independent_keys = max_corr__independent_keys
147
+ self.max_features__independent_keys = max_features__independent_keys
148
 
149
  super().__init__(feature_keys=feature_keys)
150
 
151
+ def _get_min_var_mask(self, X: np.ndarray, *args) -> np.ndarray:
152
  var_thresh = VarianceThreshold(threshold=self.min_var)
153
  return var_thresh.fit(X).get_support() # mask
154
 
155
+ def _get_max_corr_mask(
156
  self, X: np.ndarray, prev_feature_mask: np.ndarray
157
  ) -> np.ndarray:
158
  _prev_feature_mask = prev_feature_mask.copy()
 
167
  _prev_feature_mask[_prev_feature_mask] = to_keep
168
  return _prev_feature_mask
169
 
170
+ def _get_max_features_mask(
171
+ self, X: np.ndarray, prev_feature_mask: np.ndarray
172
+ ) -> np.ndarray:
173
+ _prev_feature_mask = prev_feature_mask.copy()
174
+ # select features with at least max_var variation
175
+ feature_vars = np.nanvar(X[:, _prev_feature_mask], axis=0)
176
+ order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1]
177
+ keep_feat_idx = np.arange(len(_prev_feature_mask))[order]
178
+ _prev_feature_mask = np.isin(
179
+ np.arange(len(_prev_feature_mask)), keep_feat_idx, assume_unique=True
180
+ )
181
+ return _prev_feature_mask
182
+
183
+ def apply_filter(self, filter, X, prev_feature_mask):
184
+ mask = prev_feature_mask.copy()
185
+ func = self.__getattribute__(f"_get_{filter}_mask")
186
+ feature_keys = self.__getattribute__(f"{filter}__feature_keys")
187
+
188
+ if self.__getattribute__(f"{filter}__independent_keys"):
189
+ for key in feature_keys:
190
+ key_mask = self._curr_keys == key
191
+ mask[key_mask] = func(X[:, key_mask], mask[key_mask])
192
+
193
+ else:
194
+ feature_key_mask = np.isin(self._curr_keys, feature_keys)
195
+ mask[feature_key_mask] = func(
196
+ X[:, feature_key_mask], mask[feature_key_mask]
197
+ )
198
+ return mask
199
+
200
  def fit(self, X: dict[str, np.ndarray]):
201
  _X = self.dict_to_array(X)
202
  feature_mask = np.ones((_X.shape[1]), dtype=bool)
203
 
204
  # select features with at least min_var variation
205
  if self.min_var > 0.0:
206
+ if self.min_var__independent_keys:
207
+ for key in self.min_var__feature_keys:
208
  key_mask = self._curr_keys == key
209
+ feature_mask[key_mask] = self._get_min_var_mask(_X[:, key_mask])
 
210
 
211
  else:
212
+ feature_key_mask = np.isin(self._curr_keys, self.min_var__feature_keys)
213
+ feature_mask[feature_key_mask] = self._get_min_var_mask(
214
+ _X[:, feature_key_mask]
215
+ )
216
 
217
  # select features with at least max_var variation
218
  if self.max_corr < 1.0:
219
+ if self.max_corr__independent_keys:
220
+ for key in self.max_corr__feature_keys:
221
  key_mask = self._curr_keys == key
222
  subset = _X[:, key_mask]
223
+ feature_mask[key_mask] = self._get_max_corr_mask(
224
  subset, feature_mask[key_mask]
225
  )
226
  else:
227
+ feature_key_mask = np.isin(self._curr_keys, self.max_corr__feature_keys)
228
+ feature_mask[feature_key_mask] = self._get_max_corr_mask(
229
+ _X[:, feature_key_mask], feature_mask[feature_key_mask]
230
+ )
231
 
232
  if self.max_features == 0:
233
  raise ValueError(
234
  f"max_features (={self.max_features}) must be -1 or larger 0."
235
  )
236
  elif self.max_features > 0:
237
+ if self.max_features__independent_keys:
238
+ for key in self.max_features__feature_keys:
239
+ key_mask = self._curr_keys == key
240
+ feature_mask[key_mask] = self._get_max_features_mask(
241
+ _X[:, key_mask], feature_mask[key_mask]
242
+ )
243
+ else:
244
+ feature_key_mask = np.isin(
245
+ self._curr_keys, self.max_features__feature_keys
246
+ )
247
+ feature_mask[feature_key_mask] = self._get_max_features_mask(
248
+ _X[:, feature_key_mask], feature_mask[feature_key_mask]
249
+ )
250
 
251
  self._feature_mask = feature_mask
252
  self.is_fitted_ = True
 
330
 
331
  self.feature_selection_config = copy.deepcopy(feature_selection_config)
332
  self.use_feat_selec = self.feature_selection_config.pop("use")
333
+ self.feature_selection_config["feature_keys"] = descriptors
334
  self.feature_selector = FeatureSelector(**self.feature_selection_config)
335
 
336
  self.max_samples = max_samples
 
383
 
384
  if self.use_feat_quant:
385
  _X = self.quantile_creator.transform(_X)
 
386
  if self.use_feat_selec:
387
  _X = self.feature_selector.transform(_X)
 
388
  _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
389
  _X = self.scaler.transform(_X)
390
 
src/utils.py CHANGED
@@ -32,8 +32,6 @@ TASKS = [
32
  "SR-p53",
33
  ]
34
 
35
- KNOWN_DESCR = ["ecfps", "tox", "maccs", "rdkit_descrs"]
36
-
37
  USED_200_DESCR = [
38
  0,
39
  1,
 
32
  "SR-p53",
33
  ]
34
 
 
 
35
  USED_200_DESCR = [
36
  0,
37
  1,
train.py CHANGED
@@ -46,10 +46,10 @@ def main(config):
46
  )
47
 
48
  logger.info(f"Config: {config}")
49
- model_config_repr = "Model configs: \n" + "\n".join(
50
  [str(val) for val in config["model_config"].values()]
51
  )
52
- logger.info(f"Model configs: \n{model_config_repr}")
53
 
54
  # seeding
55
  random.seed(config["seed"])
@@ -111,8 +111,7 @@ def main(config):
111
  logger.info(log_text)
112
 
113
  if config["ckpt_path"]:
114
- state = model.get_state()
115
- joblib.dump(state, config["ckpt_path"])
116
  logger.info(f"Save model as: {config['ckpt_path']}")
117
 
118
  if config["preprocessor_path"]:
 
46
  )
47
 
48
  logger.info(f"Config: {config}")
49
+ model_config_repr = "Model config: \n" + "\n".join(
50
  [str(val) for val in config["model_config"].values()]
51
  )
52
+ logger.info(f"Model config: \n{model_config_repr}")
53
 
54
  # seeding
55
  random.seed(config["seed"])
 
111
  logger.info(log_text)
112
 
113
  if config["ckpt_path"]:
114
+ model.save(config["ckpt_path"])
 
115
  logger.info(f"Save model as: {config['ckpt_path']}")
116
 
117
  if config["preprocessor_path"]: