From dde637848154328c9ecfaa5cc111c021546c8a2a Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 21:13:42 +0800 Subject: [PATCH] fixed --- model_training_nnn_tpu/rnn_model_tf.py | 11 ++++++- model_training_nnn_tpu/trainer_tf.py | 41 +++++++++++++++++++++----- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/model_training_nnn_tpu/rnn_model_tf.py b/model_training_nnn_tpu/rnn_model_tf.py index b7e8342..6a0e7c8 100644 --- a/model_training_nnn_tpu/rnn_model_tf.py +++ b/model_training_nnn_tpu/rnn_model_tf.py @@ -832,7 +832,16 @@ def create_tpu_strategy(): print(" Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID") print("🔄 Falling back to default strategy (CPU/GPU)") - return tf.distribute.get_strategy() + fallback_strategy = tf.distribute.get_strategy() + print(f"🎯 Fallback strategy created: {type(fallback_strategy).__name__}") + print(f"📊 Fallback strategy replicas: {fallback_strategy.num_replicas_in_sync}") + + # Ensure we never return None + if fallback_strategy is None: + print("⚠️ Warning: Default strategy is None, creating OneDeviceStrategy") + fallback_strategy = tf.distribute.OneDeviceStrategy("/CPU:0") + + return fallback_strategy def build_model_for_tpu(config): diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 4183208..2264fc1 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -48,7 +48,11 @@ class BrainToTextDecoderTrainerTF: # Initialize TPU strategy self.strategy = create_tpu_strategy() + if self.strategy is None: + raise RuntimeError("Failed to create TPU strategy - strategy is None") + print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores") + print(f"Strategy type: {type(self.strategy).__name__}") # Configure mixed precision for TPU v5e-8 if args.get('use_amp', True): @@ -95,9 +99,21 @@ class BrainToTextDecoderTrainerTF: print("🔧 Pre-building optimizer state for TPU...") # Force optimizer to build its internal state within strategy scope # This prevents the 'NoneType' strategy error during first apply_gradients - dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables] - self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables)) - print("✅ Optimizer state pre-built successfully") + try: + # Check if strategy is properly initialized before applying gradients + if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')): + dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables] + self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables)) + print("✅ Optimizer state pre-built successfully with TPU strategy") + else: + # Fallback: just build optimizer variables without applying gradients + print("⚠️ Strategy not fully initialized, skipping optimizer pre-build") + # Alternative: trigger optimizer variable creation + _ = self.optimizer.iterations + print("✅ Optimizer state initialized (fallback mode)") + except Exception as e: + print(f"⚠️ Warning: Could not pre-build optimizer state: {e}") + print("✅ Continuing without optimizer pre-build") print("📅 Setting up learning rate scheduler...") self.lr_scheduler = self._create_lr_scheduler() @@ -687,10 +703,21 @@ class BrainToTextDecoderTrainerTF: # Distributed training step self.logger.info("Running distributed training step...") # Ensure we're in the correct TPU strategy scope - with self.strategy.scope(): - per_replica_losses, per_replica_grad_norms = self.strategy.run( - self._train_step, args=(batch, step) - ) + try: + with self.strategy.scope(): + per_replica_losses, per_replica_grad_norms = self.strategy.run( + self._train_step, args=(batch, step) + ) + except AttributeError as e: + if "merge_call" in str(e): + error_msg = f"Strategy merge_call error at step {step}: {e}" + print(error_msg) + if self.logger: + self.logger.error(error_msg) + self.logger.error("This indicates the strategy is not properly initialized") + raise RuntimeError(f"TPU strategy failed during training step {step}: {e}") + else: + raise # Reduce across replicas self.logger.info("Reducing results across replicas...")