947 lines
		
	
	
		
			41 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			947 lines
		
	
	
		
			41 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | ||
| from torch.utils.data import DataLoader
 | ||
| from torch.optim.lr_scheduler import LambdaLR
 | ||
| import random
 | ||
| import time
 | ||
| import os
 | ||
| import numpy as np
 | ||
| import math
 | ||
| import pathlib
 | ||
| import logging
 | ||
| import sys
 | ||
| import json
 | ||
| import pickle
 | ||
| from contextlib import nullcontext
 | ||
| 
 | ||
| from dataset import BrainToTextDataset, train_test_split_indicies
 | ||
| from data_augmentations import gauss_smooth
 | ||
| 
 | ||
| import torchaudio.functional as F # for edit distance
 | ||
| from omegaconf import OmegaConf
 | ||
| 
 | ||
| # Import Accelerate for TPU support
 | ||
| from accelerate import Accelerator, DataLoaderConfiguration
 | ||
| from accelerate.utils import set_seed
 | ||
| 
 | ||
| # XLA multi-threading optimization for faster compilation
 | ||
| import torch_xla.core.xla_model as xm
 | ||
| if xm.get_xla_supported_devices():
 | ||
|     # Enable XLA multi-threading for compilation speedup
 | ||
|     os.environ.setdefault('XLA_FLAGS',
 | ||
|         '--xla_cpu_multi_thread_eigen=true ' +
 | ||
|         '--xla_cpu_enable_fast_math=true ' +
 | ||
|         f'--xla_force_host_platform_device_count={os.cpu_count()}'
 | ||
|     )
 | ||
|     # Set PyTorch XLA threading
 | ||
|     os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count()))
 | ||
| 
 | ||
| torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
 | ||
| torch.backends.cudnn.deterministic = True # makes training more reproducible
 | ||
| torch._dynamo.config.cache_size_limit = 64
 | ||
| 
 | ||
| from rnn_model import TripleGRUDecoder
 | ||
| 
 | ||
| class BrainToTextDecoder_Trainer:
 | ||
|     """
 | ||
|     This class will initialize and train a brain-to-text phoneme decoder
 | ||
|     
 | ||
|     Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
 | ||
|     """
 | ||
| 
 | ||
|     def __init__(self, args):
 | ||
|         '''
 | ||
|         args : dictionary of training arguments
 | ||
|         '''
 | ||
| 
 | ||
|         # Configure DataLoader behavior for TPU compatibility
 | ||
|         dataloader_config = DataLoaderConfiguration(
 | ||
|             even_batches=False  # Required for batch_size=None DataLoaders on TPU
 | ||
|         )
 | ||
| 
 | ||
|         # Initialize Accelerator for TPU/multi-device support
 | ||
|         self.use_xla = bool(xm.get_xla_supported_devices())
 | ||
|         self.amp_requested = args.get('use_amp', True)
 | ||
|         mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
 | ||
| 
 | ||
|         self.accelerator = Accelerator(
 | ||
|             mixed_precision=mixed_precision_mode,
 | ||
|             gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
 | ||
|             log_with=None,  # We'll use our own logging
 | ||
|             project_dir=args.get('output_dir', './output'),
 | ||
|             dataloader_config=dataloader_config,
 | ||
|         )
 | ||
| 
 | ||
| 
 | ||
|         # Trainer fields
 | ||
|         self.args = args
 | ||
|         self.logger = None
 | ||
|         self.device = self.accelerator.device  # Use accelerator device instead of manual device selection
 | ||
|         self.model = None
 | ||
|         self.optimizer = None
 | ||
|         self.learning_rate_scheduler = None
 | ||
|         self.ctc_loss = None
 | ||
| 
 | ||
|         self.best_val_PER = torch.inf # track best PER for checkpointing
 | ||
|         self.best_val_loss = torch.inf # track best loss for checkpointing
 | ||
| 
 | ||
|         self.train_dataset = None
 | ||
|         self.val_dataset = None
 | ||
|         self.train_loader = None
 | ||
|         self.val_loader = None
 | ||
| 
 | ||
|         self.transform_args = self.args['dataset']['data_transforms']
 | ||
| 
 | ||
|         # Adversarial training config (safe defaults if not provided)
 | ||
|         adv_cfg = self.args.get('adversarial', {})
 | ||
|         self.adv_enabled = adv_cfg.get('enabled', False)
 | ||
|         self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5))  # GRL strength
 | ||
|         self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2))  # weight for noisy branch CTC
 | ||
|         self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))  # optional L2 on noise output
 | ||
|         self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))  # delay enabling adversarial after N steps
 | ||
| 
 | ||
|         # Create output directory
 | ||
|         if args['mode'] == 'train':
 | ||
|             os.makedirs(self.args['output_dir'], exist_ok=True)
 | ||
| 
 | ||
|         # Create checkpoint directory
 | ||
|         if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
 | ||
|             os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
 | ||
| 
 | ||
|         # Set up logging
 | ||
|         self.logger = logging.getLogger(__name__)
 | ||
|         for handler in self.logger.handlers[:]:  # make a copy of the list
 | ||
|             self.logger.removeHandler(handler)
 | ||
|         self.logger.setLevel(logging.INFO)
 | ||
|         formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')        
 | ||
| 
 | ||
|         if args['mode']=='train':
 | ||
|             # During training, save logs to file in output directory
 | ||
