328 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			328 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | ||
| """
 | ||
| 将PKL文件中的输出时间戳转换为原始数据时间戳
 | ||
| 支持三种映射方式:简单映射、保守映射、可能映射
 | ||
| """
 | ||
| 
 | ||
| import pickle
 | ||
| import numpy as np
 | ||
| from pathlib import Path
 | ||
| from collections import defaultdict
 | ||
| import argparse
 | ||
| from datetime import datetime
 | ||
| 
 | ||
| # 滑动窗口参数 (来自 rnn_args.yaml)
 | ||
| PATCH_SIZE = 14    # 滑动窗口大小
 | ||
| PATCH_STRIDE = 4   # 滑动窗口步长
 | ||
| ORIGINAL_BIN_MS = 20  # 原始时间bin大小(ms)
 | ||
| 
 | ||
| def convert_output_timestamp_to_original(output_start, output_end, patch_size=PATCH_SIZE, patch_stride=PATCH_STRIDE):
 | ||
|     """
 | ||
|     将输出时间戳转换为原始数据时间戳,提供三种映射方式
 | ||
| 
 | ||
|     Args:
 | ||
|         output_start: 输出序列中的开始时间步
 | ||
|         output_end: 输出序列中的结束时间步
 | ||
|         patch_size: 滑动窗口大小
 | ||
|         patch_stride: 滑动窗口步长
 | ||
| 
 | ||
|     Returns:
 | ||
|         dict: 包含三种映射方式的原始时间戳信息
 | ||
|     """
 | ||
| 
 | ||
|     # 计算输出时间步对应的原始时间步中心位置
 | ||
|     original_start_center = output_start * patch_stride + (patch_size - 1) / 2
 | ||
|     original_end_center = output_end * patch_stride + (patch_size - 1) / 2
 | ||
| 
 | ||
|     # 1. 简单映射:使用中心位置
 | ||
|     simple_start = int(round(original_start_center))
 | ||
|     simple_end = int(round(original_end_center))
 | ||
| 
 | ||
|     # 2. 保守映射:考虑完整的patch范围
 | ||
|     patch_start_first = output_start * patch_stride
 | ||
|     patch_end_first = patch_start_first + patch_size - 1
 | ||
|     patch_start_last = output_end * patch_stride
 | ||
|     patch_end_last = patch_start_last + patch_size - 1
 | ||
| 
 | ||
|     conservative_start = patch_start_first
 | ||
|     conservative_end = patch_end_last
 | ||
| 
 | ||
|     # 3. 可能映射:基于中心位置但考虑patch边界的调整范围
 | ||
|     likely_start = max(patch_start_first, int(original_start_center - patch_stride/2))
 | ||
|     likely_end = min(patch_end_last, int(original_end_center + patch_stride/2))
 | ||
| 
 | ||
|     return {
 | ||
|         'simple_mapping': (simple_start, simple_end),
 | ||
|         'conservative_mapping': (conservative_start, conservative_end),
 | ||
|         'likely_mapping': (likely_start, likely_end),
 | ||
|         'metadata': {
 | ||
|             'output_range': (output_start, output_end),
 | ||
|             'output_duration': output_end - output_start + 1,
 | ||
|             'center_positions': (original_start_center, original_end_center),
 | ||
|             'patch_ranges': {
 | ||
|                 'first_patch': (patch_start_first, patch_end_first),
 | ||
|                 'last_patch': (patch_start_last, patch_end_last)
 | ||
|             }
 | ||
|         }
 | ||
|     }
 | ||
| 
 | ||
| def process_phoneme_dataset(input_file_path, output_file_path):
 | ||
|     """
 | ||
|     处理phoneme dataset,转换时间戳并保存结果
 | ||
| 
 | ||
|     Args:
 | ||
|         input_file_path: 输入PKL文件路径
 | ||
|         output_file_path: 输出PKL文件路径
 | ||
|     """
 | ||
| 
 | ||
|     print(f"=== 处理phoneme dataset ===")
 | ||
|     print(f"输入文件: {input_file_path}")
 | ||
|     print(f"输出文件: {output_file_path}")
 | ||
| 
 | ||
|     # 加载原始数据
 | ||
|     with open(input_file_path, 'rb') as f:
 | ||
|         phoneme_data = pickle.load(f)
 | ||
| 
 | ||
|     print(f"原始数据类型: {type(phoneme_data)}")
 | ||
