diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index bb44c53..3e5283f 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -851,23 +851,25 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], + max_shapes: Dict[str, int], training: bool = True, cache_path: Optional[str] = None) -> tf.data.Dataset: """ - Create input function for TPU training with DYNAMIC padding and data augmentation + Create input function for TPU training with PRE-ANALYZED FIXED shapes - This function uses dynamic shapes to avoid the "pad to a smaller size" error. - All variable-length dimensions use tf.TensorShape([None, ...]) to allow - TensorFlow to automatically determine the appropriate padding size for each batch. + This function uses pre-computed maximum shapes to create fixed-size batches, + ensuring XLA compilation success on TPU hardware. Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration + max_shapes: Pre-computed maximum shapes dictionary with keys: + 'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features' training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance Returns: - tf.data.Dataset ready for TPU training with dynamic shapes + tf.data.Dataset ready for TPU training with fixed shapes """ # Create individual example dataset with file-grouping I/O optimization @@ -916,22 +918,29 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, num_parallel_calls=tf.data.AUTOTUNE ) - # ========================= DYNAMIC SHAPES SOLUTION ========================= - # 使用动态形状避免 "pad to a smaller size" 错误 - # 这是最简单、最健壮的解决方案 - print("🔧 Using DYNAMIC shapes for maximum compatibility and robustness.") + # ========================= FIXED SHAPES SOLUTION ========================= + # 使用预分析的固定形状确保 XLA 编译成功 + print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:") - # Calculate number of features based on subset - n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 + # 从传入的参数中获取形状信息 + max_time_steps = max_shapes['max_time_steps'] + max_phone_seq_len = max_shapes['max_phone_seq_len'] + max_transcription_len = max_shapes['max_transcription_len'] + n_features = max_shapes['n_features'] - # Define dynamic padded shapes - all variable dimensions use None + print(f" Fixed time steps: {max_time_steps}") + print(f" Fixed phone sequence length: {max_phone_seq_len}") + print(f" Fixed transcription length: {max_transcription_len}") + print(f" Number of features: {n_features}") + + # Define fixed padded shapes - NO None values for XLA compatibility padded_shapes = { - 'input_features': tf.TensorShape([None, n_features]), # 时间维度动态 - 'seq_class_ids': tf.TensorShape([None]), # 序列长度动态 + 'input_features': tf.TensorShape([max_time_steps, n_features]), + 'seq_class_ids': tf.TensorShape([max_phone_seq_len]), 'n_time_steps': tf.TensorShape([]), # 标量 'phone_seq_lens': tf.TensorShape([]), # 标量 'day_indices': tf.TensorShape([]), # 标量 - 'transcriptions': tf.TensorShape([None]), # 转录长度动态 + 'transcriptions': tf.TensorShape([max_transcription_len]), 'block_nums': tf.TensorShape([]), # 标量 'trial_nums': tf.TensorShape([]) # 标量 } @@ -948,8 +957,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, 'trial_nums': 0 } - # Create batches with dynamic padding - TensorFlow will automatically - # determine the appropriate padding size for each batch + # Create batches with FIXED padding - XLA compiler will be happy! dataset = dataset.padded_batch( batch_size=dataset_tf.batch_size, padded_shapes=padded_shapes, diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 9e7dee7..7b7d8c7 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -27,7 +27,8 @@ from dataset_tf import ( BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices, - create_input_fn + create_input_fn, + analyze_dataset_shapes ) @@ -550,25 +551,25 @@ class BrainToTextDecoderTrainerTF: # Calculate losses if use_full: - # Clean CTC loss - use tf.nn.ctc_loss (TPU-compatible) + # Clean CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes) # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) clean_loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), + labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes logits=clean_logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), + label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True ) clean_loss = tf.reduce_mean(clean_loss) - # Noisy CTC loss - use tf.nn.ctc_loss (TPU-compatible) + # Noisy CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes) noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2]) noisy_loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), + labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes logits=noisy_logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), + label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True @@ -582,12 +583,12 @@ class BrainToTextDecoderTrainerTF: loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: - # Standard CTC loss - use tf.nn.ctc_loss (TPU-compatible) + # Standard CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes) logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), + labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes logits=logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), + label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True @@ -652,13 +653,13 @@ class BrainToTextDecoderTrainerTF: # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) - # Calculate loss - use tf.nn.ctc_loss (TPU-compatible) + # Calculate loss - use tf.nn.ctc_loss with dense labels (fixed shapes) # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] logits_time_major = tf.transpose(logits, [1, 0, 2]) loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), + labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes logits=logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), + label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True @@ -679,28 +680,60 @@ class BrainToTextDecoderTrainerTF: initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") - # Create datasets using modern distribution API - def create_dist_dataset_fn(input_dataset_tf, training): - """Create distributed dataset function for modern TPU strategy""" + # ========================= DATASET SHAPE ANALYSIS ========================= + # Perform one-time full dataset analysis for fixed shapes (TPU requirement) + self.logger.info("🚀 Performing one-time full dataset analysis for fixed shapes...") + + # Analyze training dataset (all data for accurate max shapes) + train_analysis_start = time.time() + train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1) + train_analysis_time = time.time() - train_analysis_start + self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s") + + # Analyze validation dataset (all data for accurate max shapes) + val_analysis_start = time.time() + val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1) + val_analysis_time = time.time() - val_analysis_start + self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s") + + # Use maximum shapes across both datasets for consistent padding + final_max_shapes = { + 'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']), + 'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']), + 'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']), + 'n_features': train_max_shapes['n_features'] + } + + self.logger.info(f"📊 Final fixed shapes for TPU training:") + self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}") + self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}") + self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}") + self.logger.info(f" Features: {final_max_shapes['n_features']}") + # ===================================================================== + + # Create datasets using modern distribution API with fixed shapes + def create_dist_dataset_fn(input_dataset_tf, training, max_shapes): + """Create distributed dataset function for modern TPU strategy with fixed shapes""" def dataset_fn(input_context): - # create_input_fn returns a complete, batched tf.data.Dataset + # create_input_fn now requires max_shapes parameter for fixed shapes return create_input_fn( input_dataset_tf, self.args['dataset']['data_transforms'], + max_shapes=max_shapes, # Pass pre-analyzed shapes training=training ) return self.strategy.distribute_datasets_from_function(dataset_fn) - # Distribute datasets using modern API + # Distribute datasets using modern API with fixed shapes self.logger.info("🔄 Distributing training dataset across TPU cores...") dist_start_time = time.time() - train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) + train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes) train_dist_time = time.time() - dist_start_time self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") self.logger.info("🔄 Distributing validation dataset across TPU cores...") val_start_time = time.time() - val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) + val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes) val_dist_time = time.time() - val_start_time self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")