adamw to adam
This commit is contained in:
@@ -137,12 +137,14 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||||||
|
|
||||||
# TPU-specific weight decay handling
|
# Manual weight decay handling for all environments (since we use Adam)
|
||||||
self.manual_weight_decay = False
|
self.manual_weight_decay = False
|
||||||
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0:
|
if self.args.get('weight_decay', 0.0) > 0:
|
||||||
self.manual_weight_decay = True
|
self.manual_weight_decay = True
|
||||||
self.weight_decay_rate = self.args['weight_decay']
|
self.weight_decay_rate = self.args['weight_decay']
|
||||||
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
|
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
|
||||||
|
else:
|
||||||
|
print("💡 No weight decay configured")
|
||||||
|
|
||||||
if self.adv_enabled:
|
if self.adv_enabled:
|
||||||
if self.logger:
|
if self.logger:
|
||||||
@@ -435,28 +437,19 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# For TPU training, we need to be more explicit about optimizer configuration
|
# For TPU training, we need to be more explicit about optimizer configuration
|
||||||
# to avoid strategy context issues
|
# to avoid strategy context issues
|
||||||
if isinstance(self.strategy, tf.distribute.TPUStrategy):
|
# IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs
|
||||||
print("Using TPU-optimized optimizer configuration")
|
# AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0
|
||||||
# TPU-specific optimizer configuration
|
# We implement manual L2 regularization (weight decay) in the training step instead
|
||||||
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues
|
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
|
||||||
# We'll implement manual L2 regularization instead
|
print("💡 Manual L2 regularization will be applied in training step")
|
||||||
optimizer = tf.keras.optimizers.AdamW(
|
|
||||||
learning_rate=self.args['lr_max'],
|
optimizer = tf.keras.optimizers.Adam(
|
||||||
beta_1=self.args['beta0'],
|
learning_rate=self.args['lr_max'],
|
||||||
beta_2=self.args['beta1'],
|
beta_1=self.args['beta0'],
|
||||||
epsilon=self.args['epsilon'],
|
beta_2=self.args['beta1'],
|
||||||
weight_decay=0.0 # Disabled for TPU compatibility
|
epsilon=self.args['epsilon']
|
||||||
# REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm
|
# No weight_decay parameter in Adam - handled manually
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
print("Using standard optimizer configuration")
|
|
||||||
optimizer = tf.keras.optimizers.AdamW(
|
|
||||||
learning_rate=self.args['lr_max'],
|
|
||||||
beta_1=self.args['beta0'],
|
|
||||||
beta_2=self.args['beta1'],
|
|
||||||
epsilon=self.args['epsilon'],
|
|
||||||
weight_decay=self.args['weight_decay']
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user