This commit is contained in:
Zchen
2025-10-17 01:58:28 +08:00
parent 8ee09b6b5e
commit 49700456b8

View File

@@ -98,6 +98,7 @@ class BrainToTextDecoderTrainerTF:
# This ensures we're in the correct replica context
print("🔧 Creating optimizer slot variables within TPU replica context...")
@tf.function
def init_optimizer_slots():
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))