480 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			480 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| """
 | |
| TensorFlow Evaluation Script for Brain-to-Text RNN Model
 | |
| Optimized for TPU v5e-8
 | |
| 
 | |
| This script evaluates the TripleGRUDecoder model using TensorFlow and provides
 | |
| detailed metrics and analysis of model performance on test data.
 | |
| 
 | |
| Usage:
 | |
|     python evaluate_model_tf.py --model_path path/to/model --data_dir path/to/data
 | |
| 
 | |
| Requirements:
 | |
|     - TensorFlow >= 2.15.0
 | |
|     - TPU v5e-8 environment
 | |
|     - Trained model checkpoint
 | |
|     - Access to brain-to-text HDF5 dataset
 | |
| """
 | |
| 
 | |
| import argparse
 | |
| import os
 | |
| import sys
 | |
| import json
 | |
| import pickle
 | |
| import numpy as np
 | |
| import tensorflow as tf
 | |
| from typing import Dict, Any, List, Tuple
 | |
| from omegaconf import OmegaConf
 | |
| 
 | |
| # Add the current directory to Python path for imports
 | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 | |
| 
 | |
| from trainer_tf import BrainToTextDecoderTrainerTF
 | |
| from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
 | |
| from rnn_model_tf import create_tpu_strategy, configure_mixed_precision
 | |
| 
 | |
| 
 | |
| class BrainToTextEvaluatorTF:
 | |
|     """
 | |
|     TensorFlow evaluator for brain-to-text model performance analysis
 | |
|     """
 | |
| 
 | |
|     def __init__(self, model_path: str, config: Dict[str, Any], eval_type: str = 'test'):
 | |
|         """
 | |
|         Initialize evaluator
 | |
| 
 | |
|         Args:
 | |
|             model_path: Path to trained model checkpoint
 | |
|             config: Configuration dictionary
 | |
|             eval_type: 'test' or 'val' evaluation type
 | |
|         """
 | |
|         self.model_path = model_path
 | |
|         self.config = config
 | |
|         self.eval_type = eval_type
 | |
| 
 | |
|         # Initialize TPU strategy
 | |
|         self.strategy = create_tpu_strategy()
 | |
|         print(f"Evaluation using {self.strategy.num_replicas_in_sync} TPU cores")
 | |
| 
 | |
|         # Configure mixed precision
 | |
|         if config.get('use_amp', True):
 | |
|             configure_mixed_precision()
 | |
| 
 | |
|         # Load model
 | |
|         with self.strategy.scope():
 | |
|             self.trainer = BrainToTextDecoderTrainerTF(config)
 | |
|             self.trainer.load_checkpoint(model_path)
 | |
| 
 | |
|         print(f"Model loaded from: {model_path}")
 | |
| 
 | |
|     def evaluate_dataset(self, save_results: bool = True,
 | |
|                         return_predictions: bool = False) -> Dict[str, Any]:
 | |
|         """
 | |
|         Evaluate model on specified dataset
 | |
| 
 | |
|         Args:
 | |
|             save_results: Whether to save detailed results to file
 | |
|             return_predictions: Whether to return individual predictions
 | |
| 
 | |
|         Returns:
 | |
|             Dictionary containing evaluation metrics and optionally predictions
 | |
|         """
 | |
|         print(f"Starting {self.eval_type} evaluation...")
 | |
| 
 | |
|         # Create evaluation dataset
 | |
|         if self.eval_type == 'test':
 | |
|             dataset_tf = self.trainer.val_dataset_tf  # Using validation data as test
 | |
|         else:
 | |
|             dataset_tf = self.trainer.val_dataset_tf
 | |
| 
 | |
|         eval_dataset = create_input_fn(
 | |
|             dataset_tf,
 | |
|             self.config['dataset']['data_transforms'],
 | |
|             training=False
 | |
|         )
 | |
| 
 | |
|         # Distribute dataset
 | |
|         eval_dist_dataset = self.strategy.experimental_distribute_dataset(eval_dataset)
 | |
| 
 | |
|         # Run evaluation
 | |
