Refactor dynamic padding shapes and update device placement configuration for TPU training

This commit is contained in:
Zchen
2025-10-22 01:03:14 +08:00
parent 57f07434ac
commit c03441d8f3
2 changed files with 15 additions and 15 deletions

View File

@@ -1003,16 +1003,16 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
print(f"🔧 Using DYNAMIC padding -> batch augmentation approach")
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
# Define dynamic padded shapes - key insight: None allows for dynamic lengths
# Define dynamic padded shapes - use simple shapes instead of TensorSpec for padded_batch
padded_shapes = {
'input_features': tf.TensorSpec(shape=(None, dataset_tf.feature_dim), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar
'input_features': (None, dataset_tf.feature_dim),
'seq_class_ids': (None,),
'n_time_steps': (), # scalar
'phone_seq_lens': (), # scalar
'day_indices': (), # scalar
'transcriptions': (None,),
'block_nums': (), # scalar
'trial_nums': () # scalar
}
# Define padding values for each field

View File

@@ -56,9 +56,10 @@ class BrainToTextDecoderTrainerTF:
self.args = args
self.logger = None
# Enable soft device placement for XLA unsupported ops (like CTC)
tf.config.set_soft_device_placement(True)
print("✅ Enabled soft device placement for CTC operations")
# Configure device placement for TPU training
# Note: classic_ctc_loss is TPU-compatible, disable soft placement to force TPU execution
tf.config.set_soft_device_placement(False)
print("✅ Disabled soft device placement - forcing operations to stay on TPU")
# Initialize TPU strategy
self.strategy = create_tpu_strategy()
@@ -69,8 +70,7 @@ class BrainToTextDecoderTrainerTF:
print(f"Strategy type: {type(self.strategy).__name__}")
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
print("⚠️ CTC operations will automatically fall back to CPU (expected behavior)")
print(" This has minimal performance impact as CTC is a small portion of computation")
print("✅ classic_ctc_loss should now execute on TPU (soft device placement disabled)")
# Configure mixed precision for TPU v5e-8
if args.get('use_amp', True):
@@ -135,7 +135,7 @@ class BrainToTextDecoderTrainerTF:
print("✅ Optimizer ready for TPU training")
self.lr_scheduler = self._create_lr_scheduler()
# CTC loss is now handled using tf.nn.ctc_loss (TPU-compatible)
# CTC loss now uses classic_ctc_loss (TPU-compatible with soft placement disabled)
# Create unified checkpoint management
self.ckpt = tf.train.Checkpoint(