Add support for fixed predefined shapes in create_input_fn to optimize shape handling and skip analysis
This commit is contained in:
@@ -1130,7 +1130,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True,
|
||||
cache_path: Optional[str] = None,
|
||||
use_static_shapes: bool = True) -> tf.data.Dataset:
|
||||
use_static_shapes: bool = True,
|
||||
fixed_shapes: Optional[Dict[str, int]] = None) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with configurable shape handling
|
||||
|
||||
@@ -1140,9 +1141,19 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
training: Whether this is for training (applies augmentations)
|
||||
cache_path: Optional path for disk caching to improve I/O performance
|
||||
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
|
||||
fixed_shapes: Optional dict with predefined shapes to skip analysis
|
||||
Format: {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600}
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training
|
||||
|
||||
Example:
|
||||
>>> # Use automatic shape detection
|
||||
>>> dataset = create_input_fn(dataset_tf, transform_args)
|
||||
|
||||
>>> # Use fixed predefined shapes (faster, no analysis needed)
|
||||
>>> fixed_shapes = {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600}
|
||||
>>> dataset = create_input_fn(dataset_tf, transform_args, fixed_shapes=fixed_shapes)
|
||||
"""
|
||||
|
||||
# Step 1: Create individual example dataset
|
||||
@@ -1163,11 +1174,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
if use_static_shapes:
|
||||
print(f"🔧 Using STATIC shapes for XLA compatibility")
|
||||
|
||||
# Analyze dataset to get maximum shapes
|
||||
print("📊 Analyzing dataset for maximum shapes...")
|
||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training) # Pass correct training mode
|
||||
# Option 1: Use predefined fixed shapes (fastest, no analysis needed)
|
||||
if fixed_shapes is not None:
|
||||
print(f"📐 Using FIXED predefined shapes (skipping analysis):")
|
||||
print(f" max_time_steps: {fixed_shapes.get('max_time_steps', 'not specified')}")
|
||||
print(f" max_phone_seq_len: {fixed_shapes.get('max_phone_seq_len', 'not specified')}")
|
||||
print(f" max_transcription_len: {fixed_shapes.get('max_transcription_len', 'not specified')}")
|
||||
|
||||
# Use static shapes based on analysis
|
||||
# Validate required keys
|
||||
required_keys = ['max_time_steps', 'max_phone_seq_len', 'max_transcription_len']
|
||||
missing_keys = [key for key in required_keys if key not in fixed_shapes]
|
||||
if missing_keys:
|
||||
raise ValueError(f"fixed_shapes missing required keys: {missing_keys}")
|
||||
|
||||
max_shapes = {
|
||||
'max_time_steps': fixed_shapes['max_time_steps'],
|
||||
'max_phone_seq_len': fixed_shapes['max_phone_seq_len'],
|
||||
'max_transcription_len': fixed_shapes['max_transcription_len'],
|
||||
'n_features': dataset_tf.feature_dim
|
||||
}
|
||||
|
||||
# Option 2: Analyze dataset to get maximum shapes (slower, but adaptive)
|
||||
else:
|
||||
print("📊 Analyzing dataset for maximum shapes...")
|
||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training)
|
||||
|
||||
# Use static shapes based on either fixed config or analysis
|
||||
padded_shapes = {
|
||||
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
|
||||
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
|
||||
@@ -1178,7 +1210,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
'block_nums': (),
|
||||
'trial_nums': ()
|
||||
}
|
||||
print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, "
|
||||
print(f"📏 Final static shapes: time_steps={max_shapes['max_time_steps']}, "
|
||||
f"phone_len={max_shapes['max_phone_seq_len']}, "
|
||||
f"transcription_len={max_shapes['max_transcription_len']}")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user