adamw修复
This commit is contained in:
@@ -511,18 +511,11 @@ class BrainToTextDecoderTrainerTF:
|
||||
loss = self.ctc_loss(loss_input, clean_logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Scale loss for mixed precision
|
||||
if self.mixed_precision:
|
||||
scaled_loss = self.optimizer.get_scaled_loss(loss)
|
||||
else:
|
||||
scaled_loss = loss
|
||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||
# TPU v5e-8使用bfloat16,不需要loss scaling
|
||||
|
||||
# Calculate gradients
|
||||
if self.mixed_precision:
|
||||
scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
||||
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
|
||||
else:
|
||||
gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
||||
# Calculate gradients - TensorFlow自动处理混合精度
|
||||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||||
|
||||
# Clip gradients
|
||||
if self.args['grad_norm_clip_value'] > 0:
|
||||
|
Reference in New Issue
Block a user