From 1b9e0d9bdf9715c60cceaf0eeb0d96c15d32d65e Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 17:37:59 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4batch=5Fsize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_training_nnn_tpu/rnn_args.yaml | 6 +++--- model_training_nnn_tpu/trainer_tf.py | 19 ++++++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/model_training_nnn_tpu/rnn_args.yaml b/model_training_nnn_tpu/rnn_args.yaml index 227bc43..4977a59 100644 --- a/model_training_nnn_tpu/rnn_args.yaml +++ b/model_training_nnn_tpu/rnn_args.yaml @@ -1,9 +1,9 @@ model: n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes) - n_units: 256 # number of units per GRU layer (大幅减少从768→256, 减少70%参数量) + n_units: 768 # number of units per GRU layer (大幅减少从768→256, 减少70%参数量) rnn_dropout: 0.4 # dropout rate for the GRU layers rnn_trainable: true # whether the GRU layers are trainable - n_layers: 3 # number of GRU layers (从5层减少到3层) + n_layers: 5 # number of GRU layers (从5层减少到3层) patch_size: 14 # size of the input patches (14 time steps) patch_stride: 4 # stride for the input patches (4 time steps) @@ -74,7 +74,7 @@ dataset: smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data neural_dim: 512 # dimensionality of the neural data - batch_size: 1024 # batch size for training (reduced for TPU memory constraints) + batch_size: 256 # batch size for training (reduced for TPU memory constraints) n_classes: 41 # number of classes (phonemes) in the dataset max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial days_per_batch: 4 # number of randomly-selected days to include in each batch diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 3628a7d..5438999 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -628,9 +628,18 @@ class BrainToTextDecoderTrainerTF: self.args['dataset']['data_transforms'], training=False ) - # Distribute datasets + # Distribute datasets with timing + self.logger.info("🔄 Distributing training dataset across TPU cores...") + dist_start_time = time.time() train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset) + train_dist_time = time.time() - dist_start_time + self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") + + self.logger.info("🔄 Distributing validation dataset across TPU cores...") + val_start_time = time.time() val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset) + val_dist_time = time.time() - val_start_time + self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s") self.logger.info("Created distributed training and validation datasets") # Training metrics @@ -644,7 +653,15 @@ class BrainToTextDecoderTrainerTF: # Training loop step = 0 + + # Add timing diagnostic for first batch iteration + self.logger.info("🔄 Starting training loop iteration...") + loop_start_time = time.time() + for batch in train_dist_dataset: + if step == 0: + first_batch_iteration_time = time.time() - loop_start_time + self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s") if step >= self.args['num_training_batches']: self.logger.info("Reached maximum training batches, stopping training") break