diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 88f6df7..c4e97d1 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -137,12 +137,14 @@ class BrainToTextDecoderTrainerTF: self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) - # TPU-specific weight decay handling + # Manual weight decay handling for all environments (since we use Adam) self.manual_weight_decay = False - if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0: + if self.args.get('weight_decay', 0.0) > 0: self.manual_weight_decay = True self.weight_decay_rate = self.args['weight_decay'] print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}") + else: + print("💡 No weight decay configured") if self.adv_enabled: if self.logger: @@ -435,28 +437,19 @@ class BrainToTextDecoderTrainerTF: # For TPU training, we need to be more explicit about optimizer configuration # to avoid strategy context issues - if isinstance(self.strategy, tf.distribute.TPUStrategy): - print("Using TPU-optimized optimizer configuration") - # TPU-specific optimizer configuration - # IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues - # We'll implement manual L2 regularization instead - optimizer = tf.keras.optimizers.AdamW( - learning_rate=self.args['lr_max'], - beta_1=self.args['beta0'], - beta_2=self.args['beta1'], - epsilon=self.args['epsilon'], - weight_decay=0.0 # Disabled for TPU compatibility - # REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm - ) - else: - print("Using standard optimizer configuration") - optimizer = tf.keras.optimizers.AdamW( - learning_rate=self.args['lr_max'], - beta_1=self.args['beta0'], - beta_2=self.args['beta1'], - epsilon=self.args['epsilon'], - weight_decay=self.args['weight_decay'] - ) + # IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs + # AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0 + # We implement manual L2 regularization (weight decay) in the training step instead + print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)") + print("💡 Manual L2 regularization will be applied in training step") + + optimizer = tf.keras.optimizers.Adam( + learning_rate=self.args['lr_max'], + beta_1=self.args['beta0'], + beta_2=self.args['beta1'], + epsilon=self.args['epsilon'] + # No weight_decay parameter in Adam - handled manually + ) return optimizer