adamw to adam

This commit is contained in:
Zchen
2025-10-17 01:07:01 +08:00
parent a96e272f7b
commit 7df78244e6

View File

@@ -137,12 +137,14 @@ class BrainToTextDecoderTrainerTF:
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
# TPU-specific weight decay handling
# Manual weight decay handling for all environments (since we use Adam)
self.manual_weight_decay = False
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0:
if self.args.get('weight_decay', 0.0) > 0:
self.manual_weight_decay = True
self.weight_decay_rate = self.args['weight_decay']
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
else:
print("💡 No weight decay configured")
if self.adv_enabled:
if self.logger:
@@ -435,28 +437,19 @@ class BrainToTextDecoderTrainerTF:
# For TPU training, we need to be more explicit about optimizer configuration
# to avoid strategy context issues
if isinstance(self.strategy, tf.distribute.TPUStrategy):
print("Using TPU-optimized optimizer configuration")
# TPU-specific optimizer configuration
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues
# We'll implement manual L2 regularization instead
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=0.0 # Disabled for TPU compatibility
# REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm
)
else:
print("Using standard optimizer configuration")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args['weight_decay']
)
# IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs
# AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0
# We implement manual L2 regularization (weight decay) in the training step instead
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
print("💡 Manual L2 regularization will be applied in training step")
optimizer = tf.keras.optimizers.Adam(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
# No weight_decay parameter in Adam - handled manually
)
return optimizer