This commit is contained in:
Zchen
2025-10-17 01:54:32 +08:00
parent a5a3179ca6
commit 8ee09b6b5e

View File

@@ -94,11 +94,17 @@ class BrainToTextDecoderTrainerTF:
print("🔧 Initializing optimizer for TPU training...")
print(f"Optimizer type: {type(self.optimizer).__name__}")
# Initialize optimizer slot variables within strategy scope
# This prevents the "different scope" error
print("🔧 Creating optimizer slot variables within TPU strategy scope...")
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
# Initialize optimizer slot variables using strategy.run
# This ensures we're in the correct replica context
print("🔧 Creating optimizer slot variables within TPU replica context...")
def init_optimizer_slots():
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
return tf.constant(True) # Return something to satisfy strategy.run
# Run the slot initialization in replica context
self.strategy.run(init_optimizer_slots)
print("✅ Optimizer ready for TPU training")
self.lr_scheduler = self._create_lr_scheduler()