|         results = self._run_evaluation(eval_dist_dataset, return_predictions)
 | |
| 
 | |
|         # Calculate summary metrics
 | |
|         summary_metrics = self._calculate_summary_metrics(results)
 | |
| 
 | |
|         print(f"Evaluation completed!")
 | |
|         print(f"Overall PER: {summary_metrics['overall_per']:.4f}")
 | |
|         print(f"Overall Loss: {summary_metrics['overall_loss']:.4f}")
 | |
|         print(f"Total trials evaluated: {summary_metrics['total_trials']}")
 | |
| 
 | |
|         # Save results if requested
 | |
|         if save_results:
 | |
|             self._save_results(results, summary_metrics)
 | |
| 
 | |
|         return {
 | |
|             'summary_metrics': summary_metrics,
 | |
|             'detailed_results': results if return_predictions else None
 | |
|         }
 | |
| 
 | |
|     def _run_evaluation(self, eval_dataset, return_predictions: bool) -> List[Dict[str, Any]]:
 | |
|         """Run evaluation on distributed dataset"""
 | |
|         all_results = []
 | |
|         batch_idx = 0
 | |
| 
 | |
|         for batch in eval_dataset:
 | |
|             batch_results = self.strategy.run(self._evaluation_step, args=(batch, return_predictions))
 | |
| 
 | |
|             # Gather results from all replicas
 | |
|             gathered_results = {}
 | |
|             for key in batch_results.keys():
 | |
|                 if key in ['logits', 'features'] and not return_predictions:
 | |
|                     continue  # Skip large tensors if not needed
 | |
| 
 | |
|                 values = self.strategy.experimental_local_results(batch_results[key])
 | |
|                 if key in ['loss', 'edit_distance', 'seq_length']:
 | |
|                     # Scalar metrics - just take the values
 | |
|                     gathered_results[key] = [float(v.numpy()) for v in values]
 | |
|                 else:
 | |
|                     # Tensor data - concatenate across replicas
 | |
|                     gathered_results[key] = [v.numpy() for v in values]
 | |
| 
 | |
|             all_results.append(gathered_results)
 | |
|             batch_idx += 1
 | |
| 
 | |
|             if batch_idx % 10 == 0:
 | |
|                 print(f"Processed {batch_idx} batches...")
 | |
| 
 | |
|         return all_results
 | |
| 
 | |
|     @tf.function
 | |
|     def _evaluation_step(self, batch, return_predictions: bool):
 | |
|         """Single evaluation step"""
 | |
|         features = batch['input_features']
 | |
|         labels = batch['seq_class_ids']
 | |
|         n_time_steps = batch['n_time_steps']
 | |
|         phone_seq_lens = batch['phone_seq_lens']
 | |
|         day_indices = batch['day_indices']
 | |
| 
 | |
|         # Apply data transformations (no augmentation)
 | |
|         from dataset_tf import DataAugmentationTF
 | |
|         features_transformed, n_time_steps_transformed = DataAugmentationTF.transform_data(
 | |
|             features, n_time_steps, self.config['dataset']['data_transforms'], training=False
 | |
|         )
 | |
| 
 | |
|         # Calculate adjusted lengths for CTC
 | |
|         adjusted_lens = tf.cast(
 | |
|             (tf.cast(n_time_steps_transformed, tf.float32) - self.config['model']['patch_size']) /
 | |
|             self.config['model']['patch_stride'] + 1,
 | |
|             tf.int32
 | |
|         )
 | |
| 
 | |
|         # Forward pass
 | |
|         logits = self.trainer.model(
 | |
|             features_transformed, day_indices, None, False, 'inference', training=False
 | |
|         )
 | |
| 
 | |
|         # Calculate loss
 | |
|         loss_input = {
 | |
|             'labels': labels,
 | |
|             'input_lengths': adjusted_lens,
 | |
|             'label_lengths': phone_seq_lens
 | |
|         }
 | |
|         loss = self.trainer.ctc_loss(loss_input, logits)
 | |
|         loss = tf.reduce_mean(loss)
 | |
| 
 | |
