Files
b2txt25/language_model/tools/compute_cmvn_stats.py
2025-07-02 12:18:09 -07:00

138 lines
4.9 KiB
Python
Executable File

#!/usr/bin/env python3
# encoding: utf-8
import sys
import argparse
import json
import codecs
import yaml
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.utils.data import Dataset, DataLoader
torchaudio.set_audio_backend("sox_io")
class CollateFunc(object):
''' Collate function for AudioDataset
'''
def __init__(self, feat_dim, resample_rate):
self.feat_dim = feat_dim
self.resample_rate = resample_rate
pass
def __call__(self, batch):
mean_stat = torch.zeros(self.feat_dim)
var_stat = torch.zeros(self.feat_dim)
number = 0
for item in batch:
value = item[1].strip().split(",")
assert len(value) == 3 or len(value) == 1
wav_path = value[0]
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
resample_rate = sample_rate
# len(value) == 3 means segmented wav.scp,
# len(value) == 1 means original wav.scp
if len(value) == 3:
start_frame = int(float(value[1]) * sample_rate)
end_frame = int(float(value[2]) * sample_rate)
waveform, sample_rate = torchaudio.backend.sox_io_backend.load(
filepath=wav_path,
num_frames=end_frame - start_frame,
frame_offset=start_frame)
else:
waveform, sample_rate = torchaudio.load(item[1])
waveform = waveform * (1 << 15)
if self.resample_rate != 0 and self.resample_rate != sample_rate:
resample_rate = self.resample_rate
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
mat = kaldi.fbank(waveform,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate)
mean_stat += torch.sum(mat, axis=0)
var_stat += torch.sum(torch.square(mat), axis=0)
number += mat.shape[0]
return number, mean_stat, var_stat
class AudioDataset(Dataset):
def __init__(self, data_file):
self.items = []
with codecs.open(data_file, 'r', encoding='utf-8') as f:
for line in f:
arr = line.strip().split()
self.items.append((arr[0], arr[1]))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.items[idx]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='extract CMVN stats')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for processing')
parser.add_argument('--train_config',
default='',
help='training yaml conf')
parser.add_argument('--in_scp', default=None, help='wav scp file')
parser.add_argument('--out_cmvn',
default='global_cmvn',
help='global cmvn file')
args = parser.parse_args()
with open(args.train_config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = configs['collate_conf']['feature_extraction_conf']['mel_bins']
resample_rate = 0
if 'resample' in configs['collate_conf']['feature_extraction_conf']:
resample_rate = configs['collate_conf']['feature_extraction_conf']['resample']
print('using resample and new sample rate is {}'.format(resample_rate))
collate_func = CollateFunc(feat_dim, resample_rate)
dataset = AudioDataset(args.in_scp)
batch_size = 20
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
num_workers=args.num_workers,
collate_fn=collate_func)
with torch.no_grad():
all_number = 0
all_mean_stat = torch.zeros(feat_dim)
all_var_stat = torch.zeros(feat_dim)
wav_number = 0
for i, batch in enumerate(data_loader):
number, mean_stat, var_stat = batch
all_mean_stat += mean_stat
all_var_stat += var_stat
all_number += number
wav_number += batch_size
if wav_number % 1000 == 0:
print(f'processed {wav_number} wavs, {all_number} frames',
file=sys.stderr,
flush=True)
cmvn_info = {
'mean_stat': list(all_mean_stat.tolist()),
'var_stat': list(all_var_stat.tolist()),
'frame_num': all_number
}
with open(args.out_cmvn, 'w') as fout:
fout.write(json.dumps(cmvn_info))