Add support for fixed predefined shapes in create_input_fn to optimize shape handling and skip analysis

This commit is contained in:
Zchen
2025-10-22 15:50:51 +08:00
parent 21b8e4f342
commit 8ab5697081

View File

@@ -1130,7 +1130,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True,
cache_path: Optional[str] = None,
use_static_shapes: bool = True) -> tf.data.Dataset:
use_static_shapes: bool = True,
fixed_shapes: Optional[Dict[str, int]] = None) -> tf.data.Dataset:
"""
Create input function for TPU training with configurable shape handling
@@ -1140,9 +1141,19 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
fixed_shapes: Optional dict with predefined shapes to skip analysis
Format: {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600}
Returns:
tf.data.Dataset ready for TPU training
Example:
>>> # Use automatic shape detection
>>> dataset = create_input_fn(dataset_tf, transform_args)
>>> # Use fixed predefined shapes (faster, no analysis needed)
>>> fixed_shapes = {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600}
>>> dataset = create_input_fn(dataset_tf, transform_args, fixed_shapes=fixed_shapes)
"""
# Step 1: Create individual example dataset
@@ -1163,11 +1174,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
if use_static_shapes:
print(f"🔧 Using STATIC shapes for XLA compatibility")
# Analyze dataset to get maximum shapes
print("📊 Analyzing dataset for maximum shapes...")
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training) # Pass correct training mode
# Option 1: Use predefined fixed shapes (fastest, no analysis needed)
if fixed_shapes is not None:
print(f"📐 Using FIXED predefined shapes (skipping analysis):")
print(f" max_time_steps: {fixed_shapes.get('max_time_steps', 'not specified')}")
print(f" max_phone_seq_len: {fixed_shapes.get('max_phone_seq_len', 'not specified')}")
print(f" max_transcription_len: {fixed_shapes.get('max_transcription_len', 'not specified')}")
# Use static shapes based on analysis
# Validate required keys
required_keys = ['max_time_steps', 'max_phone_seq_len', 'max_transcription_len']
missing_keys = [key for key in required_keys if key not in fixed_shapes]
if missing_keys:
raise ValueError(f"fixed_shapes missing required keys: {missing_keys}")
max_shapes = {
'max_time_steps': fixed_shapes['max_time_steps'],
'max_phone_seq_len': fixed_shapes['max_phone_seq_len'],
'max_transcription_len': fixed_shapes['max_transcription_len'],
'n_features': dataset_tf.feature_dim
}
# Option 2: Analyze dataset to get maximum shapes (slower, but adaptive)
else:
print("📊 Analyzing dataset for maximum shapes...")
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training)
# Use static shapes based on either fixed config or analysis
padded_shapes = {
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
@@ -1178,7 +1210,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'block_nums': (),
'trial_nums': ()
}
print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, "
print(f"📏 Final static shapes: time_steps={max_shapes['max_time_steps']}, "
f"phone_len={max_shapes['max_phone_seq_len']}, "
f"transcription_len={max_shapes['max_transcription_len']}")
else: