|
|
import os |
|
|
import GPUtil |
|
|
import torch |
|
|
import sys |
|
|
import hydra |
|
|
import wandb |
|
|
|
|
|
|
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
|
|
from pytorch_lightning.loggers.wandb import WandbLogger |
|
|
from pytorch_lightning.trainer import Trainer |
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from data.pdb_dataloader import PdbDataModule |
|
|
from models.flow_module import FlowModule |
|
|
from experiments import utils as eu |
|
|
|
|
|
|
|
|
os.environ["WANDB_MODE"] = "offline" |
|
|
log = eu.get_pylogger(__name__) |
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
|
class Experiment: |
|
|
|
|
|
def __init__(self, *, cfg: DictConfig): |
|
|
self._cfg = cfg |
|
|
self._data_cfg = cfg.data |
|
|
self._exp_cfg = cfg.experiment |
|
|
self._datamodule: LightningDataModule = PdbDataModule(self._data_cfg) |
|
|
self._model: LightningModule = FlowModule(self._cfg) |
|
|
|
|
|
def train(self): |
|
|
callbacks = [] |
|
|
if self._exp_cfg.debug: |
|
|
log.info("Debug mode.") |
|
|
logger = None |
|
|
self._exp_cfg.num_devices = 1 |
|
|
self._data_cfg.loader.num_workers = 0 |
|
|
else: |
|
|
logger = WandbLogger( |
|
|
**self._exp_cfg.wandb, |
|
|
) |
|
|
|
|
|
|
|
|
ckpt_dir = self._exp_cfg.checkpointer.dirpath |
|
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
log.info(f"Checkpoints saved to {ckpt_dir}") |
|
|
|
|
|
|
|
|
callbacks.append(ModelCheckpoint(**self._exp_cfg.checkpointer)) |
|
|
|
|
|
|
|
|
cfg_path = os.path.join(ckpt_dir, 'config.yaml') |
|
|
with open(cfg_path, 'w') as f: |
|
|
OmegaConf.save(config=self._cfg, f=f.name) |
|
|
cfg_dict = OmegaConf.to_container(self._cfg, resolve=True) |
|
|
flat_cfg = dict(eu.flatten_dict(cfg_dict)) |
|
|
if isinstance(logger.experiment.config, wandb.sdk.wandb_config.Config): |
|
|
logger.experiment.config.update(flat_cfg) |
|
|
|
|
|
devices = GPUtil.getAvailable(order='memory', limit = 8)[:self._exp_cfg.num_devices] |
|
|
log.info(f"Using devices: {devices}") |
|
|
trainer = Trainer( |
|
|
**self._exp_cfg.trainer, |
|
|
callbacks=callbacks, |
|
|
logger=logger, |
|
|
use_distributed_sampler=False, |
|
|
enable_progress_bar=True, |
|
|
enable_model_summary=True, |
|
|
devices=devices, |
|
|
) |
|
|
|
|
|
if self._exp_cfg.warm_start is not None: |
|
|
|
|
|
self._model = self._model.load_from_checkpoint(self._exp_cfg.warm_start, strict=False, map_location="cpu") |
|
|
|
|
|
trainer.fit( |
|
|
model=self._model, |
|
|
datamodule=self._datamodule, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="base.yaml") |
|
|
def main(cfg: DictConfig): |
|
|
|
|
|
if cfg.experiment.warm_start is not None and cfg.experiment.warm_start_cfg_override: |
|
|
|
|
|
warm_start_cfg_path = os.path.join( |
|
|
os.path.dirname(cfg.experiment.warm_start), 'config.yaml') |
|
|
warm_start_cfg = OmegaConf.load(warm_start_cfg_path) |
|
|
|
|
|
|
|
|
|
|
|
OmegaConf.set_struct(cfg.model, False) |
|
|
OmegaConf.set_struct(warm_start_cfg.model, False) |
|
|
cfg.model = OmegaConf.merge(cfg.model, warm_start_cfg.model) |
|
|
OmegaConf.set_struct(cfg.model, True) |
|
|
log.info(f'Loaded warm start config from {warm_start_cfg_path}') |
|
|
|
|
|
exp = Experiment(cfg=cfg) |
|
|
exp.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|