47 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			47 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | # Copyright 2019 Mobvoi Inc. All Rights Reserved. | ||
|  | # Author: binbinzhang@mobvoi.com (Binbin Zhang) | ||
|  | 
 | ||
|  | import logging | ||
|  | import os | ||
|  | import re | ||
|  | 
 | ||
|  | import yaml | ||
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | def load_checkpoint(model: torch.nn.Module, path: str) -> dict: | ||
|  |     if torch.cuda.is_available(): | ||
|  |         logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) | ||
|  |         checkpoint = torch.load(path) | ||
|  |     else: | ||
|  |         logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) | ||
|  |         checkpoint = torch.load(path, map_location='cpu') | ||
|  |     model.load_state_dict(checkpoint) | ||
|  |     info_path = re.sub('.pt$', '.yaml', path) | ||
|  |     configs = {} | ||
|  |     if os.path.exists(info_path): | ||
|  |         with open(info_path, 'r') as fin: | ||
|  |             configs = yaml.load(fin, Loader=yaml.FullLoader) | ||
|  |     return configs | ||
|  | 
 | ||
|  | 
 | ||
|  | def save_checkpoint(model: torch.nn.Module, path: str, infos=None): | ||
|  |     '''
 | ||
|  |     Args: | ||
|  |         infos (dict or None): any info you want to save. | ||
|  |     '''
 | ||
|  |     logging.info('Checkpoint: save to checkpoint %s' % path) | ||
|  |     if isinstance(model, torch.nn.DataParallel): | ||
|  |         state_dict = model.module.state_dict() | ||
|  |     elif isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||
|  |         state_dict = model.module.state_dict() | ||
|  |     else: | ||
|  |         state_dict = model.state_dict() | ||
|  |     torch.save(state_dict, path) | ||
|  |     info_path = re.sub('.pt$', '.yaml', path) | ||
|  |     if infos is None: | ||
|  |         infos = {} | ||
|  |     with open(info_path, 'w') as fout: | ||
|  |         data = yaml.dump(infos) | ||
|  |         fout.write(data) |