adamw to adam
This commit is contained in:
@@ -137,12 +137,14 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||||
|
||||
# TPU-specific weight decay handling
|
||||
# Manual weight decay handling for all environments (since we use Adam)
|
||||
self.manual_weight_decay = False
|
||||
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0:
|
||||
if self.args.get('weight_decay', 0.0) > 0:
|
||||
self.manual_weight_decay = True
|
||||
self.weight_decay_rate = self.args['weight_decay']
|
||||
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
|
||||
else:
|
||||
print("💡 No weight decay configured")
|
||||
|
||||
if self.adv_enabled:
|
||||
if self.logger:
|
||||
@@ -435,28 +437,19 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# For TPU training, we need to be more explicit about optimizer configuration
|
||||
# to avoid strategy context issues
|
||||
if isinstance(self.strategy, tf.distribute.TPUStrategy):
|
||||
print("Using TPU-optimized optimizer configuration")
|
||||
# TPU-specific optimizer configuration
|
||||
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues
|
||||
# We'll implement manual L2 regularization instead
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon'],
|
||||
weight_decay=0.0 # Disabled for TPU compatibility
|
||||
# REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm
|
||||
)
|
||||
else:
|
||||
print("Using standard optimizer configuration")
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon'],
|
||||
weight_decay=self.args['weight_decay']
|
||||
)
|
||||
# IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs
|
||||
# AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0
|
||||
# We implement manual L2 regularization (weight decay) in the training step instead
|
||||
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
|
||||
print("💡 Manual L2 regularization will be applied in training step")
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon']
|
||||
# No weight_decay parameter in Adam - handled manually
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
Reference in New Issue
Block a user