调整batch_size
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
model:
|
||||
n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes)
|
||||
n_units: 256 # number of units per GRU layer (大幅减少从768→256, 减少70%参数量)
|
||||
n_units: 768 # number of units per GRU layer (大幅减少从768→256, 减少70%参数量)
|
||||
rnn_dropout: 0.4 # dropout rate for the GRU layers
|
||||
rnn_trainable: true # whether the GRU layers are trainable
|
||||
n_layers: 3 # number of GRU layers (从5层减少到3层)
|
||||
n_layers: 5 # number of GRU layers (从5层减少到3层)
|
||||
patch_size: 14 # size of the input patches (14 time steps)
|
||||
patch_stride: 4 # stride for the input patches (4 time steps)
|
||||
|
||||
@@ -74,7 +74,7 @@ dataset:
|
||||
smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
|
||||
|
||||
neural_dim: 512 # dimensionality of the neural data
|
||||
batch_size: 1024 # batch size for training (reduced for TPU memory constraints)
|
||||
batch_size: 256 # batch size for training (reduced for TPU memory constraints)
|
||||
n_classes: 41 # number of classes (phonemes) in the dataset
|
||||
max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
|
||||
days_per_batch: 4 # number of randomly-selected days to include in each batch
|
||||
|
@@ -628,9 +628,18 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
# Distribute datasets
|
||||
# Distribute datasets with timing
|
||||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||||
dist_start_time = time.time()
|
||||
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
|
||||
train_dist_time = time.time() - dist_start_time
|
||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||
|
||||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||||
val_start_time = time.time()
|
||||
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
|
||||
val_dist_time = time.time() - val_start_time
|
||||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||||
|
||||
self.logger.info("Created distributed training and validation datasets")
|
||||
# Training metrics
|
||||
@@ -644,7 +653,15 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Training loop
|
||||
step = 0
|
||||
|
||||
# Add timing diagnostic for first batch iteration
|
||||
self.logger.info("🔄 Starting training loop iteration...")
|
||||
loop_start_time = time.time()
|
||||
|
||||
for batch in train_dist_dataset:
|
||||
if step == 0:
|
||||
first_batch_iteration_time = time.time() - loop_start_time
|
||||
self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s")
|
||||
if step >= self.args['num_training_batches']:
|
||||
self.logger.info("Reached maximum training batches, stopping training")
|
||||
break
|
||||
|
Reference in New Issue
Block a user