diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index d4598c6..1917585 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -511,18 +511,11 @@ class BrainToTextDecoderTrainerTF: loss = self.ctc_loss(loss_input, clean_logits) loss = tf.reduce_mean(loss) - # Scale loss for mixed precision - if self.mixed_precision: - scaled_loss = self.optimizer.get_scaled_loss(loss) - else: - scaled_loss = loss + # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理 + # TPU v5e-8使用bfloat16,不需要loss scaling - # Calculate gradients - if self.mixed_precision: - scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables) - gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) - else: - gradients = tape.gradient(scaled_loss, self.model.trainable_variables) + # Calculate gradients - TensorFlow自动处理混合精度 + gradients = tape.gradient(loss, self.model.trainable_variables) # Clip gradients if self.args['grad_norm_clip_value'] > 0: