Add training parameter to analyze_dataset_shapes for improved data augmentation handling

This commit is contained in:
Zchen
2025-10-22 14:13:26 +08:00
parent 7511f4cf68
commit 4a99f50afd

View File

@@ -832,7 +832,8 @@ def train_test_split_indices(file_paths: List[str],
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
sample_size: int = 100,
transform_args: Optional[Dict[str, Any]] = None) -> Dict[str, int]:
transform_args: Optional[Dict[str, Any]] = None,
training: bool = True) -> Dict[str, int]:
"""
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
applying data augmentation first to capture post-augmentation shapes.
@@ -890,7 +891,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
# 应用数据变换
augmented_features, augmented_n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, transform_args, training=True
features, n_time_steps, transform_args, training=training # Use the passed training parameter
)
# 移除批次维度并获取增强后的形状
@@ -1042,7 +1043,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
# Analyze dataset to get maximum shapes
print("📊 Analyzing dataset for maximum shapes...")
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args) # Pass transform_args
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training) # Pass correct training mode
# Use static shapes based on analysis
padded_shapes = {