|     if isinstance(phoneme_data, dict):
 | ||
|         print(f"音素数量: {len(phoneme_data)}")
 | ||
|         total_segments = sum(len(segments) for segments in phoneme_data.values())
 | ||
|         print(f"总segment数量: {total_segments}")
 | ||
| 
 | ||
|     # 转换数据结构
 | ||
|     converted_data = {}
 | ||
|     conversion_stats = {
 | ||
|         'total_segments': 0,
 | ||
|         'conversion_errors': 0,
 | ||
|         'phoneme_counts': defaultdict(int)
 | ||
|     }
 | ||
| 
 | ||
|     for phoneme, segments in phoneme_data.items():
 | ||
|         converted_segments = []
 | ||
| 
 | ||
|         for segment in segments:
 | ||
|             try:
 | ||
|                 # 获取原始时间戳
 | ||
|                 output_start = segment['start_time']
 | ||
|                 output_end = segment['end_time']
 | ||
| 
 | ||
|                 # 转换时间戳
 | ||
|                 conversion = convert_output_timestamp_to_original(output_start, output_end)
 | ||
| 
 | ||
|                 # 创建新的segment,保留原有信息并添加转换结果
 | ||
|                 new_segment = segment.copy()
 | ||
| 
 | ||
|                 # 添加原始时间戳字段
 | ||
|                 new_segment['original_timestamps'] = {
 | ||
|                     'simple_start': conversion['simple_mapping'][0],
 | ||
|                     'simple_end': conversion['simple_mapping'][1],
 | ||
|                     'conservative_start': conversion['conservative_mapping'][0],
 | ||
|                     'conservative_end': conversion['conservative_mapping'][1],
 | ||
|                     'likely_start': conversion['likely_mapping'][0],
 | ||
|                     'likely_end': conversion['likely_mapping'][1]
 | ||
|                 }
 | ||
| 
 | ||
|                 # 添加时长信息(毫秒)
 | ||
|                 output_duration_ms = conversion['metadata']['output_duration'] * PATCH_STRIDE * ORIGINAL_BIN_MS
 | ||
|                 new_segment['duration_info'] = {
 | ||
|                     'output_duration_steps': conversion['metadata']['output_duration'],
 | ||
|                     'output_duration_ms': output_duration_ms,
 | ||
|                     'simple_duration_steps': conversion['simple_mapping'][1] - conversion['simple_mapping'][0] + 1,
 | ||
|                     'conservative_duration_steps': conversion['conservative_mapping'][1] - conversion['conservative_mapping'][0] + 1,
 | ||
|                     'likely_duration_steps': conversion['likely_mapping'][1] - conversion['likely_mapping'][0] + 1
 | ||
|                 }
 | ||
| 
 | ||
|                 # 添加转换元数据
 | ||
|                 new_segment['conversion_metadata'] = conversion['metadata']
 | ||
| 
 | ||
|                 converted_segments.append(new_segment)
 | ||
|                 conversion_stats['total_segments'] += 1
 | ||
|                 conversion_stats['phoneme_counts'][phoneme] += 1
 | ||
| 
 | ||
|             except Exception as e:
 | ||
|                 print(f"转换segment时出错 (phoneme: {phoneme}): {e}")
 | ||
|                 conversion_stats['conversion_errors'] += 1
 | ||
|                 continue
 | ||
| 
 | ||
|         converted_data[phoneme] = converted_segments
 | ||
| 
 | ||
|     # 添加转换参数信息
 | ||
|     conversion_info = {
 | ||
|         'conversion_timestamp': datetime.now().isoformat(),
 | ||
|         'parameters': {
 | ||
|             'patch_size': PATCH_SIZE,
 | ||
|             'patch_stride': PATCH_STRIDE,
 | ||
|             'original_bin_ms': ORIGINAL_BIN_MS
 | ||
|         },
 | ||
|         'statistics': dict(conversion_stats),
 | ||
|         'mapping_methods': {
 | ||
|             'simple_mapping': '基于输出时间步中心位置的映射',
 | ||
|             'conservative_mapping': '基于完整patch范围的保守映射',
 | ||
|             'likely_mapping': '基于中心位置但考虑patch边界的调整映射'
 | ||
|         }
 | ||
|     }
 | ||
| 
 | ||
|     # 保存结果
 | ||
|     output_data = {
 | ||
|         'phoneme_data': converted_data,
 | ||
|         'conversion_info': conversion_info
 | ||
|     }
 | ||
