#!/usr/bin/env python3 """ Convert phoneme segmented data to phoneme classification dataset 将音素分段数据转换为音素分类数据集 """ import pickle import numpy as np import torch from pathlib import Path from collections import defaultdict import os import sys # Add parent directory to path for imports sys.path.append(str(Path(__file__).parent.parent)) def load_neural_data_for_trial(session, trial_metadata): """Load neural features for a specific trial""" try: from model_training.evaluate_model_helpers import load_h5py_file import pandas as pd # Try to load the session data data_dir = Path(__file__).parent.parent / "data" / "hdf5_data_final" train_file = data_dir / session / "data_train.hdf5" if not train_file.exists(): return None # Load CSV for metadata csv_path = data_dir.parent / "b2txt_dataset_info.csv" if csv_path.exists(): b2txt_csv_df = pd.read_csv(csv_path) else: b2txt_csv_df = None data = load_h5py_file(str(train_file), b2txt_csv_df) # Find the matching trial trial_idx = trial_metadata.get('trial_idx') if trial_idx is not None and trial_idx < len(data['neural_features']): return data['neural_features'][trial_idx] except Exception as e: print(f"Warning: Could not load neural data for {session}, trial {trial_metadata.get('trial_idx', 'unknown')}: {e}") return None def validate_phoneme_against_ground_truth(segment, ctc_data): """ Validate a phoneme segment against ground truth sequence labels 返回: (is_valid, error_reason, ground_truth_phoneme) """ try: session = segment['session'] trial_idx = segment.get('trial_idx') trial_key = (session, trial_idx) if trial_key not in ctc_data: return False, "no_trial_data", None trial_data = ctc_data[trial_key] original_sequence = trial_data.get('original_sequence') if original_sequence is None: return False, "no_ground_truth", None # Convert sequence IDs to phonemes using LOGIT_TO_PHONEME mapping try: from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME except: # Fallback phoneme mapping if import fails LOGIT_TO_PHONEME = { 0: 'BLANK', 1: 'AA', 2: 'AE', 3: 'AH', 4: 'AO', 5: 'AW', 6: 'AY', 7: 'B', 8: 'CH', 9: 'D', 10: 'DH', 11: 'EH', 12: 'ER', 13: 'EY', 14: 'F', 15: 'G', 16: 'HH', 17: 'IH', 18: 'IY', 19: 'JH', 20: 'K', 21: 'L', 22: 'M', 23: 'N', 24: 'NG', 25: 'OW', 26: 'OY', 27: 'P', 28: 'R', 29: 'S', 30: 'SH', 31: 'T', 32: 'TH', 33: 'UH', 34: 'UW', 35: 'V', 36: 'W', 37: 'Y', 38: 'Z', 39: 'ZH', 40: ' | ' } # Convert ground truth sequence to phonemes (filter out zeros/padding) ground_truth_phonemes = [] for seq_id in original_sequence: if seq_id > 0 and seq_id in LOGIT_TO_PHONEME: # Skip padding/blank ground_truth_phonemes.append(LOGIT_TO_PHONEME[seq_id]) # Find the position of this segment in the predicted sequence predicted_sequence = trial_data.get('predicted_phonemes', []) alignment_info = trial_data.get('alignment_info', []) # Find this segment in the alignment info segment_phoneme = segment['phoneme'] segment_start = segment['start_time'] segment_end = segment['end_time'] # Find matching alignment segment segment_position = None for i, (phoneme, start, end, conf) in enumerate(alignment_info): if (phoneme == segment_phoneme and start == segment_start and end == segment_end): segment_position = i break if segment_position is None: return False, "segment_not_found_in_alignment", None # Check if the phoneme at this position matches ground truth if segment_position < len(ground_truth_phonemes): expected_phoneme = ground_truth_phonemes[segment_position] if segment_phoneme == expected_phoneme: return True, "valid", expected_phoneme else: return False, "phoneme_mismatch", expected_phoneme else: return False, "position_out_of_range", None except Exception as e: return False, f"validation_error: {str(e)}", None def create_phoneme_classification_dataset(): """Create a phoneme classification dataset from segmented data with validation""" # Load the latest phoneme dataset data_dir = Path("phoneme_segmented_data") dataset_files = list(data_dir.glob("phoneme_dataset_*.pkl")) if not dataset_files: print("No phoneme dataset files found!") return latest_dataset = max(dataset_files, key=lambda x: x.stat().st_mtime) print(f"Loading dataset: {latest_dataset.name}") with open(latest_dataset, 'rb') as f: phoneme_data = pickle.load(f) # Also load the corresponding CTC results for ground truth validation ctc_file = latest_dataset.parent / latest_dataset.name.replace("phoneme_dataset_", "ctc_results_") ctc_data = {} if ctc_file.exists(): with open(ctc_file, 'rb') as f: ctc_results = pickle.load(f) # Create lookup dictionary for validation for result in ctc_results: key = (result['session'], result['trial_idx']) ctc_data[key] = result print(f"Loaded {len(phoneme_data)} phoneme types") print(f"Associated CTC results: {len(ctc_data)} trials") # Create classification dataset classification_data = { 'features': [], # Neural features for each segment 'labels': [], # Phoneme labels 'phoneme_to_id': {}, # Phoneme to numeric ID mapping 'id_to_phoneme': {}, # Numeric ID to phoneme mapping 'metadata': [] # Additional metadata for each sample } # Create error tracking error_data = { 'incorrect_segments': [], # Incorrect phoneme segments 'validation_stats': {} # Validation statistics } # Create phoneme to ID mapping unique_phonemes = sorted(phoneme_data.keys()) for i, phoneme in enumerate(unique_phonemes): classification_data['phoneme_to_id'][phoneme] = i classification_data['id_to_phoneme'][i] = phoneme print(f"\nPhoneme mapping created for {len(unique_phonemes)} phonemes:") for i, phoneme in enumerate(unique_phonemes[:10]): # Show first 10 print(f" {i:2d}: '{phoneme}'") if len(unique_phonemes) > 10: print(f" ... and {len(unique_phonemes) - 10} more") # Validation and extraction statistics validation_stats = { 'total_segments': 0, 'valid_segments': 0, 'invalid_segments': 0, 'discarded_neighbors': 0, 'successful_extractions': 0, 'error_reasons': defaultdict(int) } print(f"\nValidating phoneme segments against ground truth...") # First pass: validate all segments and mark invalid ones segment_validity = {} # Maps (phoneme, segment_idx) -> (is_valid, error_reason, ground_truth) for phoneme, segments in phoneme_data.items(): print(f"Validating '{phoneme}' ({len(segments)} segments)...") for segment_idx, segment in enumerate(segments): validation_stats['total_segments'] += 1 # Validate against ground truth is_valid, error_reason, ground_truth_phoneme = validate_phoneme_against_ground_truth(segment, ctc_data) segment_validity[(phoneme, segment_idx)] = (is_valid, error_reason, ground_truth_phoneme) if is_valid: validation_stats['valid_segments'] += 1 else: validation_stats['invalid_segments'] += 1 validation_stats['error_reasons'][error_reason] += 1 # Save error information error_data['incorrect_segments'].append({ 'phoneme': phoneme, 'segment_idx': segment_idx, 'segment': segment, 'predicted_phoneme': phoneme, 'ground_truth_phoneme': ground_truth_phoneme, 'error_reason': error_reason }) print(f"\nValidation completed:") print(f" Total segments: {validation_stats['total_segments']}") print(f" Valid segments: {validation_stats['valid_segments']}") print(f" Invalid segments: {validation_stats['invalid_segments']}") print(f" Validation accuracy: {validation_stats['valid_segments']/validation_stats['total_segments']*100:.1f}%") print(f"\nError breakdown:") for error_reason, count in validation_stats['error_reasons'].items(): print(f" {error_reason}: {count}") # Second pass: extract features for valid segments (excluding neighbors of invalid ones) print(f"\nExtracting neural features for validated segments...") for phoneme, segments in phoneme_data.items(): phoneme_id = classification_data['phoneme_to_id'][phoneme] print(f"Processing '{phoneme}' ({len(segments)} segments)...") for segment_idx, segment in enumerate(segments): # Check if this segment is valid is_valid, error_reason, ground_truth_phoneme = segment_validity[(phoneme, segment_idx)] # Check if neighboring segments are invalid (discard neighbors) prev_invalid = (segment_idx > 0 and not segment_validity[(phoneme, segment_idx - 1)][0]) next_invalid = (segment_idx < len(segments) - 1 and not segment_validity[(phoneme, segment_idx + 1)][0]) if not is_valid: continue # Skip invalid segments if prev_invalid or next_invalid: validation_stats['discarded_neighbors'] += 1 continue # Skip neighbors of invalid segments # Get trial information session = segment['session'] trial_key = (session, segment.get('trial_idx')) # Try to get neural data for this trial neural_features = None if trial_key in ctc_data: # We have the trial data, now extract the segment trial_metadata = ctc_data[trial_key] if 'neural_features' in trial_metadata: neural_features = trial_metadata['neural_features'] else: # Try to load from HDF5 files neural_features = load_neural_data_for_trial(session, segment) if neural_features is not None: # Extract the specific time segment start_time = int(segment['start_time']) end_time = int(segment['end_time']) # Ensure valid time range if start_time <= end_time and end_time < len(neural_features): # Extract neural features for this time segment segment_features = neural_features[start_time:end_time+1] # Include end_time # Convert to numpy array and handle different cases if isinstance(segment_features, torch.Tensor): segment_features = segment_features.numpy() elif isinstance(segment_features, list): segment_features = np.array(segment_features) # For classification, we need a fixed-size feature vector # Option 1: Use mean across time steps if len(segment_features.shape) == 2: # (time, features) feature_vector = np.mean(segment_features, axis=0) elif len(segment_features.shape) == 1: # Already 1D feature_vector = segment_features else: print(f"Unexpected feature shape: {segment_features.shape}") continue # Add to dataset classification_data['features'].append(feature_vector) classification_data['labels'].append(phoneme_id) classification_data['metadata'].append({ 'phoneme': phoneme, 'session': session, 'trial_num': segment.get('trial_num', -1), 'trial_idx': segment.get('trial_idx', -1), 'start_time': start_time, 'end_time': end_time, 'duration': end_time - start_time + 1, 'confidence': segment.get('confidence', 0.0), 'corpus': segment.get('corpus', 'unknown'), 'validated': True }) validation_stats['successful_extractions'] += 1 # Progress update if validation_stats['total_segments'] % 1000 == 0: print(f" Processed {validation_stats['total_segments']} segments, extracted {validation_stats['successful_extractions']} features") print(f"\nDataset creation completed!") print(f"Total segments processed: {validation_stats['total_segments']}") print(f"Valid segments (excluding neighbors): {validation_stats['valid_segments'] - validation_stats['discarded_neighbors']}") print(f"Discarded neighbor segments: {validation_stats['discarded_neighbors']}") print(f"Successful feature extractions: {validation_stats['successful_extractions']}") print(f"Extraction success rate: {validation_stats['successful_extractions']/(validation_stats['valid_segments']-validation_stats['discarded_neighbors'])*100:.1f}%") if validation_stats['successful_extractions'] == 0: print("No features were extracted. Check neural data availability.") return # Convert to numpy arrays classification_data['features'] = np.array(classification_data['features']) classification_data['labels'] = np.array(classification_data['labels']) print(f"\nFinal validated dataset shape:") print(f"Features: {classification_data['features'].shape}") print(f"Labels: {classification_data['labels'].shape}") # Show class distribution print(f"\nClass distribution:") unique_labels, counts = np.unique(classification_data['labels'], return_counts=True) for label_id, count in zip(unique_labels, counts): phoneme = classification_data['id_to_phoneme'][label_id] print(f" {label_id:2d} ('{phoneme}'): {count:4d} samples") # Save the classification dataset timestamp = latest_dataset.name.split('_')[-1].replace('.pkl', '') output_file = f"phoneme_classification_dataset_validated_{timestamp}.pkl" output_path = data_dir / output_file # Add validation stats to the dataset classification_data['validation_stats'] = validation_stats with open(output_path, 'wb') as f: pickle.dump(classification_data, f) print(f"\nValidated classification dataset saved to: {output_file}") # Save error data separately error_data['validation_stats'] = validation_stats error_file = f"phoneme_validation_errors_{timestamp}.pkl" error_path = data_dir / error_file with open(error_path, 'wb') as f: pickle.dump(error_data, f) print(f"Validation errors saved to: {error_file}") # Create a simple train/test split example create_train_test_split(classification_data, data_dir, timestamp) return classification_data def create_train_test_split(data, data_dir, timestamp): """Create train/test split for the classification dataset""" from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler print(f"\nCreating train/test split...") X = data['features'] y = data['labels'] metadata = data['metadata'] # Split by session to avoid data leakage sessions = [meta['session'] for meta in metadata] unique_sessions = list(set(sessions)) print(f"Available sessions: {len(unique_sessions)}") if len(unique_sessions) >= 4: # Use session-based split train_sessions = unique_sessions[:int(len(unique_sessions) * 0.8)] test_sessions = unique_sessions[int(len(unique_sessions) * 0.8):] train_indices = [i for i, meta in enumerate(metadata) if meta['session'] in train_sessions] test_indices = [i for i, meta in enumerate(metadata) if meta['session'] in test_sessions] X_train, X_test = X[train_indices], X[test_indices] y_train, y_test = y[train_indices], y[test_indices] print(f"Session-based split:") print(f" Train sessions: {train_sessions}") print(f" Test sessions: {test_sessions}") else: # Use random split X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) print(f"Random split (stratified):") print(f" Train samples: {len(X_train)}") print(f" Test samples: {len(X_test)}") # Standardize features scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) X_test_scaled = scaler.transform(X_test) # Save split data split_data = { 'X_train': X_train_scaled, 'X_test': X_test_scaled, 'y_train': y_train, 'y_test': y_test, 'scaler': scaler, 'phoneme_to_id': data['phoneme_to_id'], 'id_to_phoneme': data['id_to_phoneme'] } split_file = f"phoneme_classification_split_{timestamp}.pkl" split_path = data_dir / split_file with open(split_path, 'wb') as f: pickle.dump(split_data, f) print(f"Train/test split saved to: {split_file}") if __name__ == "__main__": try: classification_data = create_phoneme_classification_dataset() except Exception as e: print(f"Error creating classification dataset: {e}") import traceback traceback.print_exc()