refactor: streamline model building and ensure dtype consistency in L2 loss calculation
This commit is contained in:
@@ -90,12 +90,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
with self.strategy.scope():
|
||||
print("🔨 Building model within TPU strategy scope...")
|
||||
self.model = self._build_model()
|
||||
print("✅ Model built successfully")
|
||||
|
||||
print("⚙️ Creating optimizer...")
|
||||
self.optimizer = self._create_optimizer()
|
||||
print("✅ Optimizer created")
|
||||
|
||||
print("🔧 Pre-building optimizer state for TPU...")
|
||||
# For TPU, we must ensure optimizer is completely ready before training
|
||||
# since @tf.function doesn't allow dynamic building
|
||||
@@ -595,7 +591,10 @@ class BrainToTextDecoderTrainerTF:
|
||||
if self.manual_weight_decay:
|
||||
l2_loss = tf.constant(0.0, dtype=loss.dtype)
|
||||
for var in self.model.trainable_variables:
|
||||
l2_loss += tf.nn.l2_loss(var)
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
var_l2 = tf.nn.l2_loss(var)
|
||||
var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype
|
||||
l2_loss += var_l2
|
||||
loss += self.weight_decay_rate * l2_loss
|
||||
|
||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||
|
Reference in New Issue
Block a user