| 
 | ||
|     with open(output_file_path, 'wb') as f:
 | ||
|         pickle.dump(output_data, f)
 | ||
| 
 | ||
|     print(f"\n=== 转换完成 ===")
 | ||
|     print(f"成功转换: {conversion_stats['total_segments']} 个segments")
 | ||
|     print(f"转换错误: {conversion_stats['conversion_errors']} 个segments")
 | ||
|     print(f"音素分布:")
 | ||
| 
 | ||
|     sorted_phonemes = sorted(conversion_stats['phoneme_counts'].items(),
 | ||
|                            key=lambda x: x[1], reverse=True)
 | ||
|     for phoneme, count in sorted_phonemes[:10]:
 | ||
|         print(f"  {phoneme}: {count} segments")
 | ||
| 
 | ||
|     return output_data
 | ||
| 
 | ||
| def analyze_conversion_results(converted_data_path):
 | ||
|     """
 | ||
|     分析转换结果的统计信息
 | ||
| 
 | ||
|     Args:
 | ||
|         converted_data_path: 转换后的PKL文件路径
 | ||
|     """
 | ||
| 
 | ||
|     print(f"\n=== 分析转换结果 ===")
 | ||
| 
 | ||
|     with open(converted_data_path, 'rb') as f:
 | ||
|         data = pickle.load(f)
 | ||
| 
 | ||
|     phoneme_data = data['phoneme_data']
 | ||
|     conversion_info = data['conversion_info']
 | ||
| 
 | ||
|     print(f"转换时间: {conversion_info['conversion_timestamp']}")
 | ||
|     print(f"转换参数: {conversion_info['parameters']}")
 | ||
| 
 | ||
|     # 统计分析
 | ||
|     all_segments = []
 | ||
|     for phoneme, segments in phoneme_data.items():
 | ||
|         all_segments.extend(segments)
 | ||
| 
 | ||
|     print(f"总segment数量: {len(all_segments)}")
 | ||
| 
 | ||
|     # 分析时长分布
 | ||
|     output_durations = []
 | ||
|     simple_durations = []
 | ||
|     conservative_durations = []
 | ||
|     likely_durations = []
 | ||
| 
 | ||
|     for segment in all_segments:
 | ||
|         duration_info = segment['duration_info']
 | ||
|         output_durations.append(duration_info['output_duration_steps'])
 | ||
|         simple_durations.append(duration_info['simple_duration_steps'])
 | ||
|         conservative_durations.append(duration_info['conservative_duration_steps'])
 | ||
|         likely_durations.append(duration_info['likely_duration_steps'])
 | ||
| 
 | ||
|     print(f"\n时长统计 (时间步):")
 | ||
|     print(f"输出时长:     平均 {np.mean(output_durations):.1f}, 中位数 {np.median(output_durations):.1f}")
 | ||
|     print(f"简单映射时长: 平均 {np.mean(simple_durations):.1f}, 中位数 {np.median(simple_durations):.1f}")
 | ||
|     print(f"保守映射时长: 平均 {np.mean(conservative_durations):.1f}, 中位数 {np.median(conservative_durations):.1f}")
 | ||
|     print(f"可能映射时长: 平均 {np.mean(likely_durations):.1f}, 中位数 {np.median(likely_durations):.1f}")
 | ||
| 
 | ||
|     # 计算映射比例
 | ||
|     simple_ratios = [s/o for s, o in zip(simple_durations, output_durations) if o > 0]
 | ||
|     conservative_ratios = [c/o for c, o in zip(conservative_durations, output_durations) if o > 0]
 | ||
|     likely_ratios = [l/o for l, o in zip(likely_durations, output_durations) if o > 0]
 | ||
| 
 | ||
|     print(f"\n映射比例 (原始/输出):")
 | ||
|     print(f"简单映射:    平均 {np.mean(simple_ratios):.1f}x")
 | ||
|     print(f"保守映射:    平均 {np.mean(conservative_ratios):.1f}x")
 | ||
|     print(f"可能映射:    平均 {np.mean(likely_ratios):.1f}x")
 | ||
| 
 | ||
|     # 显示几个示例
 | ||
|     print(f"\n=== 转换示例 ===")
 | ||
|     sample_segments = all_segments[:5]
 | ||
| 
 | ||
|     print(f"{'音素':4s} {'输出时间戳':12s} {'简单映射':12s} {'保守映射':12s} {'可能映射':12s} {'时长(ms)':8s}")
 | ||