|         # Calculate edit distance for PER
 | |
|         predicted_ids = tf.argmax(logits, axis=-1)
 | |
|         batch_size = tf.shape(logits)[0]
 | |
| 
 | |
|         # Initialize metrics
 | |
|         total_edit_distance = 0
 | |
|         total_seq_length = tf.reduce_sum(phone_seq_lens)
 | |
| 
 | |
|         # Decode predictions and calculate edit distance
 | |
|         predictions = []
 | |
|         targets = []
 | |
| 
 | |
|         for i in range(batch_size):
 | |
|             # Get prediction for this sample
 | |
|             pred_seq = predicted_ids[i, :adjusted_lens[i]]
 | |
| 
 | |
|             # Remove consecutive duplicates using tf.py_function for simplicity
 | |
|             pred_seq_unique = tf.py_function(
 | |
|                 func=self._remove_consecutive_duplicates,
 | |
|                 inp=[pred_seq],
 | |
|                 Tout=tf.int64
 | |
|             )
 | |
| 
 | |
|             # Remove blanks (assuming blank_index=0)
 | |
|             pred_seq_clean = tf.boolean_mask(pred_seq_unique, pred_seq_unique != 0)
 | |
| 
 | |
|             # Get true sequence
 | |
|             true_seq = labels[i, :phone_seq_lens[i]]
 | |
| 
 | |
|             # Calculate edit distance for this pair
 | |
|             if tf.size(pred_seq_clean) > 0 and tf.size(true_seq) > 0:
 | |
|                 pred_sparse = tf.SparseTensor(
 | |
|                     indices=tf.expand_dims(tf.range(tf.size(pred_seq_clean), dtype=tf.int64), 1),
 | |
|                     values=tf.cast(pred_seq_clean, tf.int64),
 | |
|                     dense_shape=[tf.size(pred_seq_clean, out_type=tf.int64)]
 | |
|                 )
 | |
| 
 | |
|                 true_sparse = tf.SparseTensor(
 | |
|                     indices=tf.expand_dims(tf.range(tf.size(true_seq), dtype=tf.int64), 1),
 | |
|                     values=tf.cast(true_seq, tf.int64),
 | |
|                     dense_shape=[tf.size(true_seq, out_type=tf.int64)]
 | |
|                 )
 | |
| 
 | |
|                 edit_dist = tf.edit_distance(pred_sparse, true_sparse, normalize=False)
 | |
|                 total_edit_distance += edit_dist
 | |
| 
 | |
|             if return_predictions:
 | |
|                 predictions.append(pred_seq_clean)
 | |
|                 targets.append(true_seq)
 | |
| 
 | |
|         result = {
 | |
|             'loss': loss,
 | |
|             'edit_distance': total_edit_distance,
 | |
|             'seq_length': total_seq_length,
 | |
|             'day_indices': day_indices,
 | |
|             'n_time_steps': n_time_steps,
 | |
|             'phone_seq_lens': phone_seq_lens
 | |
|         }
 | |
| 
 | |
|         if return_predictions:
 | |
|             result.update({
 | |
|                 'logits': logits,
 | |
|                 'predictions': predictions,
 | |
|                 'targets': targets,
 | |
|                 'features': features
 | |
|             })
 | |
| 
 | |
|         return result
 | |
| 
 | |
|     def _remove_consecutive_duplicates(self, seq):
 | |
|         """Remove consecutive duplicate elements from sequence"""
 | |
|         seq_np = seq.numpy()
 | |
|         if len(seq_np) == 0:
 | |
|             return tf.constant([], dtype=tf.int64)
 | |
| 
 | |
|         unique_seq = [seq_np[0]]
 | |
|         for i in range(1, len(seq_np)):
 | |
|             if seq_np[i] != seq_np[i-1]:
 | |
|                 unique_seq.append(seq_np[i])
 | |
| 
 | |
|         return tf.constant(unique_seq, dtype=tf.int64)
 | |
| 
 | |
|     def _calculate_summary_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
 | |
|         """Calculate summary metrics from evaluation results"""
 | |
|         total_loss = 0.0
 | |
