This commit is contained in:
Zchen
2025-10-16 21:13:42 +08:00
parent a0b59c6987
commit dde6378481
2 changed files with 44 additions and 8 deletions

View File

@@ -48,7 +48,11 @@ class BrainToTextDecoderTrainerTF:
# Initialize TPU strategy
self.strategy = create_tpu_strategy()
if self.strategy is None:
raise RuntimeError("Failed to create TPU strategy - strategy is None")
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
print(f"Strategy type: {type(self.strategy).__name__}")
# Configure mixed precision for TPU v5e-8
if args.get('use_amp', True):
@@ -95,9 +99,21 @@ class BrainToTextDecoderTrainerTF:
print("🔧 Pre-building optimizer state for TPU...")
# Force optimizer to build its internal state within strategy scope
# This prevents the 'NoneType' strategy error during first apply_gradients
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
print("✅ Optimizer state pre-built successfully")
try:
# Check if strategy is properly initialized before applying gradients
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
print("✅ Optimizer state pre-built successfully with TPU strategy")
else:
# Fallback: just build optimizer variables without applying gradients
print("⚠️ Strategy not fully initialized, skipping optimizer pre-build")
# Alternative: trigger optimizer variable creation
_ = self.optimizer.iterations
print("✅ Optimizer state initialized (fallback mode)")
except Exception as e:
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
print("✅ Continuing without optimizer pre-build")
print("📅 Setting up learning rate scheduler...")
self.lr_scheduler = self._create_lr_scheduler()
@@ -687,10 +703,21 @@ class BrainToTextDecoderTrainerTF:
# Distributed training step
self.logger.info("Running distributed training step...")
# Ensure we're in the correct TPU strategy scope
with self.strategy.scope():
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
try:
with self.strategy.scope():
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
except AttributeError as e:
if "merge_call" in str(e):
error_msg = f"Strategy merge_call error at step {step}: {e}"
print(error_msg)
if self.logger:
self.logger.error(error_msg)
self.logger.error("This indicates the strategy is not properly initialized")
raise RuntimeError(f"TPU strategy failed during training step {step}: {e}")
else:
raise
# Reduce across replicas
self.logger.info("Reducing results across replicas...")