|             fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
 | ||
|             fh.setFormatter(formatter)
 | ||
|             self.logger.addHandler(fh)
 | ||
| 
 | ||
|         # Always print logs to stdout
 | ||
|         sh = logging.StreamHandler(sys.stdout)
 | ||
|         sh.setFormatter(formatter)
 | ||
|         self.logger.addHandler(sh)
 | ||
| 
 | ||
|         # Log device information (managed by Accelerator)
 | ||
|         self.logger.info(f'Using device: {self.device}')
 | ||
|         self.logger.info(f'Accelerator state: {self.accelerator.state}')
 | ||
|         if self.accelerator.num_processes > 1:
 | ||
|             self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
 | ||
|         if self.use_xla and self.amp_requested:
 | ||
|             self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
 | ||
| 
 | ||
|         # Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
 | ||
|         if self.args['seed'] != -1:
 | ||
|             set_seed(self.args['seed'])
 | ||
| 
 | ||
|         # Initialize the model
 | ||
|         self.model = TripleGRUDecoder(
 | ||
|             neural_dim = self.args['model']['n_input_features'],
 | ||
|             n_units = self.args['model']['n_units'],
 | ||
|             n_days = len(self.args['dataset']['sessions']),
 | ||
|             n_classes  = self.args['dataset']['n_classes'],
 | ||
|             rnn_dropout = self.args['model']['rnn_dropout'],
 | ||
|             input_dropout = self.args['model']['input_network']['input_layer_dropout'],
 | ||
|             patch_size = self.args['model']['patch_size'],
 | ||
|             patch_stride = self.args['model']['patch_stride'],
 | ||
|         )
 | ||
| 
 | ||
|         if self.use_xla and self.amp_requested:
 | ||
|             self.model = self.model.to(torch.bfloat16)
 | ||
|             self.logger.info('Converted model parameters to bfloat16 for TPU training.')
 | ||
| 
 | ||
|         self.model_dtype = next(self.model.parameters()).dtype
 | ||
| 
 | ||
|         # Temporarily disable torch.compile for compatibility with new model architecture
 | ||
|         # TODO: Re-enable torch.compile once model is stable
 | ||
|         # self.logger.info("Using torch.compile")
 | ||
|         # self.model = torch.compile(self.model)
 | ||
|         self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
 | ||
| 
 | ||
|         self.logger.info(f"Initialized RNN decoding model")
 | ||
| 
 | ||
|         self.logger.info(self.model)
 | ||
| 
 | ||
|         # Log how many parameters are in the model
 | ||
|         total_params = sum(p.numel() for p in self.model.parameters())
 | ||
|         self.logger.info(f"Model has {total_params:,} parameters")
 | ||
| 
 | ||
|         # Determine how many day-specific parameters are in the model
 | ||
|         day_params = 0
 | ||
|         for name, param in self.model.named_parameters():
 | ||
|             if 'day' in name:
 | ||
|                 day_params += param.numel()
 | ||
|         
 | ||
|         self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
 | ||
| 
 | ||
|         # Create datasets and dataloaders
 | ||
|         train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
 | ||
|         val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
 | ||
| 
 | ||
|         # Ensure that there are no duplicate days
 | ||
|         if len(set(train_file_paths)) != len(train_file_paths):
 | ||
|             raise ValueError("There are duplicate sessions listed in the train dataset")
 | ||
|         if len(set(val_file_paths)) != len(val_file_paths):
 | ||
|             raise ValueError("There are duplicate sessions listed in the val dataset")
 | ||
| 
 | ||
|         # Split trials into train and test sets
 | ||
|         train_trials, _ = train_test_split_indicies(
 | ||
|             file_paths = train_file_paths, 
 | ||
|             test_percentage = 0,
 | ||
|             seed = self.args['dataset']['seed'],
 | ||
|             bad_trials_dict = None,
 | ||
|             )
 | ||
|         _, val_trials = train_test_split_indicies(
 | ||
|             file_paths = val_file_paths, 
 | ||
|             test_percentage = 1,
 | ||
|             seed = self.args['dataset']['seed'],
 | ||
|             bad_trials_dict = None,
 | ||
|             )
 | ||
| 
 | ||
|         # Save dictionaries to output directory to know which trials were train vs val 
 | ||
|         with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: 
 | ||
|             json.dump({'train' : train_trials, 'val': val_trials}, f)
 | ||
| 
 | ||
|         # Determine if a only a subset of neural features should be used
 | ||
|         feature_subset = None
 | ||
|         if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None: 
 | ||
|             feature_subset = self.args['dataset']['feature_subset']
 | ||
|             self.logger.info(f'Using only a subset of features: {feature_subset}')
 | ||
|             
 | ||
|         # train dataset and dataloader
 | ||
|         self.train_dataset = BrainToTextDataset(
 | ||
|             trial_indicies = train_trials,
 | ||
|             split = 'train',
 | ||
|             days_per_batch = self.args['dataset']['days_per_batch'],
 | ||
|             n_batches = self.args['num_training_batches'],
 | ||
|             batch_size = self.args['dataset']['batch_size'],
 | ||
|             must_include_days = None,
 | ||
|             random_seed = self.args['dataset']['seed'],
 | ||
|             feature_subset = feature_subset
 | ||
|             )
 | ||
|         # Custom collate function that handles pre-batched data from our dataset
 | ||