|         total_edit_distance = 0
 | |
|         total_seq_length = 0
 | |
|         total_trials = 0
 | |
|         num_batches = len(results)
 | |
| 
 | |
|         # Day-specific metrics
 | |
|         day_metrics = {}
 | |
| 
 | |
|         for batch_results in results:
 | |
|             # Sum losses across replicas
 | |
|             batch_loss = sum(batch_results['loss'])
 | |
|             total_loss += batch_loss
 | |
| 
 | |
|             # Sum edit distances and sequence lengths
 | |
|             batch_edit_dist = sum(batch_results['edit_distance'])
 | |
|             batch_seq_len = sum(batch_results['seq_length'])
 | |
| 
 | |
|             total_edit_distance += batch_edit_dist
 | |
|             total_seq_length += batch_seq_len
 | |
| 
 | |
|             # Count trials
 | |
|             for day_indices_replica in batch_results['day_indices']:
 | |
|                 total_trials += len(day_indices_replica)
 | |
| 
 | |
|                 # Track per-day metrics
 | |
|                 for i, day_idx in enumerate(day_indices_replica):
 | |
|                     day_idx = int(day_idx)
 | |
|                     if day_idx not in day_metrics:
 | |
|                         day_metrics[day_idx] = {'edit_distance': 0, 'seq_length': 0, 'trials': 0}
 | |
| 
 | |
|                     day_metrics[day_idx]['trials'] += 1
 | |
| 
 | |
|         # Calculate averages
 | |
|         avg_loss = total_loss / max(num_batches, 1)
 | |
|         overall_per = total_edit_distance / max(total_seq_length, 1e-6)
 | |
| 
 | |
|         # Calculate per-day PERs
 | |
|         day_pers = {}
 | |
|         for day_idx, metrics in day_metrics.items():
 | |
|             day_per = metrics['edit_distance'] / max(metrics['seq_length'], 1e-6)
 | |
|             day_pers[day_idx] = {
 | |
|                 'per': day_per,
 | |
|                 'edit_distance': metrics['edit_distance'],
 | |
|                 'seq_length': metrics['seq_length'],
 | |
|                 'trials': metrics['trials']
 | |
|             }
 | |
| 
 | |
|         return {
 | |
|             'overall_per': float(overall_per),
 | |
|             'overall_loss': float(avg_loss),
 | |
|             'total_edit_distance': int(total_edit_distance),
 | |
|             'total_seq_length': int(total_seq_length),
 | |
|             'total_trials': total_trials,
 | |
|             'num_batches': num_batches,
 | |
|             'day_metrics': day_pers
 | |
|         }
 | |
| 
 | |
|     def _save_results(self, detailed_results: List[Dict[str, Any]],
 | |
|                      summary_metrics: Dict[str, Any]):
 | |
|         """Save evaluation results to files"""
 | |
|         output_dir = self.config.get('output_dir', './eval_output')
 | |
|         os.makedirs(output_dir, exist_ok=True)
 | |
| 
 | |
|         # Save summary metrics
 | |
|         summary_path = os.path.join(output_dir, f'{self.eval_type}_summary_metrics.json')
 | |
|         with open(summary_path, 'w') as f:
 | |
|             json.dump(summary_metrics, f, indent=2)
 | |
|         print(f"Summary metrics saved to: {summary_path}")
 | |
| 
 | |
|         # Save detailed results
 | |
|         detailed_path = os.path.join(output_dir, f'{self.eval_type}_detailed_results.pkl')
 | |
|         with open(detailed_path, 'wb') as f:
 | |
|             pickle.dump(detailed_results, f)
 | |
|         print(f"Detailed results saved to: {detailed_path}")
 | |
| 
 | |
|         # Save per-day breakdown
 | |
|         if 'day_metrics' in summary_metrics:
 | |
|             day_breakdown_path = os.path.join(output_dir, f'{self.eval_type}_day_breakdown.json')
 | |
|             with open(day_breakdown_path, 'w') as f:
 | |
|                 json.dump(summary_metrics['day_metrics'], f, indent=2)
 | |
