refactor: streamline model building and ensure dtype consistency in L2 loss calculation

This commit is contained in:
Zchen
2025-10-16 23:06:09 +08:00
parent 9453b70fad
commit 7a43ebfb71

View File

@@ -90,12 +90,8 @@ class BrainToTextDecoderTrainerTF:
with self.strategy.scope():
print("🔨 Building model within TPU strategy scope...")
self.model = self._build_model()
print("✅ Model built successfully")
print("⚙️ Creating optimizer...")
self.optimizer = self._create_optimizer()
print("✅ Optimizer created")
print("🔧 Pre-building optimizer state for TPU...")
# For TPU, we must ensure optimizer is completely ready before training
# since @tf.function doesn't allow dynamic building
@@ -595,7 +591,10 @@ class BrainToTextDecoderTrainerTF:
if self.manual_weight_decay:
l2_loss = tf.constant(0.0, dtype=loss.dtype)
for var in self.model.trainable_variables:
l2_loss += tf.nn.l2_loss(var)
# Ensure dtype consistency for mixed precision training
var_l2 = tf.nn.l2_loss(var)
var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype
l2_loss += var_l2
loss += self.weight_decay_rate * l2_loss
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理