Add support for fixed predefined shapes in create_input_fn to optimize shape handling and skip analysis
This commit is contained in:
@@ -1130,7 +1130,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
transform_args: Dict[str, Any],
|
transform_args: Dict[str, Any],
|
||||||
training: bool = True,
|
training: bool = True,
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
use_static_shapes: bool = True) -> tf.data.Dataset:
|
use_static_shapes: bool = True,
|
||||||
|
fixed_shapes: Optional[Dict[str, int]] = None) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Create input function for TPU training with configurable shape handling
|
Create input function for TPU training with configurable shape handling
|
||||||
|
|
||||||
@@ -1140,9 +1141,19 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
training: Whether this is for training (applies augmentations)
|
training: Whether this is for training (applies augmentations)
|
||||||
cache_path: Optional path for disk caching to improve I/O performance
|
cache_path: Optional path for disk caching to improve I/O performance
|
||||||
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
|
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
|
||||||
|
fixed_shapes: Optional dict with predefined shapes to skip analysis
|
||||||
|
Format: {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset ready for TPU training
|
tf.data.Dataset ready for TPU training
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Use automatic shape detection
|
||||||
|
>>> dataset = create_input_fn(dataset_tf, transform_args)
|
||||||
|
|
||||||
|
>>> # Use fixed predefined shapes (faster, no analysis needed)
|
||||||
|
>>> fixed_shapes = {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600}
|
||||||
|
>>> dataset = create_input_fn(dataset_tf, transform_args, fixed_shapes=fixed_shapes)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Step 1: Create individual example dataset
|
# Step 1: Create individual example dataset
|
||||||
@@ -1163,11 +1174,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
if use_static_shapes:
|
if use_static_shapes:
|
||||||
print(f"🔧 Using STATIC shapes for XLA compatibility")
|
print(f"🔧 Using STATIC shapes for XLA compatibility")
|
||||||
|
|
||||||
# Analyze dataset to get maximum shapes
|
# Option 1: Use predefined fixed shapes (fastest, no analysis needed)
|
||||||
print("📊 Analyzing dataset for maximum shapes...")
|
if fixed_shapes is not None:
|
||||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training) # Pass correct training mode
|
print(f"📐 Using FIXED predefined shapes (skipping analysis):")
|
||||||
|
print(f" max_time_steps: {fixed_shapes.get('max_time_steps', 'not specified')}")
|
||||||
|
print(f" max_phone_seq_len: {fixed_shapes.get('max_phone_seq_len', 'not specified')}")
|
||||||
|
print(f" max_transcription_len: {fixed_shapes.get('max_transcription_len', 'not specified')}")
|
||||||
|
|
||||||
# Use static shapes based on analysis
|
# Validate required keys
|
||||||
|
required_keys = ['max_time_steps', 'max_phone_seq_len', 'max_transcription_len']
|
||||||
|
missing_keys = [key for key in required_keys if key not in fixed_shapes]
|
||||||
|
if missing_keys:
|
||||||
|
raise ValueError(f"fixed_shapes missing required keys: {missing_keys}")
|
||||||
|
|
||||||
|
max_shapes = {
|
||||||
|
'max_time_steps': fixed_shapes['max_time_steps'],
|
||||||
|
'max_phone_seq_len': fixed_shapes['max_phone_seq_len'],
|
||||||
|
'max_transcription_len': fixed_shapes['max_transcription_len'],
|
||||||
|
'n_features': dataset_tf.feature_dim
|
||||||
|
}
|
||||||
|
|
||||||
|
# Option 2: Analyze dataset to get maximum shapes (slower, but adaptive)
|
||||||
|
else:
|
||||||
|
print("📊 Analyzing dataset for maximum shapes...")
|
||||||
|
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training)
|
||||||
|
|
||||||
|
# Use static shapes based on either fixed config or analysis
|
||||||
padded_shapes = {
|
padded_shapes = {
|
||||||
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
|
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
|
||||||
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
|
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
|
||||||
@@ -1178,7 +1210,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
'block_nums': (),
|
'block_nums': (),
|
||||||
'trial_nums': ()
|
'trial_nums': ()
|
||||||
}
|
}
|
||||||
print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, "
|
print(f"📏 Final static shapes: time_steps={max_shapes['max_time_steps']}, "
|
||||||
f"phone_len={max_shapes['max_phone_seq_len']}, "
|
f"phone_len={max_shapes['max_phone_seq_len']}, "
|
||||||
f"transcription_len={max_shapes['max_transcription_len']}")
|
f"transcription_len={max_shapes['max_transcription_len']}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user