From d92889d435310eae86ee2281e0e628740ce424f3 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 22 Oct 2025 10:23:22 +0800 Subject: [PATCH] Adjust safety margin in dataset shape analysis to account for data augmentation effects --- model_training_nnn_tpu/dataset_tf.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index eddd85f..6a762c7 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -934,15 +934,15 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 'n_features': dataset_tf.feature_dim } - # 5. 添加适当的安全边际 - 基于分析范围调整 + # 5. 添加适当的安全边际 - 基于分析范围和数据增强调整 if sample_size == -1: - # 全数据分析:只需要很小的边际应对可能的舍入误差 - safety_margin = 1.02 # 2% buffer for rounding errors - margin_reason = "minimal buffer for full dataset analysis" + # 全数据分析:需要为数据增强预留空间(特别是Gaussian平滑) + safety_margin = 1.15 # 15% buffer for data augmentation effects + margin_reason = "buffer for full dataset analysis + data augmentation" else: - # 采样分析:需要更大的边际应对未采样到的极值 - safety_margin = 1.3 # 30% buffer for sampling uncertainty - margin_reason = f"larger buffer due to sampling only {sample_size} trials" + # 采样分析:需要更大的边际应对未采样到的极值 + 数据增强 + safety_margin = 1.35 # 35% buffer for sampling uncertainty + data augmentation + margin_reason = f"larger buffer due to sampling only {sample_size} trials + data augmentation" final_max_shapes = { 'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),