This commit is contained in:
Zchen
2025-10-17 11:57:10 +08:00
parent 57360bec8a
commit eb058fe9d3

View File

@@ -59,6 +59,8 @@ class BrainToTextDecoderTrainerTF:
print(f"Strategy type: {type(self.strategy).__name__}")
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes")
print(" and drop_remainder=True to avoid dynamic shape warnings")
# Configure mixed precision for TPU v5e-8
if args.get('use_amp', True):
@@ -99,8 +101,28 @@ class BrainToTextDecoderTrainerTF:
print("🔧 Initializing optimizer for TPU training...")
print(f"Optimizer type: {type(self.optimizer).__name__}")
# ========================= SOLUTION =========================
# Explicitly build optimizer within strategy scope before training.
# This forces creation of all slot variables (e.g., AdamW momentum)
# avoiding lazy initialization inside @tf.function which loses context.
# Note: Model must be built first for .build() to work.
# The _log_model_info method builds the model via forward pass.
# Ensure model is built (will be called later in _log_model_info anyway)
if not self.model.built:
dummy_batch_size = 2
dummy_time_steps = 100
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
_ = self.model(dummy_features, dummy_day_idx, training=False)
print("🔧 Building optimizer with model variables...")
self.optimizer.build(self.model.trainable_variables)
print("✅ Optimizer built successfully")
# ============================================================
print("✅ Optimizer ready for TPU training")
print("📝 Note: Optimizer slot variables will be created automatically during first training step")
self.lr_scheduler = self._create_lr_scheduler()
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
@@ -682,6 +704,8 @@ class BrainToTextDecoderTrainerTF:
step = 0
self.logger.info("🔄 Starting training loop...")
self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,")
self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn")
for batch in train_dist_dataset:
if step >= self.args['num_training_batches']: