legacy adam

This commit is contained in:
Zchen
2025-10-17 01:26:02 +08:00
parent 7df78244e6
commit 0a72143513

View File

@@ -443,13 +443,25 @@ class BrainToTextDecoderTrainerTF:
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
print("💡 Manual L2 regularization will be applied in training step")
optimizer = tf.keras.optimizers.Adam(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
# No weight_decay parameter in Adam - handled manually
)
# Use legacy Adam optimizer for better TPU distributed training compatibility
# Legacy optimizers have more stable distributed training implementations
try:
optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
)
print("✅ Using legacy Adam optimizer for better TPU compatibility")
except AttributeError:
# Fallback to standard Adam if legacy is not available
optimizer = tf.keras.optimizers.Adam(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
)
print("⚠️ Using standard Adam optimizer (legacy not available)")
return optimizer