tpu
This commit is contained in:
@@ -21,7 +21,7 @@ use_amp: true # whether to use automatic mixed precision (AMP) for training
|
||||
# TPU and distributed training settings
|
||||
use_tpu: true # whether to use TPU for training (set to true for TPU)
|
||||
num_tpu_cores: 8 # number of TPU cores to use (full TPU v3-8)
|
||||
gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
|
||||
gradient_accumulation_steps: 2 # number of gradient accumulation steps for distributed training (2x32=64 effective batch size)
|
||||
|
||||
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
|
||||
checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
|
||||
@@ -75,13 +75,12 @@ dataset:
|
||||
smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
|
||||
|
||||
neural_dim: 512 # dimensionality of the neural data
|
||||
batch_size: 64 # batch size for training
|
||||
batch_size: 32 # batch size for training (reduced for TPU memory constraints)
|
||||
n_classes: 41 # number of classes (phonemes) in the dataset
|
||||
max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
|
||||
days_per_batch: 4 # number of randomly-selected days to include in each batch
|
||||
seed: 1 # random seed for reproducibility
|
||||
num_dataloader_workers: 4 # number of workers for the data loader
|
||||
dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
|
||||
num_dataloader_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
|
||||
loader_shuffle: false # whether to shuffle the data loader
|
||||
must_include_days: null # specific days to include in the dataset
|
||||
test_percentage: 0.1 # percentage of data to use for testing
|
||||
|
Reference in New Issue
Block a user