Add training parameter to analyze_dataset_shapes for improved data augmentation handling
This commit is contained in:
@@ -832,7 +832,8 @@ def train_test_split_indices(file_paths: List[str],
|
||||
|
||||
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
|
||||
sample_size: int = 100,
|
||||
transform_args: Optional[Dict[str, Any]] = None) -> Dict[str, int]:
|
||||
transform_args: Optional[Dict[str, Any]] = None,
|
||||
training: bool = True) -> Dict[str, int]:
|
||||
"""
|
||||
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
|
||||
applying data augmentation first to capture post-augmentation shapes.
|
||||
@@ -890,7 +891,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
|
||||
|
||||
# 应用数据变换
|
||||
augmented_features, augmented_n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, transform_args, training=True
|
||||
features, n_time_steps, transform_args, training=training # Use the passed training parameter
|
||||
)
|
||||
|
||||
# 移除批次维度并获取增强后的形状
|
||||
@@ -1042,7 +1043,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
|
||||
# Analyze dataset to get maximum shapes
|
||||
print("📊 Analyzing dataset for maximum shapes...")
|
||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args) # Pass transform_args
|
||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training) # Pass correct training mode
|
||||
|
||||
# Use static shapes based on analysis
|
||||
padded_shapes = {
|
||||
|
||||
Reference in New Issue
Block a user