diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 75e7b1e..7fe6caa 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -90,38 +90,31 @@ class BrainToTextDecoderTrainerTF: with self.strategy.scope(): self.model = self._build_model() self.optimizer = self._create_optimizer() - print("🔧 Pre-building optimizer state for TPU...") - # For TPU, we must ensure optimizer is completely ready before training - # since @tf.function doesn't allow dynamic building + print("🔧 Initializing optimizer for TPU training...") + # For TPU, we initialize the optimizer by accessing its basic properties + # The optimizer will be properly built when first used in training try: - print("✅ Building optimizer with model variables...") + print("✅ Checking optimizer initialization...") - # Explicitly build the optimizer with model variables - print(f"Building optimizer with {len(self.model.trainable_variables)} variables") - self.optimizer.build(self.model.trainable_variables) - print("✅ Optimizer built with model variables") + # Access optimizer properties to ensure it's properly initialized + # This is safe and works with all TensorFlow/Keras optimizer versions + print(f"Optimizer type: {type(self.optimizer).__name__}") + print(f"Learning rate: {self.optimizer.learning_rate}") - # Verify optimizer is properly built - just check iterations - print(f"Optimizer iterations: {self.optimizer.iterations}") + # Access iterations to ensure optimizer state tracking is ready + # This creates the iterations variable without building the full state + iterations = self.optimizer.iterations + print(f"Optimizer iterations initialized: {iterations}") - # For TPU training, we should also ensure the optimizer has all its state - # variables created. We can do this by creating dummy variables that match - # the model variables, but we don't apply them (avoid the replica context issue) - print("🔄 Ensuring optimizer state variables are created...") - - # Force creation of optimizer variables by accessing them - # This is safe and doesn't require replica context - _ = self.optimizer.iterations # This ensures basic state is created - - print("✅ Optimizer fully ready for TPU training") - print("📝 Note: Optimizer will work correctly in @tf.function context") + print("✅ Optimizer ready for TPU training") + print("📝 Note: Optimizer state will be built automatically during first training step") except Exception as e: - print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}") + print(f"❌ CRITICAL: Could not initialize optimizer: {e}") print(f"Error type: {type(e).__name__}") import traceback print(f"Full traceback: {traceback.format_exc()}") - raise RuntimeError(f"Optimizer pre-build failed: {e}") from e + raise RuntimeError(f"Optimizer initialization failed: {e}") from e self.lr_scheduler = self._create_lr_scheduler() self.ctc_loss = CTCLoss(blank_index=0, reduction='none')