|         def collate_fn(batch):
 | ||
|             # Our dataset returns full batches, so batch will be a list of single batch dict
 | ||
|             # Extract the first (and only) element since our dataset.__getitem__() returns a full batch
 | ||
|             if len(batch) == 1 and isinstance(batch[0], dict):
 | ||
|                 return batch[0]
 | ||
|             else:
 | ||
|                 # Fallback for unexpected batch structure
 | ||
|                 return batch
 | ||
| 
 | ||
|         # DataLoader configuration compatible with Accelerate
 | ||
|         self.train_loader = DataLoader(
 | ||
|             self.train_dataset,
 | ||
|             batch_size = 1,  # Use batch_size=1 since dataset returns full batches
 | ||
|             shuffle = self.args['dataset']['loader_shuffle'],
 | ||
|             num_workers = self.args['dataset']['num_dataloader_workers'],
 | ||
|             pin_memory = True,
 | ||
|             collate_fn = collate_fn
 | ||
|         )
 | ||
| 
 | ||
|         # val dataset and dataloader
 | ||
|         self.val_dataset = BrainToTextDataset(
 | ||
|             trial_indicies = val_trials, 
 | ||
|             split = 'test',
 | ||
|             days_per_batch = None,
 | ||
|             n_batches = None,
 | ||
|             batch_size = self.args['dataset']['batch_size'],
 | ||
|             must_include_days = None,
 | ||
|             random_seed = self.args['dataset']['seed'],
 | ||
|             feature_subset = feature_subset   
 | ||
|             )
 | ||
|         # Validation DataLoader with same collate function
 | ||
|         self.val_loader = DataLoader(
 | ||
|             self.val_dataset,
 | ||
|             batch_size = 1,  # Use batch_size=1 since dataset returns full batches
 | ||
|             shuffle = False,
 | ||
|             num_workers = 0,  # Keep validation dataloader single-threaded for consistency
 | ||
|             pin_memory = True,
 | ||
|             collate_fn = collate_fn  # Use same collate function
 | ||
|         )
 | ||
| 
 | ||
|         self.logger.info("Successfully initialized datasets")
 | ||
| 
 | ||
|         # Create optimizer, learning rate scheduler, and loss
 | ||
|         self.optimizer = self.create_optimizer()
 | ||
| 
 | ||
|         if self.args['lr_scheduler_type'] == 'linear':
 | ||
|             self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
 | ||
|                 optimizer = self.optimizer,
 | ||
|                 start_factor = 1.0,
 | ||
|                 end_factor = self.args['lr_min'] / self.args['lr_max'],
 | ||
|                 total_iters = self.args['lr_decay_steps'],
 | ||
|             )
 | ||
|         elif self.args['lr_scheduler_type'] == 'cosine':
 | ||
|             self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
 | ||
|         
 | ||
|         else:
 | ||
|             raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
 | ||
|         
 | ||
|         self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
 | ||
| 
 | ||
|         # If a checkpoint is provided, then load from checkpoint
 | ||
|         if self.args['init_from_checkpoint']:
 | ||
|             self.load_model_checkpoint(self.args['init_checkpoint_path'])
 | ||
| 
 | ||
|         # Set rnn and/or input layers to not trainable if specified 
 | ||
|         for name, param in self.model.named_parameters():
 | ||
|             if not self.args['model']['rnn_trainable'] and 'gru' in name:
 | ||
|                 param.requires_grad = False
 | ||
| 
 | ||
|             elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
 | ||
|                 param.requires_grad = False
 | ||
| 
 | ||
|         # Prepare model, optimizer, scheduler, and dataloaders for distributed training
 | ||
|         # Let Accelerator handle everything automatically for both GPU and TPU
 | ||
|         (
 | ||
|             self.model,
 | ||
|             self.optimizer,
 | ||
|             self.learning_rate_scheduler,
 | ||
|             self.train_loader,
 | ||
|             self.val_loader,
 | ||
|         ) = self.accelerator.prepare(
 | ||
|             self.model,
 | ||
|             self.optimizer,
 | ||
|             self.learning_rate_scheduler,
 | ||
|             self.train_loader,
 | ||
|             self.val_loader,
 | ||
|         )
 | ||
| 
 | ||
|         self.model_dtype = next(self.model.parameters()).dtype
 | ||
| 
 | ||
|         self.logger.info("Prepared model and dataloaders with Accelerator")
 | ||
|         if self.adv_enabled:
 | ||
|             self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
 | ||
| 
 | ||
|     def autocast_context(self):
 | ||
