Refactor dynamic padding shapes and update device placement configuration for TPU training
This commit is contained in:
@@ -1003,16 +1003,16 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
print(f"🔧 Using DYNAMIC padding -> batch augmentation approach")
|
||||
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
|
||||
|
||||
# Define dynamic padded shapes - key insight: None allows for dynamic lengths
|
||||
# Define dynamic padded shapes - use simple shapes instead of TensorSpec for padded_batch
|
||||
padded_shapes = {
|
||||
'input_features': tf.TensorSpec(shape=(None, dataset_tf.feature_dim), dtype=tf.float32),
|
||||
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar
|
||||
'input_features': (None, dataset_tf.feature_dim),
|
||||
'seq_class_ids': (None,),
|
||||
'n_time_steps': (), # scalar
|
||||
'phone_seq_lens': (), # scalar
|
||||
'day_indices': (), # scalar
|
||||
'transcriptions': (None,),
|
||||
'block_nums': (), # scalar
|
||||
'trial_nums': () # scalar
|
||||
}
|
||||
|
||||
# Define padding values for each field
|
||||
|
||||
@@ -56,9 +56,10 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.args = args
|
||||
self.logger = None
|
||||
|
||||
# Enable soft device placement for XLA unsupported ops (like CTC)
|
||||
tf.config.set_soft_device_placement(True)
|
||||
print("✅ Enabled soft device placement for CTC operations")
|
||||
# Configure device placement for TPU training
|
||||
# Note: classic_ctc_loss is TPU-compatible, disable soft placement to force TPU execution
|
||||
tf.config.set_soft_device_placement(False)
|
||||
print("✅ Disabled soft device placement - forcing operations to stay on TPU")
|
||||
|
||||
# Initialize TPU strategy
|
||||
self.strategy = create_tpu_strategy()
|
||||
@@ -69,8 +70,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f"Strategy type: {type(self.strategy).__name__}")
|
||||
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
|
||||
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
|
||||
print("⚠️ CTC operations will automatically fall back to CPU (expected behavior)")
|
||||
print(" This has minimal performance impact as CTC is a small portion of computation")
|
||||
print("✅ classic_ctc_loss should now execute on TPU (soft device placement disabled)")
|
||||
|
||||
# Configure mixed precision for TPU v5e-8
|
||||
if args.get('use_amp', True):
|
||||
@@ -135,7 +135,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
print("✅ Optimizer ready for TPU training")
|
||||
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
# CTC loss is now handled using tf.nn.ctc_loss (TPU-compatible)
|
||||
# CTC loss now uses classic_ctc_loss (TPU-compatible with soft placement disabled)
|
||||
|
||||
# Create unified checkpoint management
|
||||
self.ckpt = tf.train.Checkpoint(
|
||||
|
||||
Reference in New Issue
Block a user