|     print("-" * 70)
 | ||
| 
 | ||
|     for segment in sample_segments:
 | ||
|         phoneme = segment['phoneme']
 | ||
|         output_range = f"{segment['start_time']}-{segment['end_time']}"
 | ||
| 
 | ||
|         timestamps = segment['original_timestamps']
 | ||
|         simple = f"{timestamps['simple_start']}-{timestamps['simple_end']}"
 | ||
|         conservative = f"{timestamps['conservative_start']}-{timestamps['conservative_end']}"
 | ||
|         likely = f"{timestamps['likely_start']}-{timestamps['likely_end']}"
 | ||
| 
 | ||
|         duration_ms = segment['duration_info']['output_duration_ms']
 | ||
| 
 | ||
|         print(f"{phoneme:4s} {output_range:12s} {simple:12s} {conservative:12s} {likely:12s} {duration_ms:6.0f}")
 | ||
| 
 | ||
| def main():
 | ||
|     """主函数"""
 | ||
|     parser = argparse.ArgumentParser(description='转换phoneme dataset的时间戳')
 | ||
|     parser.add_argument('--input_dir', type=str, default='../phoneme_segmented_data',
 | ||
|                        help='输入PKL文件目录')
 | ||
|     parser.add_argument('--output_dir', type=str, default='../phoneme_segmented_data',
 | ||
|                        help='输出PKL文件目录')
 | ||
|     parser.add_argument('--input_file', type=str, default=None,
 | ||
|                        help='指定输入文件名,如果不指定则处理最新的phoneme_dataset文件')
 | ||
|     parser.add_argument('--output_suffix', type=str, default='_with_original_timestamps',
 | ||
|                        help='输出文件后缀')
 | ||
|     parser.add_argument('--analyze_only', action='store_true',
 | ||
|                        help='只分析已存在的转换结果,不进行转换')
 | ||
| 
 | ||
|     args = parser.parse_args()
 | ||
| 
 | ||
|     input_dir = Path(args.input_dir)
 | ||
|     output_dir = Path(args.output_dir)
 | ||
| 
 | ||
|     # 确保输出目录存在
 | ||
|     output_dir.mkdir(parents=True, exist_ok=True)
 | ||
| 
 | ||
|     if args.analyze_only:
 | ||
|         # 只进行分析
 | ||
|         converted_files = list(output_dir.glob(f"*{args.output_suffix}.pkl"))
 | ||
|         if converted_files:
 | ||
|             latest_converted = max(converted_files, key=lambda x: x.stat().st_mtime)
 | ||
|             analyze_conversion_results(latest_converted)
 | ||
|         else:
 | ||
|             print("未找到转换后的文件")
 | ||
|         return
 | ||
| 
 | ||
|     # 找到输入文件
 | ||
|     if args.input_file:
 | ||
|         input_file = input_dir / args.input_file
 | ||
|     else:
 | ||
|         # 找到最新的phoneme_dataset文件
 | ||
|         phoneme_files = list(input_dir.glob("phoneme_dataset_*.pkl"))
 | ||
|         if not phoneme_files:
 | ||
|             print(f"在目录 {input_dir} 中未找到phoneme_dataset文件")
 | ||
|             return
 | ||
| 
 | ||
|         input_file = max(phoneme_files, key=lambda x: x.stat().st_mtime)
 | ||
| 
 | ||
|     if not input_file.exists():
 | ||
|         print(f"输入文件不存在: {input_file}")
 | ||
|         return
 | ||
| 
 | ||
|     # 生成输出文件名
 | ||
|     output_filename = input_file.stem + args.output_suffix + '.pkl'
 | ||
|     output_file = output_dir / output_filename
 | ||
| 
 | ||
|     try:
 | ||
|         # 执行转换
 | ||
|         converted_data = process_phoneme_dataset(input_file, output_file)
 | ||
| 
 | ||
|         # 分析结果
 | ||
|         analyze_conversion_results(output_file)
 | ||
| 
 | ||
|         print(f"\n转换完成!结果已保存到: {output_file}")
 | ||
| 
 | ||
|     except Exception as e:
 | ||
|         print(f"处理过程中出现错误: {e}")
 | ||
|         import traceback
 | ||
|         traceback.print_exc()
 | ||
| 
 | ||
| if __name__ == "__main__":
 | ||
|     main() | 
