adamw修复

This commit is contained in:
Zchen
2025-10-16 20:44:55 +08:00
parent c2661550ef
commit 1e7077bba7

View File

@@ -511,18 +511,11 @@ class BrainToTextDecoderTrainerTF:
loss = self.ctc_loss(loss_input, clean_logits)
loss = tf.reduce_mean(loss)
# Scale loss for mixed precision
if self.mixed_precision:
scaled_loss = self.optimizer.get_scaled_loss(loss)
else:
scaled_loss = loss
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
# TPU v5e-8使用bfloat16不需要loss scaling
# Calculate gradients
if self.mixed_precision:
scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
else:
gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
# Calculate gradients - TensorFlow自动处理混合精度
gradients = tape.gradient(loss, self.model.trainable_variables)
# Clip gradients
if self.args['grad_norm_clip_value'] > 0: