From 6c7abfcca862584a31e24bf0cf50c0977be19cf5 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:53:58 +0800 Subject: [PATCH] f --- model_training_nnn_tpu/trainer_tf.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 32bff8a..d1a3a8c 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -93,25 +93,8 @@ class BrainToTextDecoderTrainerTF: print("🔧 Initializing optimizer for TPU training...") print(f"Optimizer type: {type(self.optimizer).__name__}") - - # Initialize optimizer slot variables using strategy.run - # This ensures we're in the correct replica context - print("🔧 Creating optimizer slot variables within TPU replica context...") - - @tf.function - def init_optimizer_slots(): - # Use ALL trainable variables for slot initialization, not just filtered ones - # This ensures slot variables are created for all variables that might need gradients - all_variables = self.model.trainable_variables - dummy_gradients = [tf.zeros_like(var) for var in all_variables] - - # Apply gradients for all variables to ensure all slots are created - self.optimizer.apply_gradients(zip(dummy_gradients, all_variables)) - return tf.constant(True) # Return something to satisfy strategy.run - - # Run the slot initialization in replica context - self.strategy.run(init_optimizer_slots) print("✅ Optimizer ready for TPU training") + print("📝 Note: Optimizer slot variables will be created automatically during first training step") self.lr_scheduler = self._create_lr_scheduler() self.ctc_loss = CTCLoss(blank_index=0, reduction='none') @@ -503,7 +486,6 @@ class BrainToTextDecoderTrainerTF: else: print(f"Model has {total_params:,} trainable parameters") - @tf.function def _train_step(self, batch, step): """Single training step with gradient tape""" features = batch['input_features']