446 lines
18 KiB
Python
446 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Convert phoneme segmented data to phoneme classification dataset
|
|
将音素分段数据转换为音素分类数据集
|
|
"""
|
|
|
|
import pickle
|
|
import numpy as np
|
|
import torch
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
import os
|
|
import sys
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
def load_neural_data_for_trial(session, trial_metadata):
|
|
"""Load neural features for a specific trial"""
|
|
try:
|
|
from model_training.evaluate_model_helpers import load_h5py_file
|
|
import pandas as pd
|
|
|
|
# Try to load the session data
|
|
data_dir = Path(__file__).parent.parent / "data" / "hdf5_data_final"
|
|
train_file = data_dir / session / "data_train.hdf5"
|
|
|
|
if not train_file.exists():
|
|
return None
|
|
|
|
# Load CSV for metadata
|
|
csv_path = data_dir.parent / "b2txt_dataset_info.csv"
|
|
if csv_path.exists():
|
|
b2txt_csv_df = pd.read_csv(csv_path)
|
|
else:
|
|
b2txt_csv_df = None
|
|
|
|
data = load_h5py_file(str(train_file), b2txt_csv_df)
|
|
|
|
# Find the matching trial
|
|
trial_idx = trial_metadata.get('trial_idx')
|
|
if trial_idx is not None and trial_idx < len(data['neural_features']):
|
|
return data['neural_features'][trial_idx]
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Could not load neural data for {session}, trial {trial_metadata.get('trial_idx', 'unknown')}: {e}")
|
|
|
|
return None
|
|
|
|
def validate_phoneme_against_ground_truth(segment, ctc_data):
|
|
"""
|
|
Validate a phoneme segment against ground truth sequence labels
|
|
返回: (is_valid, error_reason, ground_truth_phoneme)
|
|
"""
|
|
try:
|
|
session = segment['session']
|
|
trial_idx = segment.get('trial_idx')
|
|
trial_key = (session, trial_idx)
|
|
|
|
if trial_key not in ctc_data:
|
|
return False, "no_trial_data", None
|
|
|
|
trial_data = ctc_data[trial_key]
|
|
original_sequence = trial_data.get('original_sequence')
|
|
|
|
if original_sequence is None:
|
|
return False, "no_ground_truth", None
|
|
|
|
# Convert sequence IDs to phonemes using LOGIT_TO_PHONEME mapping
|
|
try:
|
|
from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME
|
|
except:
|
|
# Fallback phoneme mapping if import fails
|
|
LOGIT_TO_PHONEME = {
|
|
0: 'BLANK', 1: 'AA', 2: 'AE', 3: 'AH', 4: 'AO', 5: 'AW', 6: 'AY',
|
|
7: 'B', 8: 'CH', 9: 'D', 10: 'DH', 11: 'EH', 12: 'ER', 13: 'EY', 14: 'F',
|
|
15: 'G', 16: 'HH', 17: 'IH', 18: 'IY', 19: 'JH', 20: 'K', 21: 'L', 22: 'M',
|
|
23: 'N', 24: 'NG', 25: 'OW', 26: 'OY', 27: 'P', 28: 'R', 29: 'S', 30: 'SH',
|
|
31: 'T', 32: 'TH', 33: 'UH', 34: 'UW', 35: 'V', 36: 'W', 37: 'Y', 38: 'Z',
|
|
39: 'ZH', 40: ' | '
|
|
}
|
|
|
|
# Convert ground truth sequence to phonemes (filter out zeros/padding)
|
|
ground_truth_phonemes = []
|
|
for seq_id in original_sequence:
|
|
if seq_id > 0 and seq_id in LOGIT_TO_PHONEME: # Skip padding/blank
|
|
ground_truth_phonemes.append(LOGIT_TO_PHONEME[seq_id])
|
|
|
|
# Find the position of this segment in the predicted sequence
|
|
predicted_sequence = trial_data.get('predicted_phonemes', [])
|
|
alignment_info = trial_data.get('alignment_info', [])
|
|
|
|
# Find this segment in the alignment info
|
|
segment_phoneme = segment['phoneme']
|
|
segment_start = segment['start_time']
|
|
segment_end = segment['end_time']
|
|
|
|
# Find matching alignment segment
|
|
segment_position = None
|
|
for i, (phoneme, start, end, conf) in enumerate(alignment_info):
|
|
if (phoneme == segment_phoneme and
|
|
start == segment_start and
|
|
end == segment_end):
|
|
segment_position = i
|
|
break
|
|
|
|
if segment_position is None:
|
|
return False, "segment_not_found_in_alignment", None
|
|
|
|
# Check if the phoneme at this position matches ground truth
|
|
if segment_position < len(ground_truth_phonemes):
|
|
expected_phoneme = ground_truth_phonemes[segment_position]
|
|
if segment_phoneme == expected_phoneme:
|
|
return True, "valid", expected_phoneme
|
|
else:
|
|
return False, "phoneme_mismatch", expected_phoneme
|
|
else:
|
|
return False, "position_out_of_range", None
|
|
|
|
except Exception as e:
|
|
return False, f"validation_error: {str(e)}", None
|
|
|
|
def create_phoneme_classification_dataset():
|
|
"""Create a phoneme classification dataset from segmented data with validation"""
|
|
|
|
# Load the latest phoneme dataset
|
|
data_dir = Path("phoneme_segmented_data")
|
|
dataset_files = list(data_dir.glob("phoneme_dataset_*.pkl"))
|
|
|
|
if not dataset_files:
|
|
print("No phoneme dataset files found!")
|
|
return
|
|
|
|
latest_dataset = max(dataset_files, key=lambda x: x.stat().st_mtime)
|
|
print(f"Loading dataset: {latest_dataset.name}")
|
|
|
|
with open(latest_dataset, 'rb') as f:
|
|
phoneme_data = pickle.load(f)
|
|
|
|
# Also load the corresponding CTC results for ground truth validation
|
|
ctc_file = latest_dataset.parent / latest_dataset.name.replace("phoneme_dataset_", "ctc_results_")
|
|
ctc_data = {}
|
|
|
|
if ctc_file.exists():
|
|
with open(ctc_file, 'rb') as f:
|
|
ctc_results = pickle.load(f)
|
|
# Create lookup dictionary for validation
|
|
for result in ctc_results:
|
|
key = (result['session'], result['trial_idx'])
|
|
ctc_data[key] = result
|
|
|
|
print(f"Loaded {len(phoneme_data)} phoneme types")
|
|
print(f"Associated CTC results: {len(ctc_data)} trials")
|
|
|
|
# Create classification dataset
|
|
classification_data = {
|
|
'features': [], # Neural features for each segment
|
|
'labels': [], # Phoneme labels
|
|
'phoneme_to_id': {}, # Phoneme to numeric ID mapping
|
|
'id_to_phoneme': {}, # Numeric ID to phoneme mapping
|
|
'metadata': [] # Additional metadata for each sample
|
|
}
|
|
|
|
# Create error tracking
|
|
error_data = {
|
|
'incorrect_segments': [], # Incorrect phoneme segments
|
|
'validation_stats': {} # Validation statistics
|
|
}
|
|
|
|
# Create phoneme to ID mapping
|
|
unique_phonemes = sorted(phoneme_data.keys())
|
|
for i, phoneme in enumerate(unique_phonemes):
|
|
classification_data['phoneme_to_id'][phoneme] = i
|
|
classification_data['id_to_phoneme'][i] = phoneme
|
|
|
|
print(f"\nPhoneme mapping created for {len(unique_phonemes)} phonemes:")
|
|
for i, phoneme in enumerate(unique_phonemes[:10]): # Show first 10
|
|
print(f" {i:2d}: '{phoneme}'")
|
|
if len(unique_phonemes) > 10:
|
|
print(f" ... and {len(unique_phonemes) - 10} more")
|
|
|
|
# Validation and extraction statistics
|
|
validation_stats = {
|
|
'total_segments': 0,
|
|
'valid_segments': 0,
|
|
'invalid_segments': 0,
|
|
'discarded_neighbors': 0,
|
|
'successful_extractions': 0,
|
|
'error_reasons': defaultdict(int)
|
|
}
|
|
|
|
print(f"\nValidating phoneme segments against ground truth...")
|
|
|
|
# First pass: validate all segments and mark invalid ones
|
|
segment_validity = {} # Maps (phoneme, segment_idx) -> (is_valid, error_reason, ground_truth)
|
|
|
|
for phoneme, segments in phoneme_data.items():
|
|
print(f"Validating '{phoneme}' ({len(segments)} segments)...")
|
|
|
|
for segment_idx, segment in enumerate(segments):
|
|
validation_stats['total_segments'] += 1
|
|
|
|
# Validate against ground truth
|
|
is_valid, error_reason, ground_truth_phoneme = validate_phoneme_against_ground_truth(segment, ctc_data)
|
|
segment_validity[(phoneme, segment_idx)] = (is_valid, error_reason, ground_truth_phoneme)
|
|
|
|
if is_valid:
|
|
validation_stats['valid_segments'] += 1
|
|
else:
|
|
validation_stats['invalid_segments'] += 1
|
|
validation_stats['error_reasons'][error_reason] += 1
|
|
|
|
# Save error information
|
|
error_data['incorrect_segments'].append({
|
|
'phoneme': phoneme,
|
|
'segment_idx': segment_idx,
|
|
'segment': segment,
|
|
'predicted_phoneme': phoneme,
|
|
'ground_truth_phoneme': ground_truth_phoneme,
|
|
'error_reason': error_reason
|
|
})
|
|
|
|
print(f"\nValidation completed:")
|
|
print(f" Total segments: {validation_stats['total_segments']}")
|
|
print(f" Valid segments: {validation_stats['valid_segments']}")
|
|
print(f" Invalid segments: {validation_stats['invalid_segments']}")
|
|
print(f" Validation accuracy: {validation_stats['valid_segments']/validation_stats['total_segments']*100:.1f}%")
|
|
|
|
print(f"\nError breakdown:")
|
|
for error_reason, count in validation_stats['error_reasons'].items():
|
|
print(f" {error_reason}: {count}")
|
|
|
|
# Second pass: extract features for valid segments (excluding neighbors of invalid ones)
|
|
print(f"\nExtracting neural features for validated segments...")
|
|
|
|
for phoneme, segments in phoneme_data.items():
|
|
phoneme_id = classification_data['phoneme_to_id'][phoneme]
|
|
print(f"Processing '{phoneme}' ({len(segments)} segments)...")
|
|
|
|
for segment_idx, segment in enumerate(segments):
|
|
# Check if this segment is valid
|
|
is_valid, error_reason, ground_truth_phoneme = segment_validity[(phoneme, segment_idx)]
|
|
|
|
# Check if neighboring segments are invalid (discard neighbors)
|
|
prev_invalid = (segment_idx > 0 and
|
|
not segment_validity[(phoneme, segment_idx - 1)][0])
|
|
next_invalid = (segment_idx < len(segments) - 1 and
|
|
not segment_validity[(phoneme, segment_idx + 1)][0])
|
|
|
|
if not is_valid:
|
|
continue # Skip invalid segments
|
|
|
|
if prev_invalid or next_invalid:
|
|
validation_stats['discarded_neighbors'] += 1
|
|
continue # Skip neighbors of invalid segments
|
|
|
|
# Get trial information
|
|
session = segment['session']
|
|
trial_key = (session, segment.get('trial_idx'))
|
|
|
|
# Try to get neural data for this trial
|
|
neural_features = None
|
|
if trial_key in ctc_data:
|
|
# We have the trial data, now extract the segment
|
|
trial_metadata = ctc_data[trial_key]
|
|
if 'neural_features' in trial_metadata:
|
|
neural_features = trial_metadata['neural_features']
|
|
else:
|
|
# Try to load from HDF5 files
|
|
neural_features = load_neural_data_for_trial(session, segment)
|
|
|
|
if neural_features is not None:
|
|
# Extract the specific time segment
|
|
start_time = int(segment['start_time'])
|
|
end_time = int(segment['end_time'])
|
|
|
|
# Ensure valid time range
|
|
if start_time <= end_time and end_time < len(neural_features):
|
|
# Extract neural features for this time segment
|
|
segment_features = neural_features[start_time:end_time+1] # Include end_time
|
|
|
|
# Convert to numpy array and handle different cases
|
|
if isinstance(segment_features, torch.Tensor):
|
|
segment_features = segment_features.numpy()
|
|
elif isinstance(segment_features, list):
|
|
segment_features = np.array(segment_features)
|
|
|
|
# For classification, we need a fixed-size feature vector
|
|
# Option 1: Use mean across time steps
|
|
if len(segment_features.shape) == 2: # (time, features)
|
|
feature_vector = np.mean(segment_features, axis=0)
|
|
elif len(segment_features.shape) == 1: # Already 1D
|
|
feature_vector = segment_features
|
|
else:
|
|
print(f"Unexpected feature shape: {segment_features.shape}")
|
|
continue
|
|
|
|
# Add to dataset
|
|
classification_data['features'].append(feature_vector)
|
|
classification_data['labels'].append(phoneme_id)
|
|
classification_data['metadata'].append({
|
|
'phoneme': phoneme,
|
|
'session': session,
|
|
'trial_num': segment.get('trial_num', -1),
|
|
'trial_idx': segment.get('trial_idx', -1),
|
|
'start_time': start_time,
|
|
'end_time': end_time,
|
|
'duration': end_time - start_time + 1,
|
|
'confidence': segment.get('confidence', 0.0),
|
|
'corpus': segment.get('corpus', 'unknown'),
|
|
'validated': True
|
|
})
|
|
|
|
validation_stats['successful_extractions'] += 1
|
|
|
|
# Progress update
|
|
if validation_stats['total_segments'] % 1000 == 0:
|
|
print(f" Processed {validation_stats['total_segments']} segments, extracted {validation_stats['successful_extractions']} features")
|
|
|
|
print(f"\nDataset creation completed!")
|
|
print(f"Total segments processed: {validation_stats['total_segments']}")
|
|
print(f"Valid segments (excluding neighbors): {validation_stats['valid_segments'] - validation_stats['discarded_neighbors']}")
|
|
print(f"Discarded neighbor segments: {validation_stats['discarded_neighbors']}")
|
|
print(f"Successful feature extractions: {validation_stats['successful_extractions']}")
|
|
print(f"Extraction success rate: {validation_stats['successful_extractions']/(validation_stats['valid_segments']-validation_stats['discarded_neighbors'])*100:.1f}%")
|
|
|
|
if validation_stats['successful_extractions'] == 0:
|
|
print("No features were extracted. Check neural data availability.")
|
|
return
|
|
|
|
# Convert to numpy arrays
|
|
classification_data['features'] = np.array(classification_data['features'])
|
|
classification_data['labels'] = np.array(classification_data['labels'])
|
|
|
|
print(f"\nFinal validated dataset shape:")
|
|
print(f"Features: {classification_data['features'].shape}")
|
|
print(f"Labels: {classification_data['labels'].shape}")
|
|
|
|
# Show class distribution
|
|
print(f"\nClass distribution:")
|
|
unique_labels, counts = np.unique(classification_data['labels'], return_counts=True)
|
|
for label_id, count in zip(unique_labels, counts):
|
|
phoneme = classification_data['id_to_phoneme'][label_id]
|
|
print(f" {label_id:2d} ('{phoneme}'): {count:4d} samples")
|
|
|
|
# Save the classification dataset
|
|
timestamp = latest_dataset.name.split('_')[-1].replace('.pkl', '')
|
|
output_file = f"phoneme_classification_dataset_validated_{timestamp}.pkl"
|
|
output_path = data_dir / output_file
|
|
|
|
# Add validation stats to the dataset
|
|
classification_data['validation_stats'] = validation_stats
|
|
|
|
with open(output_path, 'wb') as f:
|
|
pickle.dump(classification_data, f)
|
|
|
|
print(f"\nValidated classification dataset saved to: {output_file}")
|
|
|
|
# Save error data separately
|
|
error_data['validation_stats'] = validation_stats
|
|
error_file = f"phoneme_validation_errors_{timestamp}.pkl"
|
|
error_path = data_dir / error_file
|
|
|
|
with open(error_path, 'wb') as f:
|
|
pickle.dump(error_data, f)
|
|
|
|
print(f"Validation errors saved to: {error_file}")
|
|
|
|
# Create a simple train/test split example
|
|
create_train_test_split(classification_data, data_dir, timestamp)
|
|
|
|
return classification_data
|
|
|
|
def create_train_test_split(data, data_dir, timestamp):
|
|
"""Create train/test split for the classification dataset"""
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
print(f"\nCreating train/test split...")
|
|
|
|
X = data['features']
|
|
y = data['labels']
|
|
metadata = data['metadata']
|
|
|
|
# Split by session to avoid data leakage
|
|
sessions = [meta['session'] for meta in metadata]
|
|
unique_sessions = list(set(sessions))
|
|
|
|
print(f"Available sessions: {len(unique_sessions)}")
|
|
|
|
if len(unique_sessions) >= 4:
|
|
# Use session-based split
|
|
train_sessions = unique_sessions[:int(len(unique_sessions) * 0.8)]
|
|
test_sessions = unique_sessions[int(len(unique_sessions) * 0.8):]
|
|
|
|
train_indices = [i for i, meta in enumerate(metadata) if meta['session'] in train_sessions]
|
|
test_indices = [i for i, meta in enumerate(metadata) if meta['session'] in test_sessions]
|
|
|
|
X_train, X_test = X[train_indices], X[test_indices]
|
|
y_train, y_test = y[train_indices], y[test_indices]
|
|
|
|
print(f"Session-based split:")
|
|
print(f" Train sessions: {train_sessions}")
|
|
print(f" Test sessions: {test_sessions}")
|
|
else:
|
|
# Use random split
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
X, y, test_size=0.2, random_state=42, stratify=y
|
|
)
|
|
print(f"Random split (stratified):")
|
|
|
|
print(f" Train samples: {len(X_train)}")
|
|
print(f" Test samples: {len(X_test)}")
|
|
|
|
# Standardize features
|
|
scaler = StandardScaler()
|
|
X_train_scaled = scaler.fit_transform(X_train)
|
|
X_test_scaled = scaler.transform(X_test)
|
|
|
|
# Save split data
|
|
split_data = {
|
|
'X_train': X_train_scaled,
|
|
'X_test': X_test_scaled,
|
|
'y_train': y_train,
|
|
'y_test': y_test,
|
|
'scaler': scaler,
|
|
'phoneme_to_id': data['phoneme_to_id'],
|
|
'id_to_phoneme': data['id_to_phoneme']
|
|
}
|
|
|
|
split_file = f"phoneme_classification_split_{timestamp}.pkl"
|
|
split_path = data_dir / split_file
|
|
|
|
with open(split_path, 'wb') as f:
|
|
pickle.dump(split_data, f)
|
|
|
|
print(f"Train/test split saved to: {split_file}")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
classification_data = create_phoneme_classification_dataset()
|
|
except Exception as e:
|
|
print(f"Error creating classification dataset: {e}")
|
|
import traceback
|
|
traceback.print_exc() |