import os # XLA multi-threading optimization - MUST be set before importing torch_xla # Set these environment variables early to ensure they take effect if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ: # Enable XLA multi-threading for compilation speedup os.environ.setdefault('XLA_FLAGS', '--xla_cpu_multi_thread_eigen=true ' + '--xla_cpu_enable_fast_math=true ' + f'--xla_force_host_platform_device_count={os.cpu_count()}' ) # Set PyTorch XLA threading os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count())) print(f"Set XLA compilation threads to {os.cpu_count()}") import torch from torch.utils.data import DataLoader from torch.optim.lr_scheduler import LambdaLR import random import time import numpy as np import math import pathlib import logging import sys import json import pickle from contextlib import nullcontext from dataset import BrainToTextDataset, train_test_split_indicies from data_augmentations import gauss_smooth import torchaudio.functional as F # for edit distance from omegaconf import OmegaConf # Import Accelerate for TPU support from accelerate import Accelerator, DataLoaderConfiguration from accelerate.utils import set_seed # Import XLA after setting environment variables import torch_xla.core.xla_model as xm torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs torch.backends.cudnn.deterministic = True # makes training more reproducible torch._dynamo.config.cache_size_limit = 64 from rnn_model import TripleGRUDecoder class BrainToTextDecoder_Trainer: """ This class will initialize and train a brain-to-text phoneme decoder Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function """ def __init__(self, args): ''' args : dictionary of training arguments ''' # Configure DataLoader behavior for TPU compatibility dataloader_config = DataLoaderConfiguration( even_batches=False # Required for batch_size=None DataLoaders on TPU ) # Initialize Accelerator for TPU/multi-device support self.use_xla = bool(xm.get_xla_supported_devices()) self.amp_requested = args.get('use_amp', True) mixed_precision_mode = 'bf16' if self.amp_requested else 'no' self.accelerator = Accelerator( mixed_precision=mixed_precision_mode, gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, # We'll use our own logging project_dir=args.get('output_dir', './output'), dataloader_config=dataloader_config, ) # Trainer fields self.args = args self.logger = None self.device = self.accelerator.device # Use accelerator device instead of manual device selection self.model = None self.optimizer = None self.learning_rate_scheduler = None self.ctc_loss = None self.best_val_PER = torch.inf # track best PER for checkpointing self.best_val_loss = torch.inf # track best loss for checkpointing self.train_dataset = None self.val_dataset = None self.train_loader = None self.val_loader = None self.transform_args = self.args['dataset']['data_transforms'] # Adversarial training config (safe defaults if not provided) adv_cfg = self.args.get('adversarial', {}) self.adv_enabled = adv_cfg.get('enabled', False) self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps # Create output directory if args['mode'] == 'train': os.makedirs(self.args['output_dir'], exist_ok=True) # Create checkpoint directory if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']: os.makedirs(self.args['checkpoint_dir'], exist_ok=True) # Set up logging self.logger = logging.getLogger(__name__) for handler in self.logger.handlers[:]: # make a copy of the list self.logger.removeHandler(handler) self.logger.setLevel(logging.INFO) formatter = logging.Formatter(fmt='%(asctime)s: %(message)s') if args['mode']=='train': # During training, save logs to file in output directory fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log'))) fh.setFormatter(formatter) self.logger.addHandler(fh) # Always print logs to stdout sh = logging.StreamHandler(sys.stdout) sh.setFormatter(formatter) self.logger.addHandler(sh) # Log device information (managed by Accelerator) self.logger.info(f'Using device: {self.device}') self.logger.info(f'Accelerator state: {self.accelerator.state}') if self.accelerator.num_processes > 1: self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes') if self.use_xla and self.amp_requested: self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.') # Set seed if provided (using Accelerator's set_seed for proper distributed seeding) if self.args['seed'] != -1: set_seed(self.args['seed']) # Initialize the model self.model = TripleGRUDecoder( neural_dim = self.args['model']['n_input_features'], n_units = self.args['model']['n_units'], n_days = len(self.args['dataset']['sessions']), n_classes = self.args['dataset']['n_classes'], rnn_dropout = self.args['model']['rnn_dropout'], input_dropout = self.args['model']['input_network']['input_layer_dropout'], patch_size = self.args['model']['patch_size'], patch_stride = self.args['model']['patch_stride'], ) if self.use_xla and self.amp_requested: self.model = self.model.to(torch.bfloat16) self.logger.info('Converted model parameters to bfloat16 for TPU training.') self.model_dtype = next(self.model.parameters()).dtype # Temporarily disable torch.compile for compatibility with new model architecture # TODO: Re-enable torch.compile once model is stable # self.logger.info("Using torch.compile") # self.model = torch.compile(self.model) self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility") self.logger.info(f"Initialized RNN decoding model") self.logger.info(self.model) # Log how many parameters are in the model total_params = sum(p.numel() for p in self.model.parameters()) self.logger.info(f"Model has {total_params:,} parameters") # Determine how many day-specific parameters are in the model day_params = 0 for name, param in self.model.named_parameters(): if 'day' in name: day_params += param.numel() self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters") # Create datasets and dataloaders train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']] val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']] # Ensure that there are no duplicate days if len(set(train_file_paths)) != len(train_file_paths): raise ValueError("There are duplicate sessions listed in the train dataset") if len(set(val_file_paths)) != len(val_file_paths): raise ValueError("There are duplicate sessions listed in the val dataset") # Split trials into train and test sets train_trials, _ = train_test_split_indicies( file_paths = train_file_paths, test_percentage = 0, seed = self.args['dataset']['seed'], bad_trials_dict = None, ) _, val_trials = train_test_split_indicies( file_paths = val_file_paths, test_percentage = 1, seed = self.args['dataset']['seed'], bad_trials_dict = None, ) # Save dictionaries to output directory to know which trials were train vs val with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: json.dump({'train' : train_trials, 'val': val_trials}, f) # Determine if a only a subset of neural features should be used feature_subset = None if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None: feature_subset = self.args['dataset']['feature_subset'] self.logger.info(f'Using only a subset of features: {feature_subset}') # train dataset and dataloader self.train_dataset = BrainToTextDataset( trial_indicies = train_trials, split = 'train', days_per_batch = self.args['dataset']['days_per_batch'], n_batches = self.args['num_training_batches'], batch_size = self.args['dataset']['batch_size'], must_include_days = None, random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) # Custom collate function that handles pre-batched data from our dataset def collate_fn(batch): # Our dataset returns full batches, so batch will be a list of single batch dict # Extract the first (and only) element since our dataset.__getitem__() returns a full batch if len(batch) == 1 and isinstance(batch[0], dict): return batch[0] else: # Fallback for unexpected batch structure return batch # DataLoader configuration compatible with Accelerate self.train_loader = DataLoader( self.train_dataset, batch_size = 1, # Use batch_size=1 since dataset returns full batches shuffle = self.args['dataset']['loader_shuffle'], num_workers = self.args['dataset']['num_dataloader_workers'], pin_memory = True, collate_fn = collate_fn ) # val dataset and dataloader self.val_dataset = BrainToTextDataset( trial_indicies = val_trials, split = 'test', days_per_batch = None, n_batches = None, batch_size = self.args['dataset']['batch_size'], must_include_days = None, random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) # Validation DataLoader with same collate function self.val_loader = DataLoader( self.val_dataset, batch_size = 1, # Use batch_size=1 since dataset returns full batches shuffle = False, num_workers = 0, # Keep validation dataloader single-threaded for consistency pin_memory = True, collate_fn = collate_fn # Use same collate function ) self.logger.info("Successfully initialized datasets") # Create optimizer, learning rate scheduler, and loss self.optimizer = self.create_optimizer() if self.args['lr_scheduler_type'] == 'linear': self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer = self.optimizer, start_factor = 1.0, end_factor = self.args['lr_min'] / self.args['lr_max'], total_iters = self.args['lr_decay_steps'], ) elif self.args['lr_scheduler_type'] == 'cosine': self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer) else: raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}") self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False) # If a checkpoint is provided, then load from checkpoint if self.args['init_from_checkpoint']: self.load_model_checkpoint(self.args['init_checkpoint_path']) # Set rnn and/or input layers to not trainable if specified for name, param in self.model.named_parameters(): if not self.args['model']['rnn_trainable'] and 'gru' in name: param.requires_grad = False elif not self.args['model']['input_network']['input_trainable'] and 'day' in name: param.requires_grad = False # Prepare model, optimizer, scheduler, and dataloaders for distributed training # Let Accelerator handle everything automatically for both GPU and TPU ( self.model, self.optimizer, self.learning_rate_scheduler, self.train_loader, self.val_loader, ) = self.accelerator.prepare( self.model, self.optimizer, self.learning_rate_scheduler, self.train_loader, self.val_loader, ) self.model_dtype = next(self.model.parameters()).dtype self.logger.info("Prepared model and dataloaders with Accelerator") if self.adv_enabled: self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}") def autocast_context(self): """Return appropriate autocast context; disable on XLA to avoid dtype mismatches.""" if self.device.type == 'xla': return nullcontext() return self.accelerator.autocast() def create_optimizer(self): ''' Create the optimizer with special param groups Biases and day weights should not be decayed Day weights should have a separate learning rate ''' bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name] day_params = [p for name, p in self.model.named_parameters() if 'day_' in name] other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name] if len(day_params) != 0: param_groups = [ {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, {'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'}, {'params' : other_params, 'group_type' : 'other'} ] else: param_groups = [ {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, {'params' : other_params, 'group_type' : 'other'} ] optim = torch.optim.AdamW( param_groups, lr = self.args['lr_max'], betas = (self.args['beta0'], self.args['beta1']), eps = self.args['epsilon'], weight_decay = self.args['weight_decay'], fused = True ) return optim def create_cosine_lr_scheduler(self, optim): lr_max = self.args['lr_max'] lr_min = self.args['lr_min'] lr_decay_steps = self.args['lr_decay_steps'] lr_max_day = self.args['lr_max_day'] lr_min_day = self.args['lr_min_day'] lr_decay_steps_day = self.args['lr_decay_steps_day'] lr_warmup_steps = self.args['lr_warmup_steps'] lr_warmup_steps_day = self.args['lr_warmup_steps_day'] def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps): ''' Create lr lambdas for each param group that implement cosine decay Different lr lambda decaying for day params vs rest of the model ''' # Warmup phase if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) # Cosine decay phase if current_step < decay_steps: progress = float(current_step - warmup_steps) / float( max(1, decay_steps - warmup_steps) ) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) # Scale from 1.0 to min_lr_ratio return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay) # After cosine decay is complete, maintain min_lr_ratio return min_lr_ratio if len(optim.param_groups) == 3: lr_lambdas = [ lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # biases lambda step: lr_lambda( step, lr_min_day / lr_max_day, lr_decay_steps_day, lr_warmup_steps_day, ), # day params lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # rest of model weights ] elif len(optim.param_groups) == 2: lr_lambdas = [ lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # biases lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # rest of model weights ] else: raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}") return LambdaLR(optim, lr_lambdas, -1) def load_model_checkpoint(self, load_path): ''' Load a training checkpoint for distributed training ''' # Load checkpoint on CPU first to avoid OOM issues checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict # Get unwrapped model for loading state dict unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf # Device handling is managed by Accelerator, no need to manually move to device self.logger.info("Loaded model from checkpoint: " + load_path) def save_model_checkpoint(self, save_path, PER, loss): ''' Save a training checkpoint using Accelerator for distributed training ''' # Only save on main process to avoid conflicts if self.accelerator.is_main_process: # Unwrap model to get base model for saving unwrapped_model = self.accelerator.unwrap_model(self.model) checkpoint = { 'model_state_dict' : unwrapped_model.state_dict(), 'optimizer_state_dict' : self.optimizer.state_dict(), 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(), 'val_PER' : PER, 'val_loss' : loss } torch.save(checkpoint, save_path) self.logger.info("Saved model to checkpoint: " + save_path) # Save the args file alongside the checkpoint with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: OmegaConf.save(config=self.args, f=f) # Wait for all processes to complete checkpoint saving self.accelerator.wait_for_everyone() def create_attention_mask(self, sequence_lengths): max_length = torch.max(sequence_lengths).item() batch_size = sequence_lengths.size(0) # Create a mask for valid key positions (columns) # Shape: [batch_size, max_length] key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length) key_mask = key_mask < sequence_lengths.unsqueeze(1) # Expand key_mask to [batch_size, 1, 1, max_length] # This will be broadcast across all query positions key_mask = key_mask.unsqueeze(1).unsqueeze(1) # Create the attention mask of shape [batch_size, 1, max_length, max_length] # by broadcasting key_mask across all query positions attention_mask = key_mask.expand(batch_size, 1, max_length, max_length) # Convert boolean mask to float mask: # - True (valid key positions) -> 0.0 (no change to attention scores) # - False (padding key positions) -> -inf (will become 0 after softmax) attention_mask_float = torch.where(attention_mask, True, False) return attention_mask_float def transform_data(self, features, n_time_steps, mode = 'train'): ''' Apply various augmentations and smoothing to data Performing augmentations is much faster on GPU than CPU ''' # TPU and GPU should now handle data consistently with our improved DataLoader configuration data_shape = features.shape batch_size = data_shape[0] channels = data_shape[-1] # We only apply these augmentations in training if mode == 'train': # add static gain noise if self.transform_args['static_gain_std'] > 0: warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1)) warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std'] features = torch.matmul(features, warp_mat) # add white noise if self.transform_args['white_noise_std'] > 0: features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std'] # add constant offset noise if self.transform_args['constant_offset_std'] > 0: features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std'] # add random walk noise if self.transform_args['random_walk_std'] > 0: features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis']) # randomly cutoff part of the data timecourse if self.transform_args['random_cut'] > 0: cut = np.random.randint(0, self.transform_args['random_cut']) features = features[:, cut:, :] n_time_steps = n_time_steps - cut # Apply Gaussian smoothing to data # This is done in both training and validation if self.transform_args['smooth_data']: features = gauss_smooth( inputs = features, device = self.device, smooth_kernel_std = self.transform_args['smooth_kernel_std'], smooth_kernel_size= self.transform_args['smooth_kernel_size'], ) if hasattr(self, 'model_dtype'): features = features.to(self.model_dtype) return features, n_time_steps def train(self): ''' Train the model ''' # Set model to train mode (specificially to make sure dropout layers are engaged) self.model.train() # create vars to track performance train_losses = [] val_losses = [] val_PERs = [] val_results = [] val_steps_since_improvement = 0 # training params save_best_checkpoint = self.args.get('save_best_checkpoint', True) early_stopping = self.args.get('early_stopping', True) early_stopping_val_steps = self.args.get('early_stopping_val_steps', 20) train_start_time = time.time() # train for specified number of batches self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...") for i, batch in enumerate(self.train_loader): self.model.train() self.optimizer.zero_grad() # Train step start_time = time.time() # Data is automatically moved to device by Accelerator features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indicies = batch['day_indicies'] # Use Accelerator's autocast (mixed precision handled by Accelerator init) with self.autocast_context(): # Apply augmentations to the data features, n_time_steps = self.transform_data(features, n_time_steps, 'train') # Ensure proper dtype handling for TPU mixed precision adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Get phoneme predictions using inference mode during training # (We use inference mode for simplicity - only clean logits are used for CTC loss) # Ensure features tensor matches model parameter dtype for TPU compatibility if features.dtype != self.model_dtype: features = features.to(self.model_dtype) # Forward pass: enable full adversarial mode if configured and past warmup use_full = self.adv_enabled and (i >= self.adv_warmup_steps) if use_full: clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda) else: logits = self.model(features, day_indicies, None, False, 'inference') # Calculate CTC Loss if use_full: # Clean CTC loss clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2) clean_loss = self.ctc_loss( clean_log_probs, labels, adjusted_lens, phone_seq_lens ) clean_loss = torch.mean(clean_loss) # Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗) noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2) noisy_loss = self.ctc_loss( noisy_log_probs, labels, adjusted_lens, phone_seq_lens ) noisy_loss = torch.mean(noisy_loss) # Optional noise energy regularization noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype) if self.adv_noise_l2_weight > 0.0: noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype) loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2) loss = self.ctc_loss( log_probs=log_probs, targets=labels, input_lengths=adjusted_lens, target_lengths=phone_seq_lens ) loss = torch.mean(loss) # take mean loss over batches # Use Accelerator's backward for distributed training self.accelerator.backward(loss) # Clip gradient using Accelerator's clip_grad_norm_ if self.args['grad_norm_clip_value'] > 0: grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(), max_norm = self.args['grad_norm_clip_value']) self.optimizer.step() self.learning_rate_scheduler.step() # Save training metrics train_step_duration = time.time() - start_time train_losses.append(loss.detach().item()) # Incrementally log training progress if i % self.args['batches_per_train_log'] == 0: self.logger.info(f'Train batch {i}: ' + f'loss: {(loss.detach().item()):.2f} ' + f'grad norm: {grad_norm:.2f} ' f'time: {train_step_duration:.3f}') # Incrementally run a test step if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)): self.logger.info(f"Running test after training batch: {i}") # Calculate metrics on val data start_time = time.time() val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data']) val_step_duration = time.time() - start_time # Log info self.logger.info(f'Val batch {i}: ' + f'PER (avg): {val_metrics["avg_PER"]:.4f} ' + f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' + f'time: {val_step_duration:.3f}') if self.args['log_individual_day_val_PER']: for day in val_metrics['day_PERs'].keys(): self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}") # Save metrics val_PERs.append(val_metrics['avg_PER']) val_losses.append(val_metrics['avg_loss']) val_results.append(val_metrics) # Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower new_best = False if val_metrics['avg_PER'] < self.best_val_PER: self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}") self.best_val_PER = val_metrics['avg_PER'] self.best_val_loss = val_metrics['avg_loss'] new_best = True elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss): self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}") self.best_val_loss = val_metrics['avg_loss'] new_best = True if new_best: # Checkpoint if metrics have improved if save_best_checkpoint: self.logger.info(f"Checkpointing model") self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss) # save validation metrics to pickle file if self.args['save_val_metrics']: with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: pickle.dump(val_metrics, f) val_steps_since_improvement = 0 else: val_steps_since_improvement +=1 # Optionally save this validation checkpoint, regardless of performance if self.args['save_all_val_steps']: self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss']) # Early stopping if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps): self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}') break # Log final training steps training_duration = time.time() - train_start_time self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}') self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes') # Save final model if self.args['save_final_model']: last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss) train_stats = {} train_stats['train_losses'] = train_losses train_stats['val_losses'] = val_losses train_stats['val_PERs'] = val_PERs train_stats['val_metrics'] = val_results return train_stats def validation(self, loader, return_logits = False, return_data = False): ''' Calculate metrics on the validation dataset ''' self.model.eval() metrics = {} # Record metrics if return_logits: metrics['logits'] = [] metrics['n_time_steps'] = [] if return_data: metrics['input_features'] = [] metrics['decoded_seqs'] = [] metrics['true_seq'] = [] metrics['phone_seq_lens'] = [] metrics['transcription'] = [] metrics['losses'] = [] metrics['block_nums'] = [] metrics['trial_nums'] = [] metrics['day_indicies'] = [] total_edit_distance = 0 total_seq_length = 0 # Calculate PER for each specific day day_per = {} for d in range(len(self.args['dataset']['sessions'])): if self.args['dataset']['dataset_probability_val'][d] == 1: day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0} for i, batch in enumerate(loader): # Data is automatically moved to device by Accelerator features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indicies = batch['day_indicies'] # Determine if we should perform validation on this batch day = day_indicies[0].item() if self.args['dataset']['dataset_probability_val'][day] == 0: if self.args['log_val_skip_logs']: self.logger.info(f"Skipping validation on day {day}") continue with torch.no_grad(): with self.autocast_context(): features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Ensure proper dtype handling for TPU mixed precision adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Ensure features tensor matches model parameter dtype for TPU compatibility model_param = next(self.model.parameters()) if self.model is not None else None if model_param is not None and features.dtype != model_param.dtype: features = features.to(model_param.dtype) logits = self.model(features, day_indicies, None, False, 'inference') val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2) loss = self.ctc_loss( val_log_probs, labels, adjusted_lens, phone_seq_lens, ) loss = torch.mean(loss) metrics['losses'].append(loss.cpu().detach().numpy()) # Calculate PER per day and also avg over entire validation set batch_edit_distance = 0 decoded_seqs = [] for iterIdx in range(logits.shape[0]): decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1) decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1) decoded_seq = decoded_seq.cpu().detach().numpy() decoded_seq = np.array([i for i in decoded_seq if i != 0]) trueSeq = np.array( labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach() ) batch_edit_distance += F.edit_distance(decoded_seq, trueSeq) decoded_seqs.append(decoded_seq) day = batch['day_indicies'][0].item() day_per[day]['total_edit_distance'] += batch_edit_distance day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item() total_edit_distance += batch_edit_distance total_seq_length += torch.sum(phone_seq_lens) # Record metrics if return_logits: metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32 metrics['n_time_steps'].append(adjusted_lens.cpu().numpy()) if return_data: metrics['input_features'].append(batch['input_features'].cpu().numpy()) metrics['decoded_seqs'].append(decoded_seqs) metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy()) metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy()) metrics['transcription'].append(batch['transcriptions'].cpu().numpy()) metrics['losses'].append(loss.detach().item()) metrics['block_nums'].append(batch['block_nums'].numpy()) metrics['trial_nums'].append(batch['trial_nums'].numpy()) metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy()) if isinstance(total_seq_length, torch.Tensor): total_length_value = float(total_seq_length.item()) else: total_length_value = float(total_seq_length) avg_PER = total_edit_distance / max(total_length_value, 1e-6) metrics['day_PERs'] = day_per metrics['avg_PER'] = avg_PER metrics['avg_loss'] = float(np.mean(metrics['losses'])) return metrics def inference(self, features, day_indicies, n_time_steps, mode='inference'): ''' TPU-compatible inference method for generating phoneme logits ''' self.model.eval() with torch.no_grad(): with self.autocast_context(): # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Ensure features tensor matches model parameter dtype for TPU compatibility if features.dtype != self.model_dtype: features = features.to(self.model_dtype) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode) return logits def inference_batch(self, batch, mode='inference'): ''' Inference method for processing a full batch ''' self.model.eval() # Data is automatically moved to device by Accelerator features = batch['input_features'] day_indicies = batch['day_indicies'] n_time_steps = batch['n_time_steps'] with torch.no_grad(): with self.autocast_context(): # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Calculate adjusted sequence lengths for CTC with proper dtype handling adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Ensure features tensor matches model parameter dtype for TPU compatibility if features.dtype != self.model_dtype: features = features.to(self.model_dtype) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode) return logits, adjusted_lens