import os import tensorflow as tf import numpy as np import time import json import pickle import logging import pathlib import sys from typing import Dict, Any, Tuple, Optional, List from omegaconf import OmegaConf from rnn_model_tf import ( TripleGRUDecoder, CTCLoss, create_tpu_strategy, build_model_for_tpu, configure_mixed_precision ) from dataset_tf import ( BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices, create_input_fn ) class BrainToTextDecoderTrainerTF: """ TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8 This trainer implements the same training logic as the PyTorch version but uses TensorFlow operations optimized for TPU hardware. """ def __init__(self, args: Dict[str, Any]): """ Initialize the TensorFlow trainer Args: args: Configuration dictionary containing all training parameters """ self.args = args self.logger = None # Optimize CPU utilization for data pipeline (利用224核心) self._configure_cpu_optimization() # Initialize TPU strategy self.strategy = create_tpu_strategy() print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores") # Configure mixed precision for TPU v5e-8 if args.get('use_amp', True): configure_mixed_precision() self.mixed_precision = True else: self.mixed_precision = False # Initialize tracking variables self.best_val_per = float('inf') self.best_val_loss = float('inf') # Setup directories if args['mode'] == 'train': os.makedirs(self.args['output_dir'], exist_ok=True) if (args.get('save_best_checkpoint', True) or args.get('save_all_val_steps', False) or args.get('save_final_model', False)): os.makedirs(self.args['checkpoint_dir'], exist_ok=True) # Setup logging self._setup_logging() # Set random seeds if self.args['seed'] != -1: tf.random.set_seed(self.args['seed']) np.random.seed(self.args['seed']) # Initialize datasets self._initialize_datasets() # Build model within strategy scope with self.strategy.scope(): print("🔨 Building model within TPU strategy scope...") self.model = self._build_model() print("✅ Model built successfully") print("⚙️ Creating optimizer...") self.optimizer = self._create_optimizer() print("✅ Optimizer created") print("📅 Setting up learning rate scheduler...") self.lr_scheduler = self._create_lr_scheduler() print("✅ LR scheduler ready") print("🎯 Initializing CTC loss...") self.ctc_loss = CTCLoss(blank_index=0, reduction='none') print("✅ CTC loss initialized") # Log model information self._log_model_info() # Adversarial training configuration adv_cfg = self.args.get('adversarial', {}) self.adv_enabled = adv_cfg.get('enabled', False) self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) if self.adv_enabled: self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, " f"noisy_loss_weight={self.adv_noisy_loss_weight}, " f"noise_l2_weight={self.adv_noise_l2_weight}, " f"warmup_steps={self.adv_warmup_steps}") def _setup_logging(self): """Setup logging configuration""" self.logger = logging.getLogger(__name__) for handler in self.logger.handlers[:]: self.logger.removeHandler(handler) self.logger.setLevel(logging.INFO) formatter = logging.Formatter(fmt='%(asctime)s: %(message)s') if self.args['mode'] == 'train': fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log'))) fh.setFormatter(formatter) self.logger.addHandler(fh) sh = logging.StreamHandler(sys.stdout) sh.setFormatter(formatter) self.logger.addHandler(sh) self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas') if self.mixed_precision: self.logger.info('Mixed precision (bfloat16) enabled for TPU training') def _configure_cpu_optimization(self): """Configure CPU utilization to make use of 224 cores for data pipeline""" import multiprocessing # Get available CPU cores available_cores = multiprocessing.cpu_count() print(f"💻 Available CPU cores: {available_cores}") # Optimize for data pipeline parallelism # For 224 cores, use more threads for better data loading performance if available_cores >= 200: # High core count system inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores intra_op_threads = min(32, available_cores // 6) else: # Use ~1/4 of cores for inter-op (between operations) # Use ~1/8 of cores for intra-op (within operations) inter_op_threads = min(32, available_cores // 4) intra_op_threads = min(16, available_cores // 8) tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads) tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads) print(f"🔧 CPU optimization configured:") print(f" Inter-op parallelism: {inter_op_threads} threads") print(f" Intra-op parallelism: {intra_op_threads} threads") print(f" This will accelerate data loading and preprocessing while TPU handles training") def _get_tpu_status(self) -> str: """Get current TPU status and HBM utilization info""" try: # Get TPU devices tpu_devices = tf.config.list_logical_devices('TPU') if not tpu_devices: return "TPU: No devices" # Get strategy info num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 # Try to get TPU memory info (HBM) try: # Attempt to get TPU memory usage for each device memory_info = tf.config.experimental.get_memory_info('/TPU:0') if memory_info and 'current' in memory_info: current_mb = memory_info['current'] // (1024 * 1024) peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)" else: hbm_info = "HBM: unknown" except Exception: # Fallback: simple TPU activity check try: # Test TPU responsiveness with tf.device('/TPU:0'): test_tensor = tf.constant([1.0, 2.0]) _ = tf.reduce_sum(test_tensor) hbm_info = "HBM: active" except Exception: hbm_info = "HBM: inactive" return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores " f"{hbm_info}") except Exception as e: return f"TPU: status_error({str(e)[:20]})" def _get_detailed_tpu_status(self) -> str: """Get detailed TPU status for training start""" try: # Get TPU devices tpu_devices = tf.config.list_logical_devices('TPU') if not tpu_devices: return "❌ No TPU devices detected" # Get strategy info num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 strategy_type = type(self.strategy).__name__ # Get TPU HBM memory info try: memory_info = tf.config.experimental.get_memory_info('/TPU:0') if memory_info and 'current' in memory_info: current_gb = memory_info['current'] // (1024 * 1024 * 1024) peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024) # TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB estimated_total_gb = 32 * len(tpu_devices) hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)" else: hbm_usage = "HBM: unknown" except Exception: hbm_usage = "HBM: unavailable" # Simple TPU test try: with tf.device('/TPU:0'): test_result = tf.constant([1.0, 2.0]) _ = tf.reduce_sum(test_result) tpu_test = "✅ responsive" except Exception as e: tpu_test = f"❌ test_failed({str(e)[:15]})" return (f"TPU Devices: {len(tpu_devices)} | " f"Strategy: {strategy_type} | " f"Cores: {num_replicas} | " f"{hbm_usage} | " f"Test: {tpu_test}") except Exception as e: return f"❌ TPU status check failed: {str(e)[:50]}" def _initialize_datasets(self): """Initialize training and validation datasets""" # Create file paths train_file_paths = [ os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5') for s in self.args['dataset']['sessions'] ] val_file_paths = [ os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5') for s in self.args['dataset']['sessions'] ] # Validate no duplicates if len(set(train_file_paths)) != len(train_file_paths): raise ValueError("Duplicate sessions in train dataset") if len(set(val_file_paths)) != len(val_file_paths): raise ValueError("Duplicate sessions in val dataset") # Split trials train_trials, _ = train_test_split_indices( file_paths=train_file_paths, test_percentage=0, seed=self.args['dataset']['seed'], bad_trials_dict=self.args['dataset'].get('bad_trials_dict') ) _, val_trials = train_test_split_indices( file_paths=val_file_paths, test_percentage=1, seed=self.args['dataset']['seed'], bad_trials_dict=self.args['dataset'].get('bad_trials_dict') ) # Save trial splits with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: json.dump({'train': train_trials, 'val': val_trials}, f) # Create TensorFlow datasets self.train_dataset_tf = BrainToTextDatasetTF( trial_indices=train_trials, n_batches=self.args['num_training_batches'], split='train', batch_size=self.args['dataset']['batch_size'], days_per_batch=self.args['dataset']['days_per_batch'], random_seed=self.args['dataset']['seed'], must_include_days=self.args['dataset'].get('must_include_days'), feature_subset=self.args['dataset'].get('feature_subset') ) self.val_dataset_tf = BrainToTextDatasetTF( trial_indices=val_trials, n_batches=None, # Use all validation data split='test', batch_size=self.args['dataset']['batch_size'], days_per_batch=1, # One day per validation batch random_seed=self.args['dataset']['seed'], feature_subset=self.args['dataset'].get('feature_subset') ) self.logger.info("Successfully initialized TensorFlow datasets") def _build_model(self) -> TripleGRUDecoder: """Build the TripleGRUDecoder model""" model = TripleGRUDecoder( neural_dim=self.args['model']['n_input_features'], n_units=self.args['model']['n_units'], n_days=len(self.args['dataset']['sessions']), n_classes=self.args['dataset']['n_classes'], rnn_dropout=self.args['model']['rnn_dropout'], input_dropout=self.args['model']['input_network']['input_layer_dropout'], patch_size=self.args['model']['patch_size'], patch_stride=self.args['model']['patch_stride'] ) return model def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: """Create AdamW optimizer with parameter groups""" # Note: TensorFlow doesn't have the same parameter group functionality as PyTorch # We'll use a single optimizer and handle different learning rates in the scheduler optimizer = tf.keras.optimizers.AdamW( learning_rate=self.args['lr_max'], beta_1=self.args['beta0'], beta_2=self.args['beta1'], epsilon=self.args['epsilon'], weight_decay=self.args['weight_decay'] ) return optimizer def _create_lr_scheduler(self): """Create learning rate scheduler""" if self.args['lr_scheduler_type'] == 'cosine': return self._create_cosine_scheduler() elif self.args['lr_scheduler_type'] == 'linear': return tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=self.args['lr_max'], decay_steps=self.args['lr_decay_steps'], end_learning_rate=self.args['lr_min'], power=1.0 # Linear decay ) else: raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}") def _create_cosine_scheduler(self): """Create cosine learning rate scheduler""" return tf.keras.optimizers.schedules.CosineDecayRestarts( initial_learning_rate=self.args['lr_max'], first_decay_steps=self.args['lr_decay_steps'], t_mul=1.0, m_mul=1.0, alpha=self.args['lr_min'] / self.args['lr_max'] ) def _log_model_info(self): """Log model architecture and parameter information""" self.logger.info("Initialized TripleGRUDecoder model") # Build the model by calling it once with dummy data dummy_batch_size = 2 dummy_time_steps = 100 dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features'])) dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32) # Call the model to build it _ = self.model(dummy_features, dummy_day_idx, training=False) # Count parameters total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights]) self.logger.info(f"Model has {total_params:,} trainable parameters") @tf.function def _train_step(self, batch, step): """Single training step with gradient tape""" features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indices = batch['day_indices'] with tf.GradientTape() as tape: # Apply data transformations features, n_time_steps = DataAugmentationTF.transform_data( features, n_time_steps, self.args['dataset']['data_transforms'], training=True ) # Calculate adjusted lengths for CTC adjusted_lens = tf.cast( (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1, tf.int32 ) # Forward pass use_full = self.adv_enabled and (step >= self.adv_warmup_steps) if use_full: clean_logits, noisy_logits, noise_output = self.model( features, day_indices, None, False, 'full', grl_lambda=self.adv_grl_lambda, training=True ) else: clean_logits = self.model( features, day_indices, None, False, 'inference', training=True ) # Calculate losses if use_full: # Clean CTC loss clean_loss_input = { 'labels': labels, 'input_lengths': adjusted_lens, 'label_lengths': phone_seq_lens } clean_loss = self.ctc_loss(clean_loss_input, clean_logits) clean_loss = tf.reduce_mean(clean_loss) # Noisy CTC loss noisy_loss_input = { 'labels': labels, 'input_lengths': adjusted_lens, 'label_lengths': phone_seq_lens } noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits) noisy_loss = tf.reduce_mean(noisy_loss) # Optional noise L2 regularization noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype) if self.adv_noise_l2_weight > 0.0: noise_l2 = tf.reduce_mean(tf.square(noise_output)) loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: loss_input = { 'labels': labels, 'input_lengths': adjusted_lens, 'label_lengths': phone_seq_lens } loss = self.ctc_loss(loss_input, clean_logits) loss = tf.reduce_mean(loss) # Scale loss for mixed precision if self.mixed_precision: scaled_loss = self.optimizer.get_scaled_loss(loss) else: scaled_loss = loss # Calculate gradients if self.mixed_precision: scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables) gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) else: gradients = tape.gradient(scaled_loss, self.model.trainable_variables) # Clip gradients if self.args['grad_norm_clip_value'] > 0: gradients, grad_norm = tf.clip_by_global_norm( gradients, self.args['grad_norm_clip_value'] ) else: grad_norm = tf.global_norm(gradients) # Apply gradients self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) return loss, grad_norm @tf.function def _validation_step(self, batch): """Single validation step""" features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indices = batch['day_indices'] # Apply data transformations (no augmentation for validation) features, n_time_steps = DataAugmentationTF.transform_data( features, n_time_steps, self.args['dataset']['data_transforms'], training=False ) # Calculate adjusted lengths adjusted_lens = tf.cast( (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1, tf.int32 ) # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) # Calculate loss loss_input = { 'labels': labels, 'input_lengths': adjusted_lens, 'label_lengths': phone_seq_lens } loss = self.ctc_loss(loss_input, logits) loss = tf.reduce_mean(loss) # Calculate PER (Phoneme Error Rate) # Greedy decoding predicted_ids = tf.argmax(logits, axis=-1) # Remove blanks and consecutive duplicates batch_edit_distance = 0 for i in range(tf.shape(logits)[0]): pred_seq = predicted_ids[i, :adjusted_lens[i]] # Remove consecutive duplicates pred_seq = tf.py_function( func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]), inp=[pred_seq], Tout=tf.int64 ) # Remove blanks (assuming blank_index=0) pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0) true_seq = labels[i, :phone_seq_lens[i]] # Calculate edit distance edit_dist = tf.edit_distance( tf.SparseTensor( indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1), values=tf.cast(pred_seq, tf.int64), dense_shape=[tf.size(pred_seq)] ), tf.SparseTensor( indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1), values=tf.cast(true_seq, tf.int64), dense_shape=[tf.size(true_seq)] ), normalize=False ) batch_edit_distance += edit_dist return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens) def train(self) -> Dict[str, Any]: """Main training loop""" self.logger.info("Starting training loop...") # Log initial TPU status initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") # Create distributed datasets train_dataset = create_input_fn( self.train_dataset_tf, self.args['dataset']['data_transforms'], training=True ) val_dataset = create_input_fn( self.val_dataset_tf, self.args['dataset']['data_transforms'], training=False ) # Distribute datasets train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset) val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset) # Training metrics train_losses = [] val_losses = [] val_pers = [] val_results = [] val_steps_since_improvement = 0 train_start_time = time.time() # Training loop step = 0 for batch in train_dist_dataset: if step >= self.args['num_training_batches']: break start_time = time.time() # Distributed training step per_replica_losses, per_replica_grad_norms = self.strategy.run( self._train_step, args=(batch, step) ) # Reduce across replicas loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None) train_step_duration = time.time() - start_time train_losses.append(float(loss.numpy())) # Log training progress with TPU status if step % self.args['batches_per_train_log'] == 0: tpu_status = self._get_tpu_status() self.logger.info(f'Train batch {step}: ' f'loss: {float(loss.numpy()):.2f} ' f'grad norm: {float(grad_norm.numpy()):.2f} ' f'time: {train_step_duration:.3f}s ' f'| {tpu_status}') # Validation step if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1): self.logger.info(f"Running validation after training batch: {step}") val_start_time = time.time() val_metrics = self._validate(val_dist_dataset) val_step_duration = time.time() - val_start_time tpu_status = self._get_tpu_status() self.logger.info(f'Val batch {step}: ' f'PER (avg): {val_metrics["avg_per"]:.4f} ' f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' f'time: {val_step_duration:.3f}s ' f'| {tpu_status}') val_pers.append(val_metrics['avg_per']) val_losses.append(val_metrics['avg_loss']) val_results.append(val_metrics) # Check for improvement new_best = False if val_metrics['avg_per'] < self.best_val_per: self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}") self.best_val_per = val_metrics['avg_per'] self.best_val_loss = val_metrics['avg_loss'] new_best = True elif (val_metrics['avg_per'] == self.best_val_per and val_metrics['avg_loss'] < self.best_val_loss): self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}") self.best_val_loss = val_metrics['avg_loss'] new_best = True if new_best: if self.args.get('save_best_checkpoint', True): self.logger.info("Checkpointing model") self._save_checkpoint('best_checkpoint', step) if self.args.get('save_val_metrics', True): with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: pickle.dump(val_metrics, f) val_steps_since_improvement = 0 else: val_steps_since_improvement += 1 # Optional save all validation checkpoints if self.args.get('save_all_val_steps', False): self._save_checkpoint(f'checkpoint_batch_{step}', step) # Early stopping if (self.args.get('early_stopping', False) and val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)): self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} ' f'validation steps. Stopping training early at batch: {step}') break step += 1 # Training completed training_duration = time.time() - train_start_time self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}') self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes') # Save final model if self.args.get('save_final_model', False): last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1) return { 'train_losses': train_losses, 'val_losses': val_losses, 'val_pers': val_pers, 'val_metrics': val_results } def _validate(self, val_dataset) -> Dict[str, Any]: """Run validation on entire validation dataset""" total_loss = 0.0 total_edit_distance = 0 total_seq_length = 0 num_batches = 0 for batch in val_dataset: per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = ( self.strategy.run(self._validation_step, args=(batch,)) ) # Reduce across replicas batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None) batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None) total_loss += float(batch_loss.numpy()) total_edit_distance += float(batch_edit_distance.numpy()) total_seq_length += float(batch_seq_length.numpy()) num_batches += 1 avg_loss = total_loss / max(num_batches, 1) avg_per = total_edit_distance / max(total_seq_length, 1e-6) return { 'avg_loss': avg_loss, 'avg_per': avg_per, 'total_edit_distance': total_edit_distance, 'total_seq_length': total_seq_length, 'num_batches': num_batches } def _save_checkpoint(self, name: str, step: int): """Save model checkpoint""" checkpoint_path = os.path.join(self.args['checkpoint_dir'], name) # Save model weights self.model.save_weights(checkpoint_path + '.weights.h5') # Save optimizer state optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer) optimizer_checkpoint.save(checkpoint_path + '.optimizer') # Save training state state = { 'step': step, 'best_val_per': float(self.best_val_per), 'best_val_loss': float(self.best_val_loss) } with open(checkpoint_path + '.state.json', 'w') as f: json.dump(state, f) # Save config with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: OmegaConf.save(config=self.args, f=f) self.logger.info(f"Saved checkpoint: {checkpoint_path}") def load_checkpoint(self, checkpoint_path: str): """Load model checkpoint""" # Load model weights self.model.load_weights(checkpoint_path + '.weights.h5') # Load optimizer state optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer) optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1') # Load training state with open(checkpoint_path + '.state.json', 'r') as f: state = json.load(f) self.best_val_per = state['best_val_per'] self.best_val_loss = state['best_val_loss'] self.logger.info(f"Loaded checkpoint: {checkpoint_path}") def inference(self, features: tf.Tensor, day_indices: tf.Tensor, n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor: """ Run inference on input features Args: features: Input neural features [batch_size, time_steps, features] day_indices: Day indices [batch_size] n_time_steps: Number of valid time steps [batch_size] mode: 'inference' or 'full' Returns: Phoneme logits [batch_size, time_steps, n_classes] """ # Apply data transformations (no augmentation) features, n_time_steps = DataAugmentationTF.transform_data( features, n_time_steps, self.args['dataset']['data_transforms'], training=False ) # Run model inference logits = self.model(features, day_indices, None, False, mode, training=False) return logits