2025-10-12 15:21:30 +08:00
import torch
2025-10-12 09:11:32 +08:00
from torch . utils . data import DataLoader
from torch . optim . lr_scheduler import LambdaLR
import random
import time
import os
import numpy as np
import math
import pathlib
import logging
import sys
import json
import pickle
from dataset import BrainToTextDataset , train_test_split_indicies
from data_augmentations import gauss_smooth
import torchaudio . functional as F # for edit distance
from omegaconf import OmegaConf
2025-10-12 15:21:30 +08:00
# Import Accelerate for TPU support
from accelerate import Accelerator
from accelerate . utils import set_seed
2025-10-12 09:11:32 +08:00
torch . set_float32_matmul_precision ( ' high ' ) # makes float32 matmuls faster on some GPUs
torch . backends . cudnn . deterministic = True # makes training more reproducible
torch . _dynamo . config . cache_size_limit = 64
2025-10-12 09:35:26 +08:00
from rnn_model import TripleGRUDecoder
2025-10-12 09:11:32 +08:00
class BrainToTextDecoder_Trainer :
"""
This class will initialize and train a brain - to - text phoneme decoder
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL ' s decoding function
"""
def __init__ ( self , args ) :
'''
args : dictionary of training arguments
'''
2025-10-12 15:21:30 +08:00
# Initialize Accelerator for TPU/multi-device support
self . accelerator = Accelerator (
mixed_precision = ' bf16 ' if args . get ( ' use_amp ' , True ) else ' no ' ,
gradient_accumulation_steps = args . get ( ' gradient_accumulation_steps ' , 1 ) ,
log_with = None , # We'll use our own logging
project_dir = args . get ( ' output_dir ' , ' ./output ' ) ,
)
2025-10-12 16:32:31 +08:00
# Set even_batches to False to handle batch_size=None in DataLoaders
2025-10-12 18:49:22 +08:00
# For TPU, we need to handle the batch_sampler issue more carefully
2025-10-12 16:32:31 +08:00
self . accelerator . even_batches = False
2025-10-12 09:11:32 +08:00
# Trainer fields
self . args = args
2025-10-12 15:21:30 +08:00
self . logger = None
self . device = self . accelerator . device # Use accelerator device instead of manual device selection
2025-10-12 09:11:32 +08:00
self . model = None
self . optimizer = None
self . learning_rate_scheduler = None
2025-10-12 15:21:30 +08:00
self . ctc_loss = None
2025-10-12 09:11:32 +08:00
self . best_val_PER = torch . inf # track best PER for checkpointing
self . best_val_loss = torch . inf # track best loss for checkpointing
2025-10-12 15:21:30 +08:00
self . train_dataset = None
self . val_dataset = None
self . train_loader = None
self . val_loader = None
2025-10-12 09:11:32 +08:00
self . transform_args = self . args [ ' dataset ' ] [ ' data_transforms ' ]
# Create output directory
if args [ ' mode ' ] == ' train ' :
2025-10-12 18:41:26 +08:00
os . makedirs ( self . args [ ' output_dir ' ] , exist_ok = True )
2025-10-12 09:11:32 +08:00
# Create checkpoint directory
2025-10-12 18:41:26 +08:00
if args [ ' save_best_checkpoint ' ] or args [ ' save_all_val_steps ' ] or args [ ' save_final_model ' ] :
os . makedirs ( self . args [ ' checkpoint_dir ' ] , exist_ok = True )
2025-10-12 09:11:32 +08:00
# Set up logging
self . logger = logging . getLogger ( __name__ )
for handler in self . logger . handlers [ : ] : # make a copy of the list
self . logger . removeHandler ( handler )
self . logger . setLevel ( logging . INFO )
formatter = logging . Formatter ( fmt = ' %(asctime)s : %(message)s ' )
if args [ ' mode ' ] == ' train ' :
# During training, save logs to file in output directory
fh = logging . FileHandler ( str ( pathlib . Path ( self . args [ ' output_dir ' ] , ' training_log ' ) ) )
fh . setFormatter ( formatter )
self . logger . addHandler ( fh )
# Always print logs to stdout
sh = logging . StreamHandler ( sys . stdout )
sh . setFormatter ( formatter )
self . logger . addHandler ( sh )
2025-10-12 15:21:30 +08:00
# Log device information (managed by Accelerator)
2025-10-12 09:11:32 +08:00
self . logger . info ( f ' Using device: { self . device } ' )
2025-10-12 15:21:30 +08:00
self . logger . info ( f ' Accelerator state: { self . accelerator . state } ' )
if self . accelerator . num_processes > 1 :
self . logger . info ( f ' Distributed training on { self . accelerator . num_processes } processes ' )
2025-10-12 09:11:32 +08:00
2025-10-12 15:21:30 +08:00
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
2025-10-12 09:11:32 +08:00
if self . args [ ' seed ' ] != - 1 :
2025-10-12 15:21:30 +08:00
set_seed ( self . args [ ' seed ' ] )
2025-10-12 09:11:32 +08:00
2025-10-12 09:35:26 +08:00
# Initialize the model
self . model = TripleGRUDecoder (
2025-10-12 09:11:32 +08:00
neural_dim = self . args [ ' model ' ] [ ' n_input_features ' ] ,
n_units = self . args [ ' model ' ] [ ' n_units ' ] ,
n_days = len ( self . args [ ' dataset ' ] [ ' sessions ' ] ) ,
n_classes = self . args [ ' dataset ' ] [ ' n_classes ' ] ,
2025-10-12 09:35:26 +08:00
rnn_dropout = self . args [ ' model ' ] [ ' rnn_dropout ' ] ,
input_dropout = self . args [ ' model ' ] [ ' input_network ' ] [ ' input_layer_dropout ' ] ,
2025-10-12 09:11:32 +08:00
patch_size = self . args [ ' model ' ] [ ' patch_size ' ] ,
patch_stride = self . args [ ' model ' ] [ ' patch_stride ' ] ,
)
2025-10-12 09:35:26 +08:00
# Temporarily disable torch.compile for compatibility with new model architecture
# TODO: Re-enable torch.compile once model is stable
# self.logger.info("Using torch.compile")
# self.model = torch.compile(self.model)
self . logger . info ( " torch.compile disabled for new TripleGRUDecoder compatibility " )
2025-10-12 09:11:32 +08:00
self . logger . info ( f " Initialized RNN decoding model " )
self . logger . info ( self . model )
# Log how many parameters are in the model
total_params = sum ( p . numel ( ) for p in self . model . parameters ( ) )
self . logger . info ( f " Model has { total_params : , } parameters " )
# Determine how many day-specific parameters are in the model
day_params = 0
for name , param in self . model . named_parameters ( ) :
if ' day ' in name :
day_params + = param . numel ( )
self . logger . info ( f " Model has { day_params : , } day-specific parameters | { ( ( day_params / total_params ) * 100 ) : .2f } % of total parameters " )
# Create datasets and dataloaders
train_file_paths = [ os . path . join ( self . args [ " dataset " ] [ " dataset_dir " ] , s , ' data_train.hdf5 ' ) for s in self . args [ ' dataset ' ] [ ' sessions ' ] ]
val_file_paths = [ os . path . join ( self . args [ " dataset " ] [ " dataset_dir " ] , s , ' data_val.hdf5 ' ) for s in self . args [ ' dataset ' ] [ ' sessions ' ] ]
# Ensure that there are no duplicate days
if len ( set ( train_file_paths ) ) != len ( train_file_paths ) :
raise ValueError ( " There are duplicate sessions listed in the train dataset " )
if len ( set ( val_file_paths ) ) != len ( val_file_paths ) :
raise ValueError ( " There are duplicate sessions listed in the val dataset " )
# Split trials into train and test sets
train_trials , _ = train_test_split_indicies (
file_paths = train_file_paths ,
test_percentage = 0 ,
seed = self . args [ ' dataset ' ] [ ' seed ' ] ,
bad_trials_dict = None ,
)
_ , val_trials = train_test_split_indicies (
file_paths = val_file_paths ,
test_percentage = 1 ,
seed = self . args [ ' dataset ' ] [ ' seed ' ] ,
bad_trials_dict = None ,
)
# Save dictionaries to output directory to know which trials were train vs val
with open ( os . path . join ( self . args [ ' output_dir ' ] , ' train_val_trials.json ' ) , ' w ' ) as f :
json . dump ( { ' train ' : train_trials , ' val ' : val_trials } , f )
# Determine if a only a subset of neural features should be used
feature_subset = None
if ( ' feature_subset ' in self . args [ ' dataset ' ] ) and self . args [ ' dataset ' ] [ ' feature_subset ' ] != None :
feature_subset = self . args [ ' dataset ' ] [ ' feature_subset ' ]
self . logger . info ( f ' Using only a subset of features: { feature_subset } ' )
# train dataset and dataloader
self . train_dataset = BrainToTextDataset (
trial_indicies = train_trials ,
split = ' train ' ,
days_per_batch = self . args [ ' dataset ' ] [ ' days_per_batch ' ] ,
n_batches = self . args [ ' num_training_batches ' ] ,
batch_size = self . args [ ' dataset ' ] [ ' batch_size ' ] ,
must_include_days = None ,
random_seed = self . args [ ' dataset ' ] [ ' seed ' ] ,
feature_subset = feature_subset
)
2025-10-12 15:31:45 +08:00
# Use TPU-optimized dataloader settings if TPU is enabled
num_workers = self . args [ ' dataset ' ] [ ' dataloader_num_workers ' ] if self . args . get ( ' use_tpu ' , False ) else self . args [ ' dataset ' ] [ ' num_dataloader_workers ' ]
2025-10-12 18:49:22 +08:00
# For TPU environments, we need to be more careful about DataLoader configuration
use_tpu = self . args . get ( ' use_tpu ' , False )
2025-10-12 18:41:26 +08:00
2025-10-12 19:49:47 +08:00
# TPU doesn't handle batch_size=None well, so use batch_size=1 for TPU
batch_size_setting = 1 if use_tpu else None
2025-10-12 09:11:32 +08:00
self . train_loader = DataLoader (
self . train_dataset ,
2025-10-12 19:49:47 +08:00
batch_size = batch_size_setting , # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1
2025-10-12 09:11:32 +08:00
shuffle = self . args [ ' dataset ' ] [ ' loader_shuffle ' ] ,
2025-10-12 15:31:45 +08:00
num_workers = num_workers ,
2025-10-12 18:49:22 +08:00
pin_memory = not use_tpu # TPU doesn't support pin_memory
2025-10-12 09:11:32 +08:00
)
# val dataset and dataloader
self . val_dataset = BrainToTextDataset (
trial_indicies = val_trials ,
split = ' test ' ,
days_per_batch = None ,
n_batches = None ,
batch_size = self . args [ ' dataset ' ] [ ' batch_size ' ] ,
must_include_days = None ,
random_seed = self . args [ ' dataset ' ] [ ' seed ' ] ,
feature_subset = feature_subset
)
self . val_loader = DataLoader (
self . val_dataset ,
2025-10-12 19:49:47 +08:00
batch_size = batch_size_setting , # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1
2025-10-12 15:31:45 +08:00
shuffle = False ,
num_workers = 0 , # Keep validation dataloader single-threaded for consistency
2025-10-12 18:49:22 +08:00
pin_memory = not use_tpu # TPU doesn't support pin_memory
2025-10-12 09:11:32 +08:00
)
self . logger . info ( " Successfully initialized datasets " )
# Create optimizer, learning rate scheduler, and loss
self . optimizer = self . create_optimizer ( )
if self . args [ ' lr_scheduler_type ' ] == ' linear ' :
self . learning_rate_scheduler = torch . optim . lr_scheduler . LinearLR (
optimizer = self . optimizer ,
start_factor = 1.0 ,
end_factor = self . args [ ' lr_min ' ] / self . args [ ' lr_max ' ] ,
total_iters = self . args [ ' lr_decay_steps ' ] ,
)
elif self . args [ ' lr_scheduler_type ' ] == ' cosine ' :
self . learning_rate_scheduler = self . create_cosine_lr_scheduler ( self . optimizer )
else :
raise ValueError ( f " Invalid learning rate scheduler type: { self . args [ ' lr_scheduler_type ' ] } " )
self . ctc_loss = torch . nn . CTCLoss ( blank = 0 , reduction = ' none ' , zero_infinity = False )
# If a checkpoint is provided, then load from checkpoint
if self . args [ ' init_from_checkpoint ' ] :
self . load_model_checkpoint ( self . args [ ' init_checkpoint_path ' ] )
# Set rnn and/or input layers to not trainable if specified
for name , param in self . model . named_parameters ( ) :
if not self . args [ ' model ' ] [ ' rnn_trainable ' ] and ' gru ' in name :
param . requires_grad = False
elif not self . args [ ' model ' ] [ ' input_network ' ] [ ' input_trainable ' ] and ' day ' in name :
param . requires_grad = False
2025-10-12 15:21:30 +08:00
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
2025-10-12 19:49:47 +08:00
(
self . model ,
self . optimizer ,
self . learning_rate_scheduler ,
self . train_loader ,
self . val_loader ,
) = self . accelerator . prepare (
self . model ,
self . optimizer ,
self . learning_rate_scheduler ,
self . train_loader ,
self . val_loader ,
)
2025-10-12 15:21:30 +08:00
self . logger . info ( " Prepared model and dataloaders with Accelerator " )
2025-10-12 09:11:32 +08:00
def create_optimizer ( self ) :
'''
Create the optimizer with special param groups
Biases and day weights should not be decayed
Day weights should have a separate learning rate
'''
bias_params = [ p for name , p in self . model . named_parameters ( ) if ' gru.bias ' in name or ' out.bias ' in name ]
day_params = [ p for name , p in self . model . named_parameters ( ) if ' day_ ' in name ]
other_params = [ p for name , p in self . model . named_parameters ( ) if ' day_ ' not in name and ' gru.bias ' not in name and ' out.bias ' not in name ]
if len ( day_params ) != 0 :
param_groups = [
{ ' params ' : bias_params , ' weight_decay ' : 0 , ' group_type ' : ' bias ' } ,
{ ' params ' : day_params , ' lr ' : self . args [ ' lr_max_day ' ] , ' weight_decay ' : self . args [ ' weight_decay_day ' ] , ' group_type ' : ' day_layer ' } ,
{ ' params ' : other_params , ' group_type ' : ' other ' }
]
else :
param_groups = [
{ ' params ' : bias_params , ' weight_decay ' : 0 , ' group_type ' : ' bias ' } ,
{ ' params ' : other_params , ' group_type ' : ' other ' }
]
optim = torch . optim . AdamW (
param_groups ,
lr = self . args [ ' lr_max ' ] ,
betas = ( self . args [ ' beta0 ' ] , self . args [ ' beta1 ' ] ) ,
eps = self . args [ ' epsilon ' ] ,
weight_decay = self . args [ ' weight_decay ' ] ,
fused = True
)
return optim
def create_cosine_lr_scheduler ( self , optim ) :
lr_max = self . args [ ' lr_max ' ]
lr_min = self . args [ ' lr_min ' ]
lr_decay_steps = self . args [ ' lr_decay_steps ' ]
lr_max_day = self . args [ ' lr_max_day ' ]
lr_min_day = self . args [ ' lr_min_day ' ]
lr_decay_steps_day = self . args [ ' lr_decay_steps_day ' ]
lr_warmup_steps = self . args [ ' lr_warmup_steps ' ]
lr_warmup_steps_day = self . args [ ' lr_warmup_steps_day ' ]
def lr_lambda ( current_step , min_lr_ratio , decay_steps , warmup_steps ) :
'''
Create lr lambdas for each param group that implement cosine decay
Different lr lambda decaying for day params vs rest of the model
'''
# Warmup phase
if current_step < warmup_steps :
return float ( current_step ) / float ( max ( 1 , warmup_steps ) )
# Cosine decay phase
if current_step < decay_steps :
progress = float ( current_step - warmup_steps ) / float (
max ( 1 , decay_steps - warmup_steps )
)
cosine_decay = 0.5 * ( 1 + math . cos ( math . pi * progress ) )
# Scale from 1.0 to min_lr_ratio
return max ( min_lr_ratio , min_lr_ratio + ( 1 - min_lr_ratio ) * cosine_decay )
# After cosine decay is complete, maintain min_lr_ratio
return min_lr_ratio
if len ( optim . param_groups ) == 3 :
lr_lambdas = [
lambda step : lr_lambda (
step ,
lr_min / lr_max ,
lr_decay_steps ,
lr_warmup_steps ) , # biases
lambda step : lr_lambda (
step ,
lr_min_day / lr_max_day ,
lr_decay_steps_day ,
lr_warmup_steps_day ,
) , # day params
lambda step : lr_lambda (
step ,
lr_min / lr_max ,
lr_decay_steps ,
lr_warmup_steps ) , # rest of model weights
]
elif len ( optim . param_groups ) == 2 :
lr_lambdas = [
lambda step : lr_lambda (
step ,
lr_min / lr_max ,
lr_decay_steps ,
lr_warmup_steps ) , # biases
lambda step : lr_lambda (
step ,
lr_min / lr_max ,
lr_decay_steps ,
lr_warmup_steps ) , # rest of model weights
]
else :
raise ValueError ( f " Invalid number of param groups in optimizer: { len ( optim . param_groups ) } " )
return LambdaLR ( optim , lr_lambdas , - 1 )
def load_model_checkpoint ( self , load_path ) :
'''
2025-10-12 15:31:45 +08:00
Load a training checkpoint for distributed training
'''
# Load checkpoint on CPU first to avoid OOM issues
checkpoint = torch . load ( load_path , map_location = ' cpu ' , weights_only = False ) # checkpoint is just a dict
# Get unwrapped model for loading state dict
unwrapped_model = self . accelerator . unwrap_model ( self . model )
unwrapped_model . load_state_dict ( checkpoint [ ' model_state_dict ' ] )
2025-10-12 09:11:32 +08:00
self . optimizer . load_state_dict ( checkpoint [ ' optimizer_state_dict ' ] )
self . learning_rate_scheduler . load_state_dict ( checkpoint [ ' scheduler_state_dict ' ] )
self . best_val_PER = checkpoint [ ' val_PER ' ] # best phoneme error rate
self . best_val_loss = checkpoint [ ' val_loss ' ] if ' val_loss ' in checkpoint . keys ( ) else torch . inf
2025-10-12 15:31:45 +08:00
# Device handling is managed by Accelerator, no need to manually move to device
2025-10-12 09:11:32 +08:00
self . logger . info ( " Loaded model from checkpoint: " + load_path )
def save_model_checkpoint ( self , save_path , PER , loss ) :
'''
2025-10-12 15:31:45 +08:00
Save a training checkpoint using Accelerator for distributed training
2025-10-12 09:11:32 +08:00
'''
2025-10-12 15:31:45 +08:00
# Only save on main process to avoid conflicts
if self . accelerator . is_main_process :
# Unwrap model to get base model for saving
unwrapped_model = self . accelerator . unwrap_model ( self . model )
2025-10-12 09:11:32 +08:00
2025-10-12 15:31:45 +08:00
checkpoint = {
' model_state_dict ' : unwrapped_model . state_dict ( ) ,
' optimizer_state_dict ' : self . optimizer . state_dict ( ) ,
' scheduler_state_dict ' : self . learning_rate_scheduler . state_dict ( ) ,
' val_PER ' : PER ,
' val_loss ' : loss
}
torch . save ( checkpoint , save_path )
self . logger . info ( " Saved model to checkpoint: " + save_path )
# Save the args file alongside the checkpoint
with open ( os . path . join ( self . args [ ' checkpoint_dir ' ] , ' args.yaml ' ) , ' w ' ) as f :
OmegaConf . save ( config = self . args , f = f )
2025-10-12 09:11:32 +08:00
2025-10-12 15:31:45 +08:00
# Wait for all processes to complete checkpoint saving
self . accelerator . wait_for_everyone ( )
2025-10-12 09:11:32 +08:00
def create_attention_mask ( self , sequence_lengths ) :
max_length = torch . max ( sequence_lengths ) . item ( )
batch_size = sequence_lengths . size ( 0 )
# Create a mask for valid key positions (columns)
# Shape: [batch_size, max_length]
key_mask = torch . arange ( max_length , device = sequence_lengths . device ) . expand ( batch_size , max_length )
key_mask = key_mask < sequence_lengths . unsqueeze ( 1 )
# Expand key_mask to [batch_size, 1, 1, max_length]
# This will be broadcast across all query positions
key_mask = key_mask . unsqueeze ( 1 ) . unsqueeze ( 1 )
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
# by broadcasting key_mask across all query positions
attention_mask = key_mask . expand ( batch_size , 1 , max_length , max_length )
# Convert boolean mask to float mask:
# - True (valid key positions) -> 0.0 (no change to attention scores)
# - False (padding key positions) -> -inf (will become 0 after softmax)
attention_mask_float = torch . where ( attention_mask ,
True ,
False )
return attention_mask_float
def transform_data ( self , features , n_time_steps , mode = ' train ' ) :
'''
Apply various augmentations and smoothing to data
Performing augmentations is much faster on GPU than CPU
'''
2025-10-12 19:49:47 +08:00
# Handle TPU case where DataLoader with batch_size=1 adds an extra dimension
use_tpu = self . args . get ( ' use_tpu ' , False )
if use_tpu and features . dim ( ) == 4 and features . size ( 0 ) == 1 :
features = features . squeeze ( 0 ) # Remove the extra batch dimension added by DataLoader
if isinstance ( n_time_steps , torch . Tensor ) and n_time_steps . dim ( ) == 2 :
n_time_steps = n_time_steps . squeeze ( 0 )
2025-10-12 09:11:32 +08:00
data_shape = features . shape
batch_size = data_shape [ 0 ]
channels = data_shape [ - 1 ]
# We only apply these augmentations in training
if mode == ' train ' :
# add static gain noise
if self . transform_args [ ' static_gain_std ' ] > 0 :
warp_mat = torch . tile ( torch . unsqueeze ( torch . eye ( channels ) , dim = 0 ) , ( batch_size , 1 , 1 ) )
warp_mat + = torch . randn_like ( warp_mat , device = self . device ) * self . transform_args [ ' static_gain_std ' ]
features = torch . matmul ( features , warp_mat )
# add white noise
if self . transform_args [ ' white_noise_std ' ] > 0 :
features + = torch . randn ( data_shape , device = self . device ) * self . transform_args [ ' white_noise_std ' ]
# add constant offset noise
if self . transform_args [ ' constant_offset_std ' ] > 0 :
features + = torch . randn ( ( batch_size , 1 , channels ) , device = self . device ) * self . transform_args [ ' constant_offset_std ' ]
# add random walk noise
if self . transform_args [ ' random_walk_std ' ] > 0 :
features + = torch . cumsum ( torch . randn ( data_shape , device = self . device ) * self . transform_args [ ' random_walk_std ' ] , dim = self . transform_args [ ' random_walk_axis ' ] )
# randomly cutoff part of the data timecourse
if self . transform_args [ ' random_cut ' ] > 0 :
cut = np . random . randint ( 0 , self . transform_args [ ' random_cut ' ] )
features = features [ : , cut : , : ]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing to data
# This is done in both training and validation
if self . transform_args [ ' smooth_data ' ] :
features = gauss_smooth (
inputs = features ,
device = self . device ,
smooth_kernel_std = self . transform_args [ ' smooth_kernel_std ' ] ,
smooth_kernel_size = self . transform_args [ ' smooth_kernel_size ' ] ,
)
return features , n_time_steps
def train ( self ) :
'''
Train the model
'''
# Set model to train mode (specificially to make sure dropout layers are engaged)
self . model . train ( )
# create vars to track performance
train_losses = [ ]
val_losses = [ ]
val_PERs = [ ]
val_results = [ ]
val_steps_since_improvement = 0
# training params
save_best_checkpoint = self . args . get ( ' save_best_checkpoint ' , True )
early_stopping = self . args . get ( ' early_stopping ' , True )
early_stopping_val_steps = self . args [ ' early_stopping_val_steps ' ]
train_start_time = time . time ( )
# train for specified number of batches
for i , batch in enumerate ( self . train_loader ) :
self . model . train ( )
self . optimizer . zero_grad ( )
# Train step
start_time = time . time ( )
2025-10-12 15:21:30 +08:00
# Data is automatically moved to device by Accelerator
features = batch [ ' input_features ' ]
labels = batch [ ' seq_class_ids ' ]
n_time_steps = batch [ ' n_time_steps ' ]
phone_seq_lens = batch [ ' phone_seq_lens ' ]
day_indicies = batch [ ' day_indicies ' ]
2025-10-12 09:11:32 +08:00
2025-10-12 15:21:30 +08:00
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
with self . accelerator . autocast ( ) :
2025-10-12 09:11:32 +08:00
# Apply augmentations to the data
features , n_time_steps = self . transform_data ( features , n_time_steps , ' train ' )
adjusted_lens = ( ( n_time_steps - self . args [ ' model ' ] [ ' patch_size ' ] ) / self . args [ ' model ' ] [ ' patch_stride ' ] + 1 ) . to ( torch . int32 )
2025-10-12 09:35:26 +08:00
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
logits = self . model ( features , day_indicies , None , False , ' inference ' )
2025-10-12 09:11:32 +08:00
# Calculate CTC Loss
loss = self . ctc_loss (
log_probs = torch . permute ( logits . log_softmax ( 2 ) , [ 1 , 0 , 2 ] ) ,
targets = labels ,
input_lengths = adjusted_lens ,
target_lengths = phone_seq_lens
)
loss = torch . mean ( loss ) # take mean loss over batches
2025-10-12 15:21:30 +08:00
# Use Accelerator's backward for distributed training
self . accelerator . backward ( loss )
# Clip gradient using Accelerator's clip_grad_norm_
if self . args [ ' grad_norm_clip_value ' ] > 0 :
grad_norm = self . accelerator . clip_grad_norm_ ( self . model . parameters ( ) ,
max_norm = self . args [ ' grad_norm_clip_value ' ] )
2025-10-12 09:11:32 +08:00
self . optimizer . step ( )
self . learning_rate_scheduler . step ( )
# Save training metrics
train_step_duration = time . time ( ) - start_time
train_losses . append ( loss . detach ( ) . item ( ) )
# Incrementally log training progress
if i % self . args [ ' batches_per_train_log ' ] == 0 :
self . logger . info ( f ' Train batch { i } : ' +
f ' loss: { ( loss . detach ( ) . item ( ) ) : .2f } ' +
f ' grad norm: { grad_norm : .2f } '
f ' time: { train_step_duration : .3f } ' )
# Incrementally run a test step
if i % self . args [ ' batches_per_val_step ' ] == 0 or i == ( ( self . args [ ' num_training_batches ' ] - 1 ) ) :
self . logger . info ( f " Running test after training batch: { i } " )
# Calculate metrics on val data
start_time = time . time ( )
val_metrics = self . validation ( loader = self . val_loader , return_logits = self . args [ ' save_val_logits ' ] , return_data = self . args [ ' save_val_data ' ] )
val_step_duration = time . time ( ) - start_time
# Log info
self . logger . info ( f ' Val batch { i } : ' +
f ' PER (avg): { val_metrics [ " avg_PER " ] : .4f } ' +
f ' CTC Loss (avg): { val_metrics [ " avg_loss " ] : .4f } ' +
f ' time: { val_step_duration : .3f } ' )
if self . args [ ' log_individual_day_val_PER ' ] :
for day in val_metrics [ ' day_PERs ' ] . keys ( ) :
self . logger . info ( f " { self . args [ ' dataset ' ] [ ' sessions ' ] [ day ] } val PER: { val_metrics [ ' day_PERs ' ] [ day ] [ ' total_edit_distance ' ] / val_metrics [ ' day_PERs ' ] [ day ] [ ' total_seq_length ' ] : 0.4f } " )
# Save metrics
val_PERs . append ( val_metrics [ ' avg_PER ' ] )
val_losses . append ( val_metrics [ ' avg_loss ' ] )
val_results . append ( val_metrics )
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
new_best = False
if val_metrics [ ' avg_PER ' ] < self . best_val_PER :
self . logger . info ( f " New best test PER { self . best_val_PER : .4f } --> { val_metrics [ ' avg_PER ' ] : .4f } " )
self . best_val_PER = val_metrics [ ' avg_PER ' ]
self . best_val_loss = val_metrics [ ' avg_loss ' ]
new_best = True
elif val_metrics [ ' avg_PER ' ] == self . best_val_PER and ( val_metrics [ ' avg_loss ' ] < self . best_val_loss ) :
self . logger . info ( f " New best test loss { self . best_val_loss : .4f } --> { val_metrics [ ' avg_loss ' ] : .4f } " )
self . best_val_loss = val_metrics [ ' avg_loss ' ]
new_best = True
if new_best :
# Checkpoint if metrics have improved
if save_best_checkpoint :
self . logger . info ( f " Checkpointing model " )
self . save_model_checkpoint ( f ' { self . args [ " checkpoint_dir " ] } /best_checkpoint ' , self . best_val_PER , self . best_val_loss )
# save validation metrics to pickle file
if self . args [ ' save_val_metrics ' ] :
with open ( f ' { self . args [ " checkpoint_dir " ] } /val_metrics.pkl ' , ' wb ' ) as f :
pickle . dump ( val_metrics , f )
val_steps_since_improvement = 0
else :
val_steps_since_improvement + = 1
# Optionally save this validation checkpoint, regardless of performance
if self . args [ ' save_all_val_steps ' ] :
self . save_model_checkpoint ( f ' { self . args [ " checkpoint_dir " ] } /checkpoint_batch_ { i } ' , val_metrics [ ' avg_PER ' ] )
# Early stopping
if early_stopping and ( val_steps_since_improvement > = early_stopping_val_steps ) :
self . logger . info ( f ' Overall validation PER has not improved in { early_stopping_val_steps } validation steps. Stopping training early at batch: { i } ' )
break
# Log final training steps
training_duration = time . time ( ) - train_start_time
self . logger . info ( f ' Best avg val PER achieved: { self . best_val_PER : .5f } ' )
self . logger . info ( f ' Total training time: { ( training_duration / 60 ) : .2f } minutes ' )
# Save final model
if self . args [ ' save_final_model ' ] :
self . save_model_checkpoint ( f ' { self . args [ " checkpoint_dir " ] } /final_checkpoint_batch_ { i } ' , val_PERs [ - 1 ] )
train_stats = { }
train_stats [ ' train_losses ' ] = train_losses
train_stats [ ' val_losses ' ] = val_losses
train_stats [ ' val_PERs ' ] = val_PERs
train_stats [ ' val_metrics ' ] = val_results
return train_stats
def validation ( self , loader , return_logits = False , return_data = False ) :
'''
Calculate metrics on the validation dataset
'''
self . model . eval ( )
metrics = { }
# Record metrics
if return_logits :
metrics [ ' logits ' ] = [ ]
metrics [ ' n_time_steps ' ] = [ ]
if return_data :
metrics [ ' input_features ' ] = [ ]
metrics [ ' decoded_seqs ' ] = [ ]
metrics [ ' true_seq ' ] = [ ]
metrics [ ' phone_seq_lens ' ] = [ ]
metrics [ ' transcription ' ] = [ ]
metrics [ ' losses ' ] = [ ]
metrics [ ' block_nums ' ] = [ ]
metrics [ ' trial_nums ' ] = [ ]
metrics [ ' day_indicies ' ] = [ ]
total_edit_distance = 0
total_seq_length = 0
# Calculate PER for each specific day
day_per = { }
for d in range ( len ( self . args [ ' dataset ' ] [ ' sessions ' ] ) ) :
if self . args [ ' dataset ' ] [ ' dataset_probability_val ' ] [ d ] == 1 :
day_per [ d ] = { ' total_edit_distance ' : 0 , ' total_seq_length ' : 0 }
2025-10-12 15:31:45 +08:00
for i , batch in enumerate ( loader ) :
2025-10-12 09:11:32 +08:00
2025-10-12 15:31:45 +08:00
# Data is automatically moved to device by Accelerator
features = batch [ ' input_features ' ]
labels = batch [ ' seq_class_ids ' ]
n_time_steps = batch [ ' n_time_steps ' ]
phone_seq_lens = batch [ ' phone_seq_lens ' ]
day_indicies = batch [ ' day_indicies ' ]
2025-10-12 09:11:32 +08:00
# Determine if we should perform validation on this batch
day = day_indicies [ 0 ] . item ( )
if self . args [ ' dataset ' ] [ ' dataset_probability_val ' ] [ day ] == 0 :
if self . args [ ' log_val_skip_logs ' ] :
self . logger . info ( f " Skipping validation on day { day } " )
continue
with torch . no_grad ( ) :
2025-10-12 15:31:45 +08:00
with self . accelerator . autocast ( ) :
2025-10-12 09:11:32 +08:00
features , n_time_steps = self . transform_data ( features , n_time_steps , ' val ' )
adjusted_lens = ( ( n_time_steps - self . args [ ' model ' ] [ ' patch_size ' ] ) / self . args [ ' model ' ] [ ' patch_stride ' ] + 1 ) . to ( torch . int32 )
2025-10-12 09:35:26 +08:00
logits = self . model ( features , day_indicies , None , False , ' inference ' )
2025-10-12 09:11:32 +08:00
loss = self . ctc_loss (
torch . permute ( logits . log_softmax ( 2 ) , [ 1 , 0 , 2 ] ) ,
labels ,
adjusted_lens ,
phone_seq_lens ,
)
loss = torch . mean ( loss )
metrics [ ' losses ' ] . append ( loss . cpu ( ) . detach ( ) . numpy ( ) )
# Calculate PER per day and also avg over entire validation set
batch_edit_distance = 0
decoded_seqs = [ ]
for iterIdx in range ( logits . shape [ 0 ] ) :
decoded_seq = torch . argmax ( logits [ iterIdx , 0 : adjusted_lens [ iterIdx ] , : ] . clone ( ) . detach ( ) , dim = - 1 )
decoded_seq = torch . unique_consecutive ( decoded_seq , dim = - 1 )
decoded_seq = decoded_seq . cpu ( ) . detach ( ) . numpy ( )
decoded_seq = np . array ( [ i for i in decoded_seq if i != 0 ] )
trueSeq = np . array (
labels [ iterIdx ] [ 0 : phone_seq_lens [ iterIdx ] ] . cpu ( ) . detach ( )
)
batch_edit_distance + = F . edit_distance ( decoded_seq , trueSeq )
decoded_seqs . append ( decoded_seq )
day = batch [ ' day_indicies ' ] [ 0 ] . item ( )
day_per [ day ] [ ' total_edit_distance ' ] + = batch_edit_distance
day_per [ day ] [ ' total_seq_length ' ] + = torch . sum ( phone_seq_lens ) . item ( )
total_edit_distance + = batch_edit_distance
total_seq_length + = torch . sum ( phone_seq_lens )
# Record metrics
if return_logits :
metrics [ ' logits ' ] . append ( logits . cpu ( ) . float ( ) . numpy ( ) ) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
metrics [ ' n_time_steps ' ] . append ( adjusted_lens . cpu ( ) . numpy ( ) )
if return_data :
metrics [ ' input_features ' ] . append ( batch [ ' input_features ' ] . cpu ( ) . numpy ( ) )
metrics [ ' decoded_seqs ' ] . append ( decoded_seqs )
metrics [ ' true_seq ' ] . append ( batch [ ' seq_class_ids ' ] . cpu ( ) . numpy ( ) )
metrics [ ' phone_seq_lens ' ] . append ( batch [ ' phone_seq_lens ' ] . cpu ( ) . numpy ( ) )
metrics [ ' transcription ' ] . append ( batch [ ' transcriptions ' ] . cpu ( ) . numpy ( ) )
metrics [ ' losses ' ] . append ( loss . detach ( ) . item ( ) )
metrics [ ' block_nums ' ] . append ( batch [ ' block_nums ' ] . numpy ( ) )
metrics [ ' trial_nums ' ] . append ( batch [ ' trial_nums ' ] . numpy ( ) )
metrics [ ' day_indicies ' ] . append ( batch [ ' day_indicies ' ] . cpu ( ) . numpy ( ) )
avg_PER = total_edit_distance / total_seq_length
metrics [ ' day_PERs ' ] = day_per
metrics [ ' avg_PER ' ] = avg_PER . item ( )
metrics [ ' avg_loss ' ] = np . mean ( metrics [ ' losses ' ] )
2025-10-12 15:31:45 +08:00
return metrics
def inference ( self , features , day_indicies , n_time_steps , mode = ' inference ' ) :
'''
TPU - compatible inference method for generating phoneme logits
'''
self . model . eval ( )
with torch . no_grad ( ) :
with self . accelerator . autocast ( ) :
# Apply data transformations (no augmentation for inference)
features , n_time_steps = self . transform_data ( features , n_time_steps , ' val ' )
# Get phoneme predictions
logits = self . model ( features , day_indicies , None , False , mode )
return logits
def inference_batch ( self , batch , mode = ' inference ' ) :
'''
TPU - compatible inference method for processing a full batch
'''
self . model . eval ( )
# Data is automatically moved to device by Accelerator
features = batch [ ' input_features ' ]
day_indicies = batch [ ' day_indicies ' ]
n_time_steps = batch [ ' n_time_steps ' ]
with torch . no_grad ( ) :
with self . accelerator . autocast ( ) :
# Apply data transformations (no augmentation for inference)
features , n_time_steps = self . transform_data ( features , n_time_steps , ' val ' )
# Calculate adjusted sequence lengths for CTC
adjusted_lens = ( ( n_time_steps - self . args [ ' model ' ] [ ' patch_size ' ] ) / self . args [ ' model ' ] [ ' patch_stride ' ] + 1 ) . to ( torch . int32 )
# Get phoneme predictions
logits = self . model ( features , day_indicies , None , False , mode )
return logits , adjusted_lens