From eb058fe9d348db1188fa3105dee3aca70805f232 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:57:10 +0800 Subject: [PATCH] ff --- model_training_nnn_tpu/trainer_tf.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 58e49ea..8e65b54 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -59,6 +59,8 @@ class BrainToTextDecoderTrainerTF: print(f"Strategy type: {type(self.strategy).__name__}") print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance") print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations") + print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes") + print(" and drop_remainder=True to avoid dynamic shape warnings") # Configure mixed precision for TPU v5e-8 if args.get('use_amp', True): @@ -99,8 +101,28 @@ class BrainToTextDecoderTrainerTF: print("🔧 Initializing optimizer for TPU training...") print(f"Optimizer type: {type(self.optimizer).__name__}") + + # ========================= SOLUTION ========================= + # Explicitly build optimizer within strategy scope before training. + # This forces creation of all slot variables (e.g., AdamW momentum) + # avoiding lazy initialization inside @tf.function which loses context. + # Note: Model must be built first for .build() to work. + # The _log_model_info method builds the model via forward pass. + + # Ensure model is built (will be called later in _log_model_info anyway) + if not self.model.built: + dummy_batch_size = 2 + dummy_time_steps = 100 + dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features'])) + dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32) + _ = self.model(dummy_features, dummy_day_idx, training=False) + + print("🔧 Building optimizer with model variables...") + self.optimizer.build(self.model.trainable_variables) + print("✅ Optimizer built successfully") + # ============================================================ + print("✅ Optimizer ready for TPU training") - print("📝 Note: Optimizer slot variables will be created automatically during first training step") self.lr_scheduler = self._create_lr_scheduler() self.ctc_loss = CTCLoss(blank_index=0, reduction='none') @@ -682,6 +704,8 @@ class BrainToTextDecoderTrainerTF: step = 0 self.logger.info("🔄 Starting training loop...") + self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,") + self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn") for batch in train_dist_dataset: if step >= self.args['num_training_batches']: