This commit is contained in:
Zchen
2025-10-16 21:13:42 +08:00
parent a0b59c6987
commit dde6378481
2 changed files with 44 additions and 8 deletions

View File

@@ -832,7 +832,16 @@ def create_tpu_strategy():
print(" Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID")
print("🔄 Falling back to default strategy (CPU/GPU)")
return tf.distribute.get_strategy()
fallback_strategy = tf.distribute.get_strategy()
print(f"🎯 Fallback strategy created: {type(fallback_strategy).__name__}")
print(f"📊 Fallback strategy replicas: {fallback_strategy.num_replicas_in_sync}")
# Ensure we never return None
if fallback_strategy is None:
print("⚠️ Warning: Default strategy is None, creating OneDeviceStrategy")
fallback_strategy = tf.distribute.OneDeviceStrategy("/CPU:0")
return fallback_strategy
def build_model_for_tpu(config):

View File

@@ -48,7 +48,11 @@ class BrainToTextDecoderTrainerTF:
# Initialize TPU strategy
self.strategy = create_tpu_strategy()
if self.strategy is None:
raise RuntimeError("Failed to create TPU strategy - strategy is None")
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
print(f"Strategy type: {type(self.strategy).__name__}")
# Configure mixed precision for TPU v5e-8
if args.get('use_amp', True):
@@ -95,9 +99,21 @@ class BrainToTextDecoderTrainerTF:
print("🔧 Pre-building optimizer state for TPU...")
# Force optimizer to build its internal state within strategy scope
# This prevents the 'NoneType' strategy error during first apply_gradients
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
print("✅ Optimizer state pre-built successfully")
try:
# Check if strategy is properly initialized before applying gradients
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
print("✅ Optimizer state pre-built successfully with TPU strategy")
else:
# Fallback: just build optimizer variables without applying gradients
print("⚠️ Strategy not fully initialized, skipping optimizer pre-build")
# Alternative: trigger optimizer variable creation
_ = self.optimizer.iterations
print("✅ Optimizer state initialized (fallback mode)")
except Exception as e:
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
print("✅ Continuing without optimizer pre-build")
print("📅 Setting up learning rate scheduler...")
self.lr_scheduler = self._create_lr_scheduler()
@@ -687,10 +703,21 @@ class BrainToTextDecoderTrainerTF:
# Distributed training step
self.logger.info("Running distributed training step...")
# Ensure we're in the correct TPU strategy scope
with self.strategy.scope():
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
try:
with self.strategy.scope():
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
except AttributeError as e:
if "merge_call" in str(e):
error_msg = f"Strategy merge_call error at step {step}: {e}"
print(error_msg)
if self.logger:
self.logger.error(error_msg)
self.logger.error("This indicates the strategy is not properly initialized")
raise RuntimeError(f"TPU strategy failed during training step {step}: {e}")
else:
raise
# Reduce across replicas
self.logger.info("Reducing results across replicas...")