# Copyright 2019 Mobvoi Inc. All Rights Reserved. # Author: binbinzhang@mobvoi.com (Binbin Zhang) import logging import os import re import yaml import torch def load_checkpoint(model: torch.nn.Module, path: str) -> dict: if torch.cuda.is_available(): logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) checkpoint = torch.load(path) else: logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) checkpoint = torch.load(path, map_location='cpu') model.load_state_dict(checkpoint) info_path = re.sub('.pt$', '.yaml', path) configs = {} if os.path.exists(info_path): with open(info_path, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) return configs def save_checkpoint(model: torch.nn.Module, path: str, infos=None): ''' Args: infos (dict or None): any info you want to save. ''' logging.info('Checkpoint: save to checkpoint %s' % path) if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() elif isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save(state_dict, path) info_path = re.sub('.pt$', '.yaml', path) if infos is None: infos = {} with open(info_path, 'w') as fout: data = yaml.dump(infos) fout.write(data)