Files
b2txt25/language_model/wenet/utils/checkpoint.py
2025-07-02 12:18:09 -07:00

47 lines
1.4 KiB
Python

# Copyright 2019 Mobvoi Inc. All Rights Reserved.
# Author: binbinzhang@mobvoi.com (Binbin Zhang)
import logging
import os
import re
import yaml
import torch
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
if torch.cuda.is_available():
logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
checkpoint = torch.load(path)
else:
logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
checkpoint = torch.load(path, map_location='cpu')
model.load_state_dict(checkpoint)
info_path = re.sub('.pt$', '.yaml', path)
configs = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
return configs
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
'''
Args:
infos (dict or None): any info you want to save.
'''
logging.info('Checkpoint: save to checkpoint %s' % path)
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)