|         """Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
 | ||
|         if self.device.type == 'xla':
 | ||
|             return nullcontext()
 | ||
|         return self.accelerator.autocast()
 | ||
| 
 | ||
|     def create_optimizer(self):
 | ||
|         '''
 | ||
|         Create the optimizer with special param groups 
 | ||
| 
 | ||
|         Biases and day weights should not be decayed
 | ||
| 
 | ||
|         Day weights should have a separate learning rate
 | ||
|         '''
 | ||
|         bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
 | ||
|         day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
 | ||
|         other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
 | ||
| 
 | ||
|         if len(day_params) != 0:
 | ||
|             param_groups = [
 | ||
|                     {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
 | ||
|                     {'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
 | ||
|                     {'params' : other_params, 'group_type' : 'other'}
 | ||
|                 ]
 | ||
|         else: 
 | ||
|             param_groups = [
 | ||
|                     {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
 | ||
|                     {'params' : other_params, 'group_type' : 'other'}
 | ||
|                 ]
 | ||
|             
 | ||
|         optim = torch.optim.AdamW(
 | ||
|             param_groups,
 | ||
|             lr = self.args['lr_max'],
 | ||
|             betas = (self.args['beta0'], self.args['beta1']),
 | ||
|             eps = self.args['epsilon'],
 | ||
|             weight_decay = self.args['weight_decay'],
 | ||
|             fused = True
 | ||
|         )
 | ||
| 
 | ||
|         return optim 
 | ||
| 
 | ||
|     def create_cosine_lr_scheduler(self, optim):
 | ||
|         lr_max = self.args['lr_max']
 | ||
|         lr_min = self.args['lr_min']
 | ||
|         lr_decay_steps = self.args['lr_decay_steps']
 | ||
| 
 | ||
|         lr_max_day =  self.args['lr_max_day']
 | ||
|         lr_min_day = self.args['lr_min_day']
 | ||
|         lr_decay_steps_day = self.args['lr_decay_steps_day']
 | ||
| 
 | ||
|         lr_warmup_steps = self.args['lr_warmup_steps']
 | ||
|         lr_warmup_steps_day = self.args['lr_warmup_steps_day']
 | ||
| 
 | ||
|         def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
 | ||
|             '''
 | ||
|             Create lr lambdas for each param group that implement cosine decay
 | ||
| 
 | ||
|             Different lr lambda decaying for day params vs rest of the model
 | ||
|             '''
 | ||
|             # Warmup phase
 | ||
|             if current_step < warmup_steps:
 | ||
|                 return float(current_step) / float(max(1, warmup_steps))
 | ||
|             
 | ||
|             # Cosine decay phase
 | ||
|             if current_step < decay_steps:
 | ||
|                 progress = float(current_step - warmup_steps) / float(
 | ||
|                     max(1, decay_steps - warmup_steps)
 | ||
|                 )
 | ||
|                 cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
 | ||
|                 # Scale from 1.0 to min_lr_ratio
 | ||
|                 return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
 | ||
|             
 | ||
|             # After cosine decay is complete, maintain min_lr_ratio
 | ||
|             return min_lr_ratio
 | ||
| 
 | ||
|         if len(optim.param_groups) == 3:
 | ||
|             lr_lambdas = [
 | ||
|                 lambda step: lr_lambda(
 | ||
|                     step, 
 | ||
|                     lr_min / lr_max, 
 | ||
|                     lr_decay_steps, 
 | ||
|                     lr_warmup_steps), # biases 
 | ||
|                 lambda step: lr_lambda(
 | ||
|                     step, 
 | ||
|                     lr_min_day / lr_max_day, 
 | ||
|                     lr_decay_steps_day,
 | ||
|                     lr_warmup_steps_day, 
 | ||
|                     ), # day params
 | ||
|                 lambda step: lr_lambda(
 | ||
|                     step, 
 | ||
|                     lr_min / lr_max, 
 | ||
|                     lr_decay_steps, 
 | ||
|                     lr_warmup_steps), # rest of model weights
 | ||
|             ]
 | ||
|         elif len(optim.param_groups) == 2:
 | ||
|             lr_lambdas = [
 | ||
|                 lambda step: lr_lambda(
 | ||
|                     step, 
 | ||
|                     lr_min / lr_max, 
 | ||
|                     lr_decay_steps, 
 | ||
|                     lr_warmup_steps), # biases 
 | ||
|                 lambda step: lr_lambda(
 | ||
|                     step, 
 | ||
|                     lr_min / lr_max, 
 | ||
|                     lr_decay_steps, 
 | ||
|                     lr_warmup_steps), # rest of model weights
 | ||
|             ]
 | ||
|         else:
 | ||
|             raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
 | ||
|         
 | ||
|         return LambdaLR(optim, lr_lambdas, -1)
 | ||
|         
 | ||
|     def load_model_checkpoint(self, load_path):
 | ||
|         '''
 | ||
|         Load a training checkpoint for distributed training
 | ||
|         '''
 | ||
|         # Load checkpoint on CPU first to avoid OOM issues
 | ||
|         checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
 | ||
| 
 | ||
|         # Get unwrapped model for loading state dict
 | ||
|         unwrapped_model = self.accelerator.unwrap_model(self.model)
 | ||
|         unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
 | ||
| 
 | ||
|         self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 | ||
|         self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
 | ||
|         self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
 | ||
|         self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
 | ||
| 
 | ||
|         # Device handling is managed by Accelerator, no need to manually move to device
 | ||
| 
 | ||
|         self.logger.info("Loaded model from checkpoint: " + load_path)
 | ||
| 
 | ||
|     def save_model_checkpoint(self, save_path, PER, loss):
 | ||
|         '''
 | ||
|         Save a training checkpoint using Accelerator for distributed training
 | ||
|         '''
 | ||
|         # Only save on main process to avoid conflicts
 | ||
|         if self.accelerator.is_main_process:
 | ||
|             # Unwrap model to get base model for saving
 | ||
|             unwrapped_model = self.accelerator.unwrap_model(self.model)
 | ||
| 
 | ||
|             checkpoint = {
 | ||
|                 'model_state_dict' : unwrapped_model.state_dict(),
 | ||
|                 'optimizer_state_dict' : self.optimizer.state_dict(),
 | ||
|                 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
 | ||
|                 'val_PER' : PER,
 | ||
|                 'val_loss' : loss
 | ||
|             }
 | ||
| 
 | ||
|             torch.save(checkpoint, save_path)
 | ||
| 
 | ||
|             self.logger.info("Saved model to checkpoint: " + save_path)
 | ||
| 
 | ||
|             # Save the args file alongside the checkpoint
 | ||
|             with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
 | ||
|                 OmegaConf.save(config=self.args, f=f)
 | ||
| 
 | ||
|         # Wait for all processes to complete checkpoint saving
 | ||
|         self.accelerator.wait_for_everyone()
 | ||
| 
 | ||
|     def create_attention_mask(self, sequence_lengths):
 | ||
| 
 | ||
|         max_length = torch.max(sequence_lengths).item()
 | ||
| 
 | ||
|         batch_size = sequence_lengths.size(0)
 | ||
|         
 | ||
|         # Create a mask for valid key positions (columns)
 | ||
|         # Shape: [batch_size, max_length]
 | ||
|         key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
 | ||
|         key_mask = key_mask < sequence_lengths.unsqueeze(1)
 | ||
|         
 | ||
|         # Expand key_mask to [batch_size, 1, 1, max_length]
 | ||
|         # This will be broadcast across all query positions
 | ||
|         key_mask = key_mask.unsqueeze(1).unsqueeze(1)
 | ||
|         
 | ||
|         # Create the attention mask of shape [batch_size, 1, max_length, max_length]
 | ||
|         # by broadcasting key_mask across all query positions
 | ||
|         attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
 | ||
|         
 | ||
|         # Convert boolean mask to float mask:
 | ||
|         # - True (valid key positions) -> 0.0 (no change to attention scores)
 | ||
|         # - False (padding key positions) -> -inf (will become 0 after softmax)
 | ||
|         attention_mask_float = torch.where(attention_mask, 
 | ||
|                                         True,
 | ||
|                                         False)
 | ||
|         
 | ||
|         return attention_mask_float
 | ||
| 
 | ||
|     def transform_data(self, features, n_time_steps, mode = 'train'):
 | ||
|         '''
 | ||
|         Apply various augmentations and smoothing to data
 | ||
|         Performing augmentations is much faster on GPU than CPU
 | ||
|         '''
 | ||
| 
 | ||
|         # TPU and GPU should now handle data consistently with our improved DataLoader configuration
 | ||
| 
 | ||
|         data_shape = features.shape
 | ||
|         batch_size = data_shape[0]
 | ||
|         channels = data_shape[-1]
 | ||
| 
 | ||
|         # We only apply these augmentations in training
 | ||
|         if mode == 'train':
 | ||
|             # add static gain noise 
 | ||
|             if self.transform_args['static_gain_std'] > 0:
 | ||
|                 warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
 | ||
|                 warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
 | ||
| 
 | ||
|                 features = torch.matmul(features, warp_mat)
 | ||
| 
 | ||
|             # add white noise
 | ||
|             if self.transform_args['white_noise_std'] > 0:
 | ||
|                 features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
 | ||
| 
 | ||
|             # add constant offset noise 
 | ||
|             if self.transform_args['constant_offset_std'] > 0:
 | ||
|                 features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
 | ||
| 
 | ||
|             # add random walk noise
 | ||
|             if self.transform_args['random_walk_std'] > 0:
 | ||
|                 features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
 | ||
| 
 | ||
|             # randomly cutoff part of the data timecourse
 | ||
|             if self.transform_args['random_cut'] > 0:
 | ||
|                 cut = np.random.randint(0, self.transform_args['random_cut'])
 | ||
|                 features = features[:, cut:, :]
 | ||
|                 n_time_steps = n_time_steps - cut
 | ||
| 
 | ||
|         # Apply Gaussian smoothing to data 
 | ||
|         # This is done in both training and validation
 | ||
|         if self.transform_args['smooth_data']:
 | ||
|             features = gauss_smooth(
 | ||
|                 inputs = features,
 | ||
|                 device = self.device,
 | ||
|                 smooth_kernel_std = self.transform_args['smooth_kernel_std'],
 | ||
|                 smooth_kernel_size= self.transform_args['smooth_kernel_size'],
 | ||
|                 )
 | ||
| 
 | ||
|         if hasattr(self, 'model_dtype'):
 | ||
|             features = features.to(self.model_dtype)
 | ||
|             
 | ||
|         
 | ||
|         return features, n_time_steps
 | ||
| 
 | ||
|     def train(self):
 | ||
|         '''
 | ||
|         Train the model 
 | ||
|         '''
 | ||
| 
 | ||
|         # Set model to train mode (specificially to make sure dropout layers are engaged)
 | ||
|         self.model.train()
 | ||
| 
 | ||
|         # create vars to track performance
 | ||
|         train_losses = []
 | ||
|         val_losses = []
 | ||
|         val_PERs = []
 | ||
|         val_results = []
 | ||
| 
 | ||
|         val_steps_since_improvement = 0
 | ||
| 
 | ||
|         # training params 
 | ||
|         save_best_checkpoint = self.args.get('save_best_checkpoint', True)
 | ||
|         early_stopping = self.args.get('early_stopping', True)
 | ||
| 
 | ||
|         early_stopping_val_steps = self.args['early_stopping_val_steps']
 | ||
| 
 | ||
|         train_start_time = time.time()
 | ||
| 
 | ||
|         # train for specified number of batches
 | ||
|         self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...")
 | ||
|         for i, batch in enumerate(self.train_loader):
 | ||
|             
 | ||
|             self.model.train()
 | ||
|             self.optimizer.zero_grad()
 | ||
|             
 | ||
|             # Train step
 | ||
|             start_time = time.time()
 | ||
| 
 | ||
|             # Data is automatically moved to device by Accelerator
 | ||
|             features = batch['input_features']
 | ||
|             labels = batch['seq_class_ids']
 | ||
|             n_time_steps = batch['n_time_steps']
 | ||
|             phone_seq_lens = batch['phone_seq_lens']
 | ||
|             day_indicies = batch['day_indicies']
 | ||
| 
 | ||
|             # Use Accelerator's autocast (mixed precision handled by Accelerator init)
 | ||
|             with self.autocast_context():
 | ||
| 
 | ||
|                 # Apply augmentations to the data
 | ||
|                 features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
 | ||
| 
 | ||
|                 # Ensure proper dtype handling for TPU mixed precision
 | ||
|                 adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
 | ||
| 
 | ||
|                 # Get phoneme predictions using inference mode during training
 | ||
|                 # (We use inference mode for simplicity - only clean logits are used for CTC loss)
 | ||
|                 # Ensure features tensor matches model parameter dtype for TPU compatibility
 | ||
|                 if features.dtype != self.model_dtype:
 | ||
|                     features = features.to(self.model_dtype)
 | ||
| 
 | ||
|                 # Forward pass: enable full adversarial mode if configured and past warmup
 | ||
|                 use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
 | ||
|                 if use_full:
 | ||
|                     clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
 | ||
|                 else:
 | ||
|                     logits = self.model(features, day_indicies, None, False, 'inference')
 | ||
| 
 | ||
|                 # Calculate CTC Loss
 | ||
|                 if use_full:
 | ||
|                     # Clean CTC loss
 | ||
|                     clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
 | ||
|                     clean_loss = self.ctc_loss(
 | ||
|                         clean_log_probs,
 | ||
|                         labels,
 | ||
|                         adjusted_lens,
 | ||
|                         phone_seq_lens
 | ||
|                     )
 | ||
|                     clean_loss = torch.mean(clean_loss)
 | ||
| 
 | ||
|                     # Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
 | ||
|                     noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
 | ||
|                     noisy_loss = self.ctc_loss(
 | ||
|                         noisy_log_probs,
 | ||
|                         labels,
 | ||
|                         adjusted_lens,
 | ||
|                         phone_seq_lens
 | ||
|                     )
 | ||
|                     noisy_loss = torch.mean(noisy_loss)
 | ||
| 
 | ||
|                     # Optional noise energy regularization
 | ||
|                     noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
 | ||
|                     if self.adv_noise_l2_weight > 0.0:
 | ||
|                         noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
 | ||
| 
 | ||
|                     loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
 | ||
|                 else:
 | ||
|                     log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
 | ||
|                     loss = self.ctc_loss(
 | ||
|                         log_probs=log_probs,
 | ||
|                         targets=labels,
 | ||
|                         input_lengths=adjusted_lens,
 | ||
|                         target_lengths=phone_seq_lens
 | ||
|                     )
 | ||
|                     loss = torch.mean(loss) # take mean loss over batches
 | ||
| 
 | ||
|             # Use Accelerator's backward for distributed training
 | ||
|             self.accelerator.backward(loss)
 | ||
| 
 | ||
|             # Clip gradient using Accelerator's clip_grad_norm_
 | ||
|             if self.args['grad_norm_clip_value'] > 0:
 | ||
|                 grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(),
 | ||
|                                                            max_norm = self.args['grad_norm_clip_value'])
 | ||
| 
 | ||
|             self.optimizer.step()
 | ||
|             self.learning_rate_scheduler.step()
 | ||
|             
 | ||
|             # Save training metrics 
 | ||
|             train_step_duration = time.time() - start_time
 | ||
|             train_losses.append(loss.detach().item())
 | ||
| 
 | ||
|             # Incrementally log training progress
 | ||
|             if i % self.args['batches_per_train_log'] == 0:
 | ||
|                 self.logger.info(f'Train batch {i}: ' +
 | ||
|                         f'loss: {(loss.detach().item()):.2f} ' +
 | ||
|                         f'grad norm: {grad_norm:.2f} '
 | ||
|                         f'time: {train_step_duration:.3f}')
 | ||
| 
 | ||
|             # Incrementally run a test step
 | ||
|             if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
 | ||
|                 self.logger.info(f"Running test after training batch: {i}")
 | ||
|                 
 | ||
|                 # Calculate metrics on val data
 | ||
|                 start_time = time.time()
 | ||
|                 val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
 | ||
|                 val_step_duration = time.time() - start_time
 | ||
| 
 | ||
| 
 | ||
|                 # Log info 
 | ||
|                 self.logger.info(f'Val batch {i}: ' +
 | ||
|                         f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
 | ||
|                         f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
 | ||
|                         f'time: {val_step_duration:.3f}')
 | ||
|                 
 | ||
|                 if self.args['log_individual_day_val_PER']:
 | ||
|                     for day in val_metrics['day_PERs'].keys():
 | ||
|                         self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
 | ||
| 
 | ||
|                 # Save metrics 
 | ||
|                 val_PERs.append(val_metrics['avg_PER'])
 | ||
|                 val_losses.append(val_metrics['avg_loss'])
 | ||
|                 val_results.append(val_metrics)
 | ||
| 
 | ||
|                 # Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
 | ||
|                 new_best = False
 | ||
|                 if val_metrics['avg_PER'] < self.best_val_PER:
 | ||
|                     self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
 | ||
|                     self.best_val_PER = val_metrics['avg_PER']
 | ||
|                     self.best_val_loss = val_metrics['avg_loss']
 | ||
|                     new_best = True
 | ||
|                 elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss): 
 | ||
|                     self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
 | ||
|                     self.best_val_loss = val_metrics['avg_loss']
 | ||
|                     new_best = True
 | ||
| 
 | ||
|                 if new_best:
 | ||
| 
 | ||
|                     # Checkpoint if metrics have improved 
 | ||
|                     if save_best_checkpoint:
 | ||
|                         self.logger.info(f"Checkpointing model")
 | ||
|                         self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
 | ||
| 
 | ||
|                     # save validation metrics to pickle file
 | ||
|                     if self.args['save_val_metrics']:
 | ||
|                         with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
 | ||
|                             pickle.dump(val_metrics, f) 
 | ||
| 
 | ||
|                     val_steps_since_improvement = 0
 | ||
|                     
 | ||
|                 else:
 | ||
|                     val_steps_since_improvement +=1
 | ||
| 
 | ||
|                 # Optionally save this validation checkpoint, regardless of performance
 | ||
|                 if self.args['save_all_val_steps']:
 | ||
|                     self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
 | ||
| 
 | ||
|                 # Early stopping 
 | ||
|                 if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
 | ||
|                     self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
 | ||
|                     break
 | ||
|                 
 | ||
|         # Log final training steps 
 | ||
|         training_duration = time.time() - train_start_time
 | ||
| 
 | ||
| 
 | ||
|         self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
 | ||
|         self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
 | ||
| 
 | ||
|         # Save final model 
 | ||
|         if self.args['save_final_model']:
 | ||
|             last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
 | ||
|             self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
 | ||
| 
 | ||
|         train_stats = {}
 | ||
|         train_stats['train_losses'] = train_losses
 | ||
|         train_stats['val_losses'] = val_losses 
 | ||
|         train_stats['val_PERs'] = val_PERs
 | ||
|         train_stats['val_metrics'] = val_results
 | ||
| 
 | ||
|         return train_stats
 | ||
| 
 | ||
|     def validation(self, loader, return_logits = False, return_data = False):
 | ||
|         '''
 | ||
|         Calculate metrics on the validation dataset
 | ||
|         '''
 | ||
|         self.model.eval()
 | ||
| 
 | ||
|         metrics = {}
 | ||
|         
 | ||
|         # Record metrics
 | ||
|         if return_logits: 
 | ||
|             metrics['logits'] = []
 | ||
|             metrics['n_time_steps'] = []
 | ||
| 
 | ||
|         if return_data: 
 | ||
|             metrics['input_features'] = []
 | ||
| 
 | ||
|         metrics['decoded_seqs'] = []
 | ||
|         metrics['true_seq'] = []
 | ||
|         metrics['phone_seq_lens'] = []
 | ||
|         metrics['transcription'] = []
 | ||
|         metrics['losses'] = []
 | ||
|         metrics['block_nums'] = []
 | ||
|         metrics['trial_nums'] = []
 | ||
|         metrics['day_indicies'] = []
 | ||
| 
 | ||
|         total_edit_distance = 0
 | ||
|         total_seq_length = 0
 | ||
| 
 | ||
|         # Calculate PER for each specific day
 | ||
|         day_per = {}
 | ||
|         for d in range(len(self.args['dataset']['sessions'])):
 | ||
|             if self.args['dataset']['dataset_probability_val'][d] == 1: 
 | ||
|                 day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
 | ||
| 
 | ||
|         for i, batch in enumerate(loader):
 | ||
| 
 | ||
|             # Data is automatically moved to device by Accelerator
 | ||
|             features = batch['input_features']
 | ||
|             labels = batch['seq_class_ids']
 | ||
|             n_time_steps = batch['n_time_steps']
 | ||
|             phone_seq_lens = batch['phone_seq_lens']
 | ||
|             day_indicies = batch['day_indicies']
 | ||
| 
 | ||
|             # Determine if we should perform validation on this batch
 | ||
|             day = day_indicies[0].item()
 | ||
|             if self.args['dataset']['dataset_probability_val'][day] == 0: 
 | ||
|                 if self.args['log_val_skip_logs']:
 | ||
|                     self.logger.info(f"Skipping validation on day {day}")
 | ||
|                 continue
 | ||
|             
 | ||
|             with torch.no_grad():
 | ||
| 
 | ||
|                 with self.autocast_context():
 | ||
|                     features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
 | ||
| 
 | ||
|                     # Ensure proper dtype handling for TPU mixed precision
 | ||
|                     adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
 | ||
| 
 | ||
|                     # Ensure features tensor matches model parameter dtype for TPU compatibility
 | ||
|                     model_param = next(self.model.parameters()) if self.model is not None else None
 | ||
|                     if model_param is not None and features.dtype != model_param.dtype:
 | ||
|                         features = features.to(model_param.dtype)
 | ||
| 
 | ||
|                     logits = self.model(features, day_indicies, None, False, 'inference')
 | ||
| 
 | ||
|                     val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
 | ||
|                     loss = self.ctc_loss(
 | ||
|                         val_log_probs,
 | ||
|                         labels,
 | ||
|                         adjusted_lens,
 | ||
|                         phone_seq_lens,
 | ||
|                     )
 | ||
|                     loss = torch.mean(loss)
 | ||
| 
 | ||
|                 metrics['losses'].append(loss.cpu().detach().numpy())
 | ||
| 
 | ||
|                 # Calculate PER per day and also avg over entire validation set
 | ||
|                 batch_edit_distance = 0 
 | ||
|                 decoded_seqs = []
 | ||
|                 for iterIdx in range(logits.shape[0]):
 | ||
|                     decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
 | ||
|                     decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
 | ||
|                     decoded_seq = decoded_seq.cpu().detach().numpy()
 | ||
|                     decoded_seq = np.array([i for i in decoded_seq if i != 0])
 | ||
| 
 | ||
|                     trueSeq = np.array(
 | ||
|                         labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
 | ||
|                     )
 | ||
|             
 | ||
|                     batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
 | ||
| 
 | ||
|                     decoded_seqs.append(decoded_seq)
 | ||
| 
 | ||
|             day = batch['day_indicies'][0].item()
 | ||
|                 
 | ||
|             day_per[day]['total_edit_distance'] += batch_edit_distance
 | ||
|             day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
 | ||
| 
 | ||
| 
 | ||
|             total_edit_distance += batch_edit_distance
 | ||
|             total_seq_length += torch.sum(phone_seq_lens)
 | ||
| 
 | ||
|             # Record metrics
 | ||
|             if return_logits: 
 | ||
|                 metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
 | ||
|                 metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
 | ||
| 
 | ||
|             if return_data: 
 | ||
|                 metrics['input_features'].append(batch['input_features'].cpu().numpy()) 
 | ||
| 
 | ||
|             metrics['decoded_seqs'].append(decoded_seqs)
 | ||
|             metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
 | ||
|             metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
 | ||
|             metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
 | ||
|             metrics['losses'].append(loss.detach().item())
 | ||
|             metrics['block_nums'].append(batch['block_nums'].numpy())
 | ||
|             metrics['trial_nums'].append(batch['trial_nums'].numpy())
 | ||
|             metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
 | ||
| 
 | ||
|         if isinstance(total_seq_length, torch.Tensor):
 | ||
|             total_length_value = float(total_seq_length.item())
 | ||
|         else:
 | ||
|             total_length_value = float(total_seq_length)
 | ||
| 
 | ||
|         avg_PER = total_edit_distance / max(total_length_value, 1e-6)
 | ||
| 
 | ||
|         metrics['day_PERs'] = day_per
 | ||
|         metrics['avg_PER'] = avg_PER
 | ||
|         metrics['avg_loss'] = float(np.mean(metrics['losses']))
 | ||
| 
 | ||
|         return metrics
 | ||
| 
 | ||
|     def inference(self, features, day_indicies, n_time_steps, mode='inference'):
 | ||
|         '''
 | ||
|         TPU-compatible inference method for generating phoneme logits
 | ||
|         '''
 | ||
|         self.model.eval()
 | ||
| 
 | ||
|         with torch.no_grad():
 | ||
|             with self.autocast_context():
 | ||
|                 # Apply data transformations (no augmentation for inference)
 | ||
|                 features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
 | ||
| 
 | ||
|                 # Ensure features tensor matches model parameter dtype for TPU compatibility
 | ||
|                 if features.dtype != self.model_dtype:
 | ||
|                     features = features.to(self.model_dtype)
 | ||
| 
 | ||
|                 # Get phoneme predictions
 | ||
|                 logits = self.model(features, day_indicies, None, False, mode)
 | ||
| 
 | ||
|         return logits
 | ||
| 
 | ||
|     def inference_batch(self, batch, mode='inference'):
 | ||
|         '''
 | ||
|         Inference method for processing a full batch
 | ||
|         '''
 | ||
|         self.model.eval()
 | ||
| 
 | ||
|         # Data is automatically moved to device by Accelerator
 | ||
|         features = batch['input_features']
 | ||
|         day_indicies = batch['day_indicies']
 | ||
|         n_time_steps = batch['n_time_steps']
 | ||
| 
 | ||
|         with torch.no_grad():
 | ||
|             with self.autocast_context():
 | ||
|                 # Apply data transformations (no augmentation for inference)
 | ||
|                 features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
 | ||
| 
 | ||
|                 # Calculate adjusted sequence lengths for CTC with proper dtype handling
 | ||
|                 adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
 | ||
| 
 | ||
|                 # Ensure features tensor matches model parameter dtype for TPU compatibility
 | ||
|                 if features.dtype != self.model_dtype:
 | ||
|                     features = features.to(self.model_dtype)
 | ||
| 
 | ||
|                 # Get phoneme predictions
 | ||
|                 logits = self.model(features, day_indicies, None, False, mode)
 | ||
| 
 | ||
|         return logits, adjusted_lens | 
