Enhance safety margin calculations in dataset shape analysis to address double augmentation issues caused by random transformations. Implement intelligent detection of random vs deterministic augmentations, applying appropriate safety margins to prevent shape mismatch errors during training.

This commit is contained in:
Zchen
2025-10-22 15:09:38 +08:00
parent 4a99f50afd
commit fde5ea20ad

View File

@@ -20,6 +20,32 @@ class BrainToTextDatasetTF:
This class creates tf.data.Dataset objects that efficiently load and batch
brain-to-text data from HDF5 files with TPU-optimized operations.
🔧 IMPORTANT FIX APPLIED (v2024.10): Double Augmentation Safety Margin
The dataset now includes intelligent safety margin calculation to prevent shape
mismatch errors caused by random data augmentation uncertainty. The system:
1. Detects random vs deterministic augmentations in transform_args
2. Applies conservative safety margins (25-50%) for random augmentations
3. Uses efficient margins (5-15%) for deterministic-only augmentations
4. Prevents runtime errors from different random outcomes during shape analysis
vs actual training
This resolves the "double augmentation problem" where shape analysis and
training both apply random transformations but get different results.
Usage Example:
>>> # Random augmentation config (gets 25% safety margin)
>>> transform_args = {
... 'static_gain_std': 0.1, # Random
... 'white_noise_std': 0.05, # Random
... 'smooth_data': True # Deterministic
... }
>>> dataset = BrainToTextDatasetTF(...)
>>> input_fn = create_input_fn(dataset, transform_args, use_static_shapes=True)
🎲 Random augmentations detected: ['static_gain_std', 'white_noise_std']
📏 Using 25% safety margin for random augmentation uncertainty
"""
def __init__(
@@ -838,13 +864,37 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
applying data augmentation first to capture post-augmentation shapes.
🔧 FIXED: Intelligent safety margin calculation for random augmentation uncertainty.
The function now detects random vs deterministic augmentations and applies appropriate
safety margins to handle the "double augmentation problem":
- Random augmentations (noise, random_cut, etc.): 25-50% safety margin
- Deterministic augmentations (gaussian smoothing): 5-15% safety margin
- No augmentation: 2-30% safety margin (depends on sampling)
This prevents shape mismatch errors caused by different random outcomes between
shape analysis and actual training data augmentation.
Args:
dataset_tf: Dataset instance to analyze
sample_size: Number of samples to analyze (set to -1 for all data)
transform_args: Data transformation configuration to apply during analysis
training: Whether to apply training-mode augmentations during analysis
Returns:
Dictionary with maximum dimensions after data augmentation
Dictionary with maximum dimensions after data augmentation and safety margins
Example:
>>> transform_args = {
... 'static_gain_std': 0.1, # Random augmentation detected
... 'white_noise_std': 0.05, # Random augmentation detected
... 'smooth_data': True # Deterministic augmentation detected
... }
>>> max_shapes = analyze_dataset_shapes(dataset, transform_args=transform_args)
🎲 Random augmentations detected: ['static_gain_std', 'white_noise_std']
🔄 Deterministic augmentations detected: ['gaussian_smoothing']
📏 Final max shapes (with 25% safety margin - conservative buffer for random augmentation uncertainty):
"""
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
start_time = time.time()
@@ -964,17 +1014,48 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
'n_features': dataset_tf.feature_dim
}
# 5. 添加适当的安全边际 - Gaussian平滑现在使用VALID padding不会增加长度
# 5. 添加适当的安全边际 - 智能检测随机增强并应用保守策略
if transform_args:
# 已应用数据增强分析VALID padding确保长度不增加只需小边际
if sample_size == -1:
safety_margin = 1.05 # 5% buffer for minor augmentation effects (full data)
margin_reason = "small buffer after full analysis with length-preserving augmentation"
# 检测是否包含随机增强操作
random_augmentation_keys = ['static_gain_std', 'white_noise_std', 'constant_offset_std',
'random_walk_std', 'random_cut']
active_random_augmentations = [
key for key in random_augmentation_keys
if transform_args.get(key, 0) > 0
]
has_random_augmentation = len(active_random_augmentations) > 0
# 诊断信息:显示检测到的增强类型
if has_random_augmentation:
print(f"🎲 Random augmentations detected: {active_random_augmentations}")
# 检测确定性增强
deterministic_augmentations = []
if transform_args.get('smooth_data', False):
deterministic_augmentations.append('gaussian_smoothing')
if deterministic_augmentations:
print(f"🔄 Deterministic augmentations detected: {deterministic_augmentations}")
if has_random_augmentation:
# 包含随机增强:使用保守的安全边际来应对随机性不确定性
if sample_size == -1:
safety_margin = 1.25 # 25% buffer for random augmentation uncertainty (full data)
margin_reason = "conservative buffer for random augmentation uncertainty (full analysis)"
else:
safety_margin = 1.5 # 50% buffer for random augmentation + sampling uncertainty
margin_reason = f"conservative buffer for random augmentation + sampling {sample_size} trials"
else:
safety_margin = 1.15 # 15% buffer for augmentation + sampling
margin_reason = f"buffer for augmentation + sampling {sample_size} trials"
# 仅确定性增强如Gaussian平滑使用较小边际
if sample_size == -1:
safety_margin = 1.05 # 5% buffer for deterministic augmentation effects
margin_reason = "small buffer for deterministic augmentation (full analysis)"
else:
safety_margin = 1.15 # 15% buffer for deterministic augmentation + sampling
margin_reason = f"buffer for deterministic augmentation + sampling {sample_size} trials"
else:
# 未应用数据增强:使用原始逻辑
print(f"🚫 No data augmentation detected")
if sample_size == -1:
safety_margin = 1.02 # 2% buffer for rounding errors
margin_reason = "minimal buffer for full dataset analysis without augmentation"