diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index c4e97d1..75e7b1e 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -443,13 +443,25 @@ class BrainToTextDecoderTrainerTF: print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)") print("💡 Manual L2 regularization will be applied in training step") - optimizer = tf.keras.optimizers.Adam( - learning_rate=self.args['lr_max'], - beta_1=self.args['beta0'], - beta_2=self.args['beta1'], - epsilon=self.args['epsilon'] - # No weight_decay parameter in Adam - handled manually - ) + # Use legacy Adam optimizer for better TPU distributed training compatibility + # Legacy optimizers have more stable distributed training implementations + try: + optimizer = tf.keras.optimizers.legacy.Adam( + learning_rate=self.args['lr_max'], + beta_1=self.args['beta0'], + beta_2=self.args['beta1'], + epsilon=self.args['epsilon'] + ) + print("✅ Using legacy Adam optimizer for better TPU compatibility") + except AttributeError: + # Fallback to standard Adam if legacy is not available + optimizer = tf.keras.optimizers.Adam( + learning_rate=self.args['lr_max'], + beta_1=self.args['beta0'], + beta_2=self.args['beta1'], + epsilon=self.args['epsilon'] + ) + print("⚠️ Using standard Adam optimizer (legacy not available)") return optimizer