|             print(f"Per-day breakdown saved to: {day_breakdown_path}")
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     """Main evaluation function"""
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description='Evaluate Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
 | |
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--model_path',
 | |
|         required=True,
 | |
|         help='Path to trained model checkpoint (without extension)'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--config_path',
 | |
|         default='rnn_args.yaml',
 | |
|         help='Path to model configuration file'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--data_dir',
 | |
|         default=None,
 | |
|         help='Override data directory from config'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--eval_type',
 | |
|         choices=['test', 'val'],
 | |
|         default='test',
 | |
|         help='Type of evaluation to run'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--output_dir',
 | |
|         default='./eval_output',
 | |
|         help='Directory to save evaluation results'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--save_predictions',
 | |
|         action='store_true',
 | |
|         help='Save individual predictions and targets'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--batch_size',
 | |
|         type=int,
 | |
|         default=None,
 | |
|         help='Override batch size for evaluation'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--sessions',
 | |
|         nargs='+',
 | |
|         default=None,
 | |
|         help='Specific sessions to evaluate (overrides config)'
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     # Setup TPU environment
 | |
|     os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
 | |
| 
 | |
|     # Load configuration
 | |
|     if not os.path.exists(args.config_path):
 | |
|         raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
 | |
| 
 | |
|     config = OmegaConf.load(args.config_path)
 | |
| 
 | |
|     # Apply overrides
 | |
|     if args.data_dir:
 | |
|         config.dataset.dataset_dir = args.data_dir
 | |
|     if args.batch_size:
 | |
|         config.dataset.batch_size = args.batch_size
 | |
|     if args.sessions:
 | |
|         config.dataset.sessions = args.sessions
 | |
|     if args.output_dir:
 | |
|         config.output_dir = args.output_dir
 | |
| 
 | |
|     # Validate model checkpoint exists
 | |
|     if not os.path.exists(args.model_path + '.weights.h5'):
 | |
|         raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}")
 | |
| 
 | |
|     try:
 | |
|         # Initialize evaluator
 | |
|         evaluator = BrainToTextEvaluatorTF(
 | |
|             model_path=args.model_path,
 | |
|             config=config,
 | |
|             eval_type=args.eval_type
 | |
|         )
 | |
| 
 | |
|         # Run evaluation
 | |
|         results = evaluator.evaluate_dataset(
 | |
|             save_results=True,
 | |
|             return_predictions=args.save_predictions
 | |
|         )
 | |
| 
 | |
|         # Print results
 | |
|         metrics = results['summary_metrics']
 | |
|         print("\n" + "="*60)
 | |
|         print("EVALUATION RESULTS")
 | |
|         print("="*60)
 | |
|         print(f"Overall PER: {metrics['overall_per']:.6f}")
 | |
|         print(f"Overall Loss: {metrics['overall_loss']:.6f}")
 | |
|         print(f"Total Edit Distance: {metrics['total_edit_distance']}")
 | |
|         print(f"Total Sequence Length: {metrics['total_seq_length']}")
 | |
|         print(f"Total Trials: {metrics['total_trials']}")
 | |
|         print(f"Batches Processed: {metrics['num_batches']}")
 | |
| 
 | |
|         # Print per-day results if available
 | |
|         if 'day_metrics' in metrics and metrics['day_metrics']:
 | |
|             print("\nPER-DAY RESULTS:")
 | |
|             print("-" * 40)
 | |
|             for day_idx, day_metrics in metrics['day_metrics'].items():
 | |
|                 session_name = config.dataset.sessions[day_idx] if day_idx < len(config.dataset.sessions) else f"Day_{day_idx}"
 | |
|                 print(f"{session_name}: PER={day_metrics['per']:.6f}, Trials={day_metrics['trials']}")
 | |
| 
 | |
|         print("\nEvaluation completed successfully!")
 | |
| 
 | |
|     except Exception as e:
 | |
|         print(f"Evaluation failed: {e}")
 | |
|         import traceback
 | |
|         traceback.print_exc()
 | |
|         sys.exit(1)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main() | 
