From 8d94b706f71b381f3fbaf666bb3ea0e245a6a0de Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 22 Oct 2025 13:08:46 +0800 Subject: [PATCH] Enhance dataset shape analysis to incorporate data augmentation effects and adjust safety margins accordingly --- model_training_nnn_tpu/dataset_tf.py | 75 +++++++++++++++++++++------- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 6a762c7..9377387 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -830,17 +830,20 @@ def train_test_split_indices(file_paths: List[str], return train_trials, test_trials -def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]: +def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, + sample_size: int = 100, + transform_args: Optional[Dict[str, Any]] = None) -> Dict[str, int]: """ Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch, - utilizing multiple CPU cores and the dataset's caching mechanism. + applying data augmentation first to capture post-augmentation shapes. Args: dataset_tf: Dataset instance to analyze sample_size: Number of samples to analyze (set to -1 for all data) + transform_args: Data transformation configuration to apply during analysis Returns: - Dictionary with maximum dimensions + Dictionary with maximum dimensions after data augmentation """ print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...") start_time = time.time() @@ -867,29 +870,55 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = total_trials_to_analyze = len(trials_to_check) print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}") - # 定义一个辅助函数,供每个线程调用 + # 定义一个辅助函数,供每个线程调用,包含数据增强 def analyze_single_trial(day_trial_pair): - """Loads and analyzes a single trial, returns its shapes.""" + """Loads and analyzes a single trial with data augmentation, returns its shapes.""" day, trial = day_trial_pair try: # 复用 dataset_tf 的加载和缓存逻辑 trial_data = dataset_tf._load_trial_data(day, trial) - # 直接从加载的数据中获取信息 - time_steps = int(trial_data['n_time_steps']) + # 模拟数据增强过程(如果提供了transform_args) + if transform_args: + # 将numpy数据转换为TensorFlow张量 + features = tf.constant(trial_data['input_features'], dtype=tf.float32) + n_time_steps = tf.constant(trial_data['n_time_steps'], dtype=tf.int32) + + # 添加批次维度(batch_size=1) + features = tf.expand_dims(features, axis=0) + n_time_steps = tf.expand_dims(n_time_steps, axis=0) + + # 应用数据变换 + augmented_features, augmented_n_time_steps = DataAugmentationTF.transform_data( + features, n_time_steps, transform_args, training=True + ) + + # 移除批次维度并获取增强后的形状 + augmented_features = tf.squeeze(augmented_features, axis=0) + augmented_n_time_steps = tf.squeeze(augmented_n_time_steps, axis=0) + + # 获取增强后的实际时间步数 + time_steps = int(augmented_n_time_steps.numpy()) + actual_features_shape = augmented_features.shape[0] # 可能与n_time_steps不同 + time_steps = max(time_steps, actual_features_shape) # 取较大值 + else: + # 不应用数据增强,使用原始数据 + time_steps = int(trial_data['n_time_steps']) + + # 获取其他维度(这些不受数据增强影响) phone_seq_len = int(trial_data['phone_seq_lens']) - # 处理 transcription 数据 - 它可能是数组 + # 处理 transcription 数据 transcription_data = trial_data['transcription'] if hasattr(transcription_data, '__len__'): transcription_len = len(transcription_data) else: - transcription_len = 1 # 如果是标量,长度为1 + transcription_len = 1 return (time_steps, phone_seq_len, transcription_len) except Exception as e: logging.warning(f"Failed to analyze trial {day}_{trial}: {e}") - return None # 返回 None 表示失败 + return None # 3. 使用 ThreadPoolExecutor 进行并行处理 # Use dynamic calculation based on CPU cores with reasonable upper limit @@ -934,15 +963,23 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 'n_features': dataset_tf.feature_dim } - # 5. 添加适当的安全边际 - 基于分析范围和数据增强调整 - if sample_size == -1: - # 全数据分析:需要为数据增强预留空间(特别是Gaussian平滑) - safety_margin = 1.15 # 15% buffer for data augmentation effects - margin_reason = "buffer for full dataset analysis + data augmentation" + # 5. 添加适当的安全边际 - 现在分析已包含数据增强效果 + if transform_args: + # 已应用数据增强分析:只需要很小的边际应对随机性 + if sample_size == -1: + safety_margin = 1.05 # 5% buffer for augmentation randomness (full data) + margin_reason = "small buffer for augmentation randomness after full analysis" + else: + safety_margin = 1.15 # 15% buffer for augmentation randomness + sampling + margin_reason = f"buffer for augmentation randomness + sampling {sample_size} trials" else: - # 采样分析:需要更大的边际应对未采样到的极值 + 数据增强 - safety_margin = 1.35 # 35% buffer for sampling uncertainty + data augmentation - margin_reason = f"larger buffer due to sampling only {sample_size} trials + data augmentation" + # 未应用数据增强:使用原始逻辑 + if sample_size == -1: + safety_margin = 1.02 # 2% buffer for rounding errors + margin_reason = "minimal buffer for full dataset analysis without augmentation" + else: + safety_margin = 1.3 # 30% buffer for sampling uncertainty + margin_reason = f"larger buffer due to sampling only {sample_size} trials without augmentation" final_max_shapes = { 'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin), @@ -1005,7 +1042,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, # Analyze dataset to get maximum shapes print("📊 Analyzing dataset for maximum shapes...") - max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1) # Analyze ALL data for maximum accuracy + max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args) # Pass transform_args # Use static shapes based on analysis padded_shapes = {