Files
b2txt25/data_analyse/convert_timestamps_to_original.py
2025-10-12 09:11:32 +08:00

328 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
将PKL文件中的输出时间戳转换为原始数据时间戳
支持三种映射方式:简单映射、保守映射、可能映射
"""
import pickle
import numpy as np
from pathlib import Path
from collections import defaultdict
import argparse
from datetime import datetime
# 滑动窗口参数 (来自 rnn_args.yaml)
PATCH_SIZE = 14 # 滑动窗口大小
PATCH_STRIDE = 4 # 滑动窗口步长
ORIGINAL_BIN_MS = 20 # 原始时间bin大小(ms)
def convert_output_timestamp_to_original(output_start, output_end, patch_size=PATCH_SIZE, patch_stride=PATCH_STRIDE):
"""
将输出时间戳转换为原始数据时间戳,提供三种映射方式
Args:
output_start: 输出序列中的开始时间步
output_end: 输出序列中的结束时间步
patch_size: 滑动窗口大小
patch_stride: 滑动窗口步长
Returns:
dict: 包含三种映射方式的原始时间戳信息
"""
# 计算输出时间步对应的原始时间步中心位置
original_start_center = output_start * patch_stride + (patch_size - 1) / 2
original_end_center = output_end * patch_stride + (patch_size - 1) / 2
# 1. 简单映射:使用中心位置
simple_start = int(round(original_start_center))
simple_end = int(round(original_end_center))
# 2. 保守映射考虑完整的patch范围
patch_start_first = output_start * patch_stride
patch_end_first = patch_start_first + patch_size - 1
patch_start_last = output_end * patch_stride
patch_end_last = patch_start_last + patch_size - 1
conservative_start = patch_start_first
conservative_end = patch_end_last
# 3. 可能映射基于中心位置但考虑patch边界的调整范围
likely_start = max(patch_start_first, int(original_start_center - patch_stride/2))
likely_end = min(patch_end_last, int(original_end_center + patch_stride/2))
return {
'simple_mapping': (simple_start, simple_end),
'conservative_mapping': (conservative_start, conservative_end),
'likely_mapping': (likely_start, likely_end),
'metadata': {
'output_range': (output_start, output_end),
'output_duration': output_end - output_start + 1,
'center_positions': (original_start_center, original_end_center),
'patch_ranges': {
'first_patch': (patch_start_first, patch_end_first),
'last_patch': (patch_start_last, patch_end_last)
}
}
}
def process_phoneme_dataset(input_file_path, output_file_path):
"""
处理phoneme dataset转换时间戳并保存结果
Args:
input_file_path: 输入PKL文件路径
output_file_path: 输出PKL文件路径
"""
print(f"=== 处理phoneme dataset ===")
print(f"输入文件: {input_file_path}")
print(f"输出文件: {output_file_path}")
# 加载原始数据
with open(input_file_path, 'rb') as f:
phoneme_data = pickle.load(f)
print(f"原始数据类型: {type(phoneme_data)}")
if isinstance(phoneme_data, dict):
print(f"音素数量: {len(phoneme_data)}")
total_segments = sum(len(segments) for segments in phoneme_data.values())
print(f"总segment数量: {total_segments}")
# 转换数据结构
converted_data = {}
conversion_stats = {
'total_segments': 0,
'conversion_errors': 0,
'phoneme_counts': defaultdict(int)
}
for phoneme, segments in phoneme_data.items():
converted_segments = []
for segment in segments:
try:
# 获取原始时间戳
output_start = segment['start_time']
output_end = segment['end_time']
# 转换时间戳
conversion = convert_output_timestamp_to_original(output_start, output_end)
# 创建新的segment保留原有信息并添加转换结果
new_segment = segment.copy()
# 添加原始时间戳字段
new_segment['original_timestamps'] = {
'simple_start': conversion['simple_mapping'][0],
'simple_end': conversion['simple_mapping'][1],
'conservative_start': conversion['conservative_mapping'][0],
'conservative_end': conversion['conservative_mapping'][1],
'likely_start': conversion['likely_mapping'][0],
'likely_end': conversion['likely_mapping'][1]
}
# 添加时长信息(毫秒)
output_duration_ms = conversion['metadata']['output_duration'] * PATCH_STRIDE * ORIGINAL_BIN_MS
new_segment['duration_info'] = {
'output_duration_steps': conversion['metadata']['output_duration'],
'output_duration_ms': output_duration_ms,
'simple_duration_steps': conversion['simple_mapping'][1] - conversion['simple_mapping'][0] + 1,
'conservative_duration_steps': conversion['conservative_mapping'][1] - conversion['conservative_mapping'][0] + 1,
'likely_duration_steps': conversion['likely_mapping'][1] - conversion['likely_mapping'][0] + 1
}
# 添加转换元数据
new_segment['conversion_metadata'] = conversion['metadata']
converted_segments.append(new_segment)
conversion_stats['total_segments'] += 1
conversion_stats['phoneme_counts'][phoneme] += 1
except Exception as e:
print(f"转换segment时出错 (phoneme: {phoneme}): {e}")
conversion_stats['conversion_errors'] += 1
continue
converted_data[phoneme] = converted_segments
# 添加转换参数信息
conversion_info = {
'conversion_timestamp': datetime.now().isoformat(),
'parameters': {
'patch_size': PATCH_SIZE,
'patch_stride': PATCH_STRIDE,
'original_bin_ms': ORIGINAL_BIN_MS
},
'statistics': dict(conversion_stats),
'mapping_methods': {
'simple_mapping': '基于输出时间步中心位置的映射',
'conservative_mapping': '基于完整patch范围的保守映射',
'likely_mapping': '基于中心位置但考虑patch边界的调整映射'
}
}
# 保存结果
output_data = {
'phoneme_data': converted_data,
'conversion_info': conversion_info
}
with open(output_file_path, 'wb') as f:
pickle.dump(output_data, f)
print(f"\n=== 转换完成 ===")
print(f"成功转换: {conversion_stats['total_segments']} 个segments")
print(f"转换错误: {conversion_stats['conversion_errors']} 个segments")
print(f"音素分布:")
sorted_phonemes = sorted(conversion_stats['phoneme_counts'].items(),
key=lambda x: x[1], reverse=True)
for phoneme, count in sorted_phonemes[:10]:
print(f" {phoneme}: {count} segments")
return output_data
def analyze_conversion_results(converted_data_path):
"""
分析转换结果的统计信息
Args:
converted_data_path: 转换后的PKL文件路径
"""
print(f"\n=== 分析转换结果 ===")
with open(converted_data_path, 'rb') as f:
data = pickle.load(f)
phoneme_data = data['phoneme_data']
conversion_info = data['conversion_info']
print(f"转换时间: {conversion_info['conversion_timestamp']}")
print(f"转换参数: {conversion_info['parameters']}")
# 统计分析
all_segments = []
for phoneme, segments in phoneme_data.items():
all_segments.extend(segments)
print(f"总segment数量: {len(all_segments)}")
# 分析时长分布
output_durations = []
simple_durations = []
conservative_durations = []
likely_durations = []
for segment in all_segments:
duration_info = segment['duration_info']
output_durations.append(duration_info['output_duration_steps'])
simple_durations.append(duration_info['simple_duration_steps'])
conservative_durations.append(duration_info['conservative_duration_steps'])
likely_durations.append(duration_info['likely_duration_steps'])
print(f"\n时长统计 (时间步):")
print(f"输出时长: 平均 {np.mean(output_durations):.1f}, 中位数 {np.median(output_durations):.1f}")
print(f"简单映射时长: 平均 {np.mean(simple_durations):.1f}, 中位数 {np.median(simple_durations):.1f}")
print(f"保守映射时长: 平均 {np.mean(conservative_durations):.1f}, 中位数 {np.median(conservative_durations):.1f}")
print(f"可能映射时长: 平均 {np.mean(likely_durations):.1f}, 中位数 {np.median(likely_durations):.1f}")
# 计算映射比例
simple_ratios = [s/o for s, o in zip(simple_durations, output_durations) if o > 0]
conservative_ratios = [c/o for c, o in zip(conservative_durations, output_durations) if o > 0]
likely_ratios = [l/o for l, o in zip(likely_durations, output_durations) if o > 0]
print(f"\n映射比例 (原始/输出):")
print(f"简单映射: 平均 {np.mean(simple_ratios):.1f}x")
print(f"保守映射: 平均 {np.mean(conservative_ratios):.1f}x")
print(f"可能映射: 平均 {np.mean(likely_ratios):.1f}x")
# 显示几个示例
print(f"\n=== 转换示例 ===")
sample_segments = all_segments[:5]
print(f"{'音素':4s} {'输出时间戳':12s} {'简单映射':12s} {'保守映射':12s} {'可能映射':12s} {'时长(ms)':8s}")
print("-" * 70)
for segment in sample_segments:
phoneme = segment['phoneme']
output_range = f"{segment['start_time']}-{segment['end_time']}"
timestamps = segment['original_timestamps']
simple = f"{timestamps['simple_start']}-{timestamps['simple_end']}"
conservative = f"{timestamps['conservative_start']}-{timestamps['conservative_end']}"
likely = f"{timestamps['likely_start']}-{timestamps['likely_end']}"
duration_ms = segment['duration_info']['output_duration_ms']
print(f"{phoneme:4s} {output_range:12s} {simple:12s} {conservative:12s} {likely:12s} {duration_ms:6.0f}")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='转换phoneme dataset的时间戳')
parser.add_argument('--input_dir', type=str, default='../phoneme_segmented_data',
help='输入PKL文件目录')
parser.add_argument('--output_dir', type=str, default='../phoneme_segmented_data',
help='输出PKL文件目录')
parser.add_argument('--input_file', type=str, default=None,
help='指定输入文件名如果不指定则处理最新的phoneme_dataset文件')
parser.add_argument('--output_suffix', type=str, default='_with_original_timestamps',
help='输出文件后缀')
parser.add_argument('--analyze_only', action='store_true',
help='只分析已存在的转换结果,不进行转换')
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
# 确保输出目录存在
output_dir.mkdir(parents=True, exist_ok=True)
if args.analyze_only:
# 只进行分析
converted_files = list(output_dir.glob(f"*{args.output_suffix}.pkl"))
if converted_files:
latest_converted = max(converted_files, key=lambda x: x.stat().st_mtime)
analyze_conversion_results(latest_converted)
else:
print("未找到转换后的文件")
return
# 找到输入文件
if args.input_file:
input_file = input_dir / args.input_file
else:
# 找到最新的phoneme_dataset文件
phoneme_files = list(input_dir.glob("phoneme_dataset_*.pkl"))
if not phoneme_files:
print(f"在目录 {input_dir} 中未找到phoneme_dataset文件")
return
input_file = max(phoneme_files, key=lambda x: x.stat().st_mtime)
if not input_file.exists():
print(f"输入文件不存在: {input_file}")
return
# 生成输出文件名
output_filename = input_file.stem + args.output_suffix + '.pkl'
output_file = output_dir / output_filename
try:
# 执行转换
converted_data = process_phoneme_dataset(input_file, output_file)
# 分析结果
analyze_conversion_results(output_file)
print(f"\n转换完成!结果已保存到: {output_file}")
except Exception as e:
print(f"处理过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()