ff
This commit is contained in:
@@ -59,6 +59,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f"Strategy type: {type(self.strategy).__name__}")
|
||||
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
|
||||
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
|
||||
print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes")
|
||||
print(" and drop_remainder=True to avoid dynamic shape warnings")
|
||||
|
||||
# Configure mixed precision for TPU v5e-8
|
||||
if args.get('use_amp', True):
|
||||
@@ -99,8 +101,28 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
print("🔧 Initializing optimizer for TPU training...")
|
||||
print(f"Optimizer type: {type(self.optimizer).__name__}")
|
||||
|
||||
# ========================= SOLUTION =========================
|
||||
# Explicitly build optimizer within strategy scope before training.
|
||||
# This forces creation of all slot variables (e.g., AdamW momentum)
|
||||
# avoiding lazy initialization inside @tf.function which loses context.
|
||||
# Note: Model must be built first for .build() to work.
|
||||
# The _log_model_info method builds the model via forward pass.
|
||||
|
||||
# Ensure model is built (will be called later in _log_model_info anyway)
|
||||
if not self.model.built:
|
||||
dummy_batch_size = 2
|
||||
dummy_time_steps = 100
|
||||
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
|
||||
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
|
||||
_ = self.model(dummy_features, dummy_day_idx, training=False)
|
||||
|
||||
print("🔧 Building optimizer with model variables...")
|
||||
self.optimizer.build(self.model.trainable_variables)
|
||||
print("✅ Optimizer built successfully")
|
||||
# ============================================================
|
||||
|
||||
print("✅ Optimizer ready for TPU training")
|
||||
print("📝 Note: Optimizer slot variables will be created automatically during first training step")
|
||||
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
@@ -682,6 +704,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
step = 0
|
||||
|
||||
self.logger.info("🔄 Starting training loop...")
|
||||
self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,")
|
||||
self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn")
|
||||
|
||||
for batch in train_dist_dataset:
|
||||
if step >= self.args['num_training_batches']:
|
||||
|
Reference in New Issue
Block a user