Enhance safety margin calculations in dataset shape analysis to address double augmentation issues caused by random transformations. Implement intelligent detection of random vs deterministic augmentations, applying appropriate safety margins to prevent shape mismatch errors during training.
This commit is contained in:
@@ -20,6 +20,32 @@ class BrainToTextDatasetTF:
|
||||
|
||||
This class creates tf.data.Dataset objects that efficiently load and batch
|
||||
brain-to-text data from HDF5 files with TPU-optimized operations.
|
||||
|
||||
🔧 IMPORTANT FIX APPLIED (v2024.10): Double Augmentation Safety Margin
|
||||
|
||||
The dataset now includes intelligent safety margin calculation to prevent shape
|
||||
mismatch errors caused by random data augmentation uncertainty. The system:
|
||||
|
||||
1. Detects random vs deterministic augmentations in transform_args
|
||||
2. Applies conservative safety margins (25-50%) for random augmentations
|
||||
3. Uses efficient margins (5-15%) for deterministic-only augmentations
|
||||
4. Prevents runtime errors from different random outcomes during shape analysis
|
||||
vs actual training
|
||||
|
||||
This resolves the "double augmentation problem" where shape analysis and
|
||||
training both apply random transformations but get different results.
|
||||
|
||||
Usage Example:
|
||||
>>> # Random augmentation config (gets 25% safety margin)
|
||||
>>> transform_args = {
|
||||
... 'static_gain_std': 0.1, # Random
|
||||
... 'white_noise_std': 0.05, # Random
|
||||
... 'smooth_data': True # Deterministic
|
||||
... }
|
||||
>>> dataset = BrainToTextDatasetTF(...)
|
||||
>>> input_fn = create_input_fn(dataset, transform_args, use_static_shapes=True)
|
||||
🎲 Random augmentations detected: ['static_gain_std', 'white_noise_std']
|
||||
📏 Using 25% safety margin for random augmentation uncertainty
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -838,13 +864,37 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
|
||||
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
|
||||
applying data augmentation first to capture post-augmentation shapes.
|
||||
|
||||
🔧 FIXED: Intelligent safety margin calculation for random augmentation uncertainty.
|
||||
|
||||
The function now detects random vs deterministic augmentations and applies appropriate
|
||||
safety margins to handle the "double augmentation problem":
|
||||
|
||||
- Random augmentations (noise, random_cut, etc.): 25-50% safety margin
|
||||
- Deterministic augmentations (gaussian smoothing): 5-15% safety margin
|
||||
- No augmentation: 2-30% safety margin (depends on sampling)
|
||||
|
||||
This prevents shape mismatch errors caused by different random outcomes between
|
||||
shape analysis and actual training data augmentation.
|
||||
|
||||
Args:
|
||||
dataset_tf: Dataset instance to analyze
|
||||
sample_size: Number of samples to analyze (set to -1 for all data)
|
||||
transform_args: Data transformation configuration to apply during analysis
|
||||
training: Whether to apply training-mode augmentations during analysis
|
||||
|
||||
Returns:
|
||||
Dictionary with maximum dimensions after data augmentation
|
||||
Dictionary with maximum dimensions after data augmentation and safety margins
|
||||
|
||||
Example:
|
||||
>>> transform_args = {
|
||||
... 'static_gain_std': 0.1, # Random augmentation detected
|
||||
... 'white_noise_std': 0.05, # Random augmentation detected
|
||||
... 'smooth_data': True # Deterministic augmentation detected
|
||||
... }
|
||||
>>> max_shapes = analyze_dataset_shapes(dataset, transform_args=transform_args)
|
||||
🎲 Random augmentations detected: ['static_gain_std', 'white_noise_std']
|
||||
🔄 Deterministic augmentations detected: ['gaussian_smoothing']
|
||||
📏 Final max shapes (with 25% safety margin - conservative buffer for random augmentation uncertainty):
|
||||
"""
|
||||
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
|
||||
start_time = time.time()
|
||||
@@ -964,17 +1014,48 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
|
||||
'n_features': dataset_tf.feature_dim
|
||||
}
|
||||
|
||||
# 5. 添加适当的安全边际 - Gaussian平滑现在使用VALID padding,不会增加长度
|
||||
# 5. 添加适当的安全边际 - 智能检测随机增强并应用保守策略
|
||||
if transform_args:
|
||||
# 已应用数据增强分析:VALID padding确保长度不增加,只需小边际
|
||||
if sample_size == -1:
|
||||
safety_margin = 1.05 # 5% buffer for minor augmentation effects (full data)
|
||||
margin_reason = "small buffer after full analysis with length-preserving augmentation"
|
||||
# 检测是否包含随机增强操作
|
||||
random_augmentation_keys = ['static_gain_std', 'white_noise_std', 'constant_offset_std',
|
||||
'random_walk_std', 'random_cut']
|
||||
active_random_augmentations = [
|
||||
key for key in random_augmentation_keys
|
||||
if transform_args.get(key, 0) > 0
|
||||
]
|
||||
has_random_augmentation = len(active_random_augmentations) > 0
|
||||
|
||||
# 诊断信息:显示检测到的增强类型
|
||||
if has_random_augmentation:
|
||||
print(f"🎲 Random augmentations detected: {active_random_augmentations}")
|
||||
|
||||
# 检测确定性增强
|
||||
deterministic_augmentations = []
|
||||
if transform_args.get('smooth_data', False):
|
||||
deterministic_augmentations.append('gaussian_smoothing')
|
||||
|
||||
if deterministic_augmentations:
|
||||
print(f"🔄 Deterministic augmentations detected: {deterministic_augmentations}")
|
||||
|
||||
if has_random_augmentation:
|
||||
# 包含随机增强:使用保守的安全边际来应对随机性不确定性
|
||||
if sample_size == -1:
|
||||
safety_margin = 1.25 # 25% buffer for random augmentation uncertainty (full data)
|
||||
margin_reason = "conservative buffer for random augmentation uncertainty (full analysis)"
|
||||
else:
|
||||
safety_margin = 1.5 # 50% buffer for random augmentation + sampling uncertainty
|
||||
margin_reason = f"conservative buffer for random augmentation + sampling {sample_size} trials"
|
||||
else:
|
||||
safety_margin = 1.15 # 15% buffer for augmentation + sampling
|
||||
margin_reason = f"buffer for augmentation + sampling {sample_size} trials"
|
||||
# 仅确定性增强(如Gaussian平滑):使用较小边际
|
||||
if sample_size == -1:
|
||||
safety_margin = 1.05 # 5% buffer for deterministic augmentation effects
|
||||
margin_reason = "small buffer for deterministic augmentation (full analysis)"
|
||||
else:
|
||||
safety_margin = 1.15 # 15% buffer for deterministic augmentation + sampling
|
||||
margin_reason = f"buffer for deterministic augmentation + sampling {sample_size} trials"
|
||||
else:
|
||||
# 未应用数据增强:使用原始逻辑
|
||||
print(f"🚫 No data augmentation detected")
|
||||
if sample_size == -1:
|
||||
safety_margin = 1.02 # 2% buffer for rounding errors
|
||||
margin_reason = "minimal buffer for full dataset analysis without augmentation"
|
||||
|
||||
Reference in New Issue
Block a user