From 8ab5697081a2459abb3485880f00fa2f6247a300 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:50:51 +0800 Subject: [PATCH] Add support for fixed predefined shapes in create_input_fn to optimize shape handling and skip analysis --- model_training_nnn_tpu/dataset_tf.py | 44 ++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index df2ad75..6fbb947 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -1130,7 +1130,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], training: bool = True, cache_path: Optional[str] = None, - use_static_shapes: bool = True) -> tf.data.Dataset: + use_static_shapes: bool = True, + fixed_shapes: Optional[Dict[str, int]] = None) -> tf.data.Dataset: """ Create input function for TPU training with configurable shape handling @@ -1140,9 +1141,19 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance use_static_shapes: If True, use pre-computed static shapes for XLA compatibility + fixed_shapes: Optional dict with predefined shapes to skip analysis + Format: {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600} Returns: tf.data.Dataset ready for TPU training + + Example: + >>> # Use automatic shape detection + >>> dataset = create_input_fn(dataset_tf, transform_args) + + >>> # Use fixed predefined shapes (faster, no analysis needed) + >>> fixed_shapes = {'max_time_steps': 2600, 'max_phone_seq_len': 80, 'max_transcription_len': 600} + >>> dataset = create_input_fn(dataset_tf, transform_args, fixed_shapes=fixed_shapes) """ # Step 1: Create individual example dataset @@ -1163,11 +1174,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, if use_static_shapes: print(f"🔧 Using STATIC shapes for XLA compatibility") - # Analyze dataset to get maximum shapes - print("📊 Analyzing dataset for maximum shapes...") - max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training) # Pass correct training mode + # Option 1: Use predefined fixed shapes (fastest, no analysis needed) + if fixed_shapes is not None: + print(f"📐 Using FIXED predefined shapes (skipping analysis):") + print(f" max_time_steps: {fixed_shapes.get('max_time_steps', 'not specified')}") + print(f" max_phone_seq_len: {fixed_shapes.get('max_phone_seq_len', 'not specified')}") + print(f" max_transcription_len: {fixed_shapes.get('max_transcription_len', 'not specified')}") - # Use static shapes based on analysis + # Validate required keys + required_keys = ['max_time_steps', 'max_phone_seq_len', 'max_transcription_len'] + missing_keys = [key for key in required_keys if key not in fixed_shapes] + if missing_keys: + raise ValueError(f"fixed_shapes missing required keys: {missing_keys}") + + max_shapes = { + 'max_time_steps': fixed_shapes['max_time_steps'], + 'max_phone_seq_len': fixed_shapes['max_phone_seq_len'], + 'max_transcription_len': fixed_shapes['max_transcription_len'], + 'n_features': dataset_tf.feature_dim + } + + # Option 2: Analyze dataset to get maximum shapes (slower, but adaptive) + else: + print("📊 Analyzing dataset for maximum shapes...") + max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args, training=training) + + # Use static shapes based on either fixed config or analysis padded_shapes = { 'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim), 'seq_class_ids': (max_shapes['max_phone_seq_len'],), @@ -1178,7 +1210,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, 'block_nums': (), 'trial_nums': () } - print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, " + print(f"📏 Final static shapes: time_steps={max_shapes['max_time_steps']}, " f"phone_len={max_shapes['max_phone_seq_len']}, " f"transcription_len={max_shapes['max_transcription_len']}") else: