Enhance dataset shape analysis to incorporate data augmentation effects and adjust safety margins accordingly

This commit is contained in:
Zchen
2025-10-22 13:08:46 +08:00
parent d92889d435
commit 8d94b706f7

View File

@@ -830,17 +830,20 @@ def train_test_split_indices(file_paths: List[str],
return train_trials, test_trials
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
sample_size: int = 100,
transform_args: Optional[Dict[str, Any]] = None) -> Dict[str, int]:
"""
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
utilizing multiple CPU cores and the dataset's caching mechanism.
applying data augmentation first to capture post-augmentation shapes.
Args:
dataset_tf: Dataset instance to analyze
sample_size: Number of samples to analyze (set to -1 for all data)
transform_args: Data transformation configuration to apply during analysis
Returns:
Dictionary with maximum dimensions
Dictionary with maximum dimensions after data augmentation
"""
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
start_time = time.time()
@@ -867,29 +870,55 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
total_trials_to_analyze = len(trials_to_check)
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
# 定义一个辅助函数,供每个线程调用
# 定义一个辅助函数,供每个线程调用,包含数据增强
def analyze_single_trial(day_trial_pair):
"""Loads and analyzes a single trial, returns its shapes."""
"""Loads and analyzes a single trial with data augmentation, returns its shapes."""
day, trial = day_trial_pair
try:
# 复用 dataset_tf 的加载和缓存逻辑
trial_data = dataset_tf._load_trial_data(day, trial)
# 直接从加载的数据中获取信息
time_steps = int(trial_data['n_time_steps'])
# 模拟数据增强过程如果提供了transform_args
if transform_args:
# 将numpy数据转换为TensorFlow张量
features = tf.constant(trial_data['input_features'], dtype=tf.float32)
n_time_steps = tf.constant(trial_data['n_time_steps'], dtype=tf.int32)
# 添加批次维度batch_size=1
features = tf.expand_dims(features, axis=0)
n_time_steps = tf.expand_dims(n_time_steps, axis=0)
# 应用数据变换
augmented_features, augmented_n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, transform_args, training=True
)
# 移除批次维度并获取增强后的形状
augmented_features = tf.squeeze(augmented_features, axis=0)
augmented_n_time_steps = tf.squeeze(augmented_n_time_steps, axis=0)
# 获取增强后的实际时间步数
time_steps = int(augmented_n_time_steps.numpy())
actual_features_shape = augmented_features.shape[0] # 可能与n_time_steps不同
time_steps = max(time_steps, actual_features_shape) # 取较大值
else:
# 不应用数据增强,使用原始数据
time_steps = int(trial_data['n_time_steps'])
# 获取其他维度(这些不受数据增强影响)
phone_seq_len = int(trial_data['phone_seq_lens'])
# 处理 transcription 数据 - 它可能是数组
# 处理 transcription 数据
transcription_data = trial_data['transcription']
if hasattr(transcription_data, '__len__'):
transcription_len = len(transcription_data)
else:
transcription_len = 1 # 如果是标量长度为1
transcription_len = 1
return (time_steps, phone_seq_len, transcription_len)
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
return None # 返回 None 表示失败
return None
# 3. 使用 ThreadPoolExecutor 进行并行处理
# Use dynamic calculation based on CPU cores with reasonable upper limit
@@ -934,15 +963,23 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
'n_features': dataset_tf.feature_dim
}
# 5. 添加适当的安全边际 - 基于分析范围和数据增强调整
if sample_size == -1:
# 全数据分析:需要为数据增强预留空间特别是Gaussian平滑
safety_margin = 1.15 # 15% buffer for data augmentation effects
margin_reason = "buffer for full dataset analysis + data augmentation"
# 5. 添加适当的安全边际 - 现在分析已包含数据增强效果
if transform_args:
# 已应用数据增强分析:需要很小的边际应对随机性
if sample_size == -1:
safety_margin = 1.05 # 5% buffer for augmentation randomness (full data)
margin_reason = "small buffer for augmentation randomness after full analysis"
else:
safety_margin = 1.15 # 15% buffer for augmentation randomness + sampling
margin_reason = f"buffer for augmentation randomness + sampling {sample_size} trials"
else:
# 采样分析:需要更大的边际应对未采样到的极值 + 数据增强
safety_margin = 1.35 # 35% buffer for sampling uncertainty + data augmentation
margin_reason = f"larger buffer due to sampling only {sample_size} trials + data augmentation"
# 未应用数据增强:使用原始逻辑
if sample_size == -1:
safety_margin = 1.02 # 2% buffer for rounding errors
margin_reason = "minimal buffer for full dataset analysis without augmentation"
else:
safety_margin = 1.3 # 30% buffer for sampling uncertainty
margin_reason = f"larger buffer due to sampling only {sample_size} trials without augmentation"
final_max_shapes = {
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
@@ -1005,7 +1042,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
# Analyze dataset to get maximum shapes
print("📊 Analyzing dataset for maximum shapes...")
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1) # Analyze ALL data for maximum accuracy
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args) # Pass transform_args
# Use static shapes based on analysis
padded_shapes = {