This commit is contained in:
Zchen
2025-10-16 22:42:33 +08:00
parent 982d2dc256
commit 7efa33d730

View File

@@ -148,11 +148,24 @@ class BrainToTextDecoderTrainerTF:
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
# TPU-specific weight decay handling
self.manual_weight_decay = False
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0:
self.manual_weight_decay = True
self.weight_decay_rate = self.args['weight_decay']
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
if self.adv_enabled:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
f"noise_l2_weight={self.adv_noise_l2_weight}, "
f"warmup_steps={self.adv_warmup_steps}")
if self.logger:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
f"noise_l2_weight={self.adv_noise_l2_weight}, "
f"warmup_steps={self.adv_warmup_steps}")
else:
print(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
f"noise_l2_weight={self.adv_noise_l2_weight}, "
f"warmup_steps={self.adv_warmup_steps}")
def _setup_logging(self):
"""Setup logging configuration"""
@@ -436,15 +449,19 @@ class BrainToTextDecoderTrainerTF:
if isinstance(self.strategy, tf.distribute.TPUStrategy):
print("Using TPU-optimized optimizer configuration")
# TPU-specific optimizer configuration
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues
# We'll implement manual L2 regularization instead
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args['weight_decay'],
weight_decay=0.0, # Disabled for TPU compatibility
# TPU-specific settings
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
)
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
print("💡 Consider implementing manual L2 regularization if needed")
else:
print("Using standard optimizer configuration")
optimizer = tf.keras.optimizers.AdamW(
@@ -574,6 +591,13 @@ class BrainToTextDecoderTrainerTF:
loss = self.ctc_loss(loss_input, clean_logits)
loss = tf.reduce_mean(loss)
# Add manual L2 regularization for TPU (since weight_decay is disabled)
if self.manual_weight_decay:
l2_loss = tf.constant(0.0, dtype=loss.dtype)
for var in self.model.trainable_variables:
l2_loss += tf.nn.l2_loss(var)
loss += self.weight_decay_rate * l2_loss
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
# TPU v5e-8使用bfloat16不需要loss scaling