f
This commit is contained in:
@@ -94,11 +94,17 @@ class BrainToTextDecoderTrainerTF:
|
||||
print("🔧 Initializing optimizer for TPU training...")
|
||||
print(f"Optimizer type: {type(self.optimizer).__name__}")
|
||||
|
||||
# Initialize optimizer slot variables within strategy scope
|
||||
# This prevents the "different scope" error
|
||||
print("🔧 Creating optimizer slot variables within TPU strategy scope...")
|
||||
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
||||
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
|
||||
# Initialize optimizer slot variables using strategy.run
|
||||
# This ensures we're in the correct replica context
|
||||
print("🔧 Creating optimizer slot variables within TPU replica context...")
|
||||
|
||||
def init_optimizer_slots():
|
||||
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
||||
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
|
||||
return tf.constant(True) # Return something to satisfy strategy.run
|
||||
|
||||
# Run the slot initialization in replica context
|
||||
self.strategy.run(init_optimizer_slots)
|
||||
print("✅ Optimizer ready for TPU training")
|
||||
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
|
Reference in New Issue
Block a user