调整batch_size

This commit is contained in:
Zchen
2025-10-16 17:37:59 +08:00
parent be578f2e1d
commit 1b9e0d9bdf
2 changed files with 21 additions and 4 deletions

View File

@@ -628,9 +628,18 @@ class BrainToTextDecoderTrainerTF:
self.args['dataset']['data_transforms'],
training=False
)
# Distribute datasets
# Distribute datasets with timing
self.logger.info("🔄 Distributing training dataset across TPU cores...")
dist_start_time = time.time()
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
val_start_time = time.time()
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
val_dist_time = time.time() - val_start_time
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
self.logger.info("Created distributed training and validation datasets")
# Training metrics
@@ -644,7 +653,15 @@ class BrainToTextDecoderTrainerTF:
# Training loop
step = 0
# Add timing diagnostic for first batch iteration
self.logger.info("🔄 Starting training loop iteration...")
loop_start_time = time.time()
for batch in train_dist_dataset:
if step == 0:
first_batch_iteration_time = time.time() - loop_start_time
self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s")
if step >= self.args['num_training_batches']:
self.logger.info("Reached maximum training batches, stopping training")
break