595 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			595 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import tensorflow as tf
 | |
| import h5py
 | |
| import numpy as np
 | |
| import math
 | |
| from typing import Dict, List, Tuple, Optional, Any
 | |
| from scipy.ndimage import gaussian_filter1d
 | |
| 
 | |
| 
 | |
| class BrainToTextDatasetTF:
 | |
|     """
 | |
|     TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
 | |
| 
 | |
|     This class creates tf.data.Dataset objects that efficiently load and batch
 | |
|     brain-to-text data from HDF5 files with TPU-optimized operations.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         trial_indices: Dict[int, Dict[str, Any]],
 | |
|         n_batches: Optional[int],
 | |
|         split: str = 'train',
 | |
|         batch_size: int = 64,
 | |
|         days_per_batch: int = 1,
 | |
|         random_seed: int = -1,
 | |
|         must_include_days: Optional[List[int]] = None,
 | |
|         feature_subset: Optional[List[int]] = None,
 | |
|         prefetch_buffer: int = tf.data.AUTOTUNE,
 | |
|         num_parallel_calls: int = tf.data.AUTOTUNE
 | |
|     ):
 | |
|         """
 | |
|         Initialize TensorFlow dataset for brain-to-text data
 | |
| 
 | |
|         Args:
 | |
|             trial_indices: Dictionary with day numbers as keys and trial info as values
 | |
|             n_batches: Number of training batches to create (None for validation)
 | |
|             split: 'train' or 'test'
 | |
|             batch_size: Number of examples per batch
 | |
|             days_per_batch: Number of unique days per batch (for day-specific layers)
 | |
|             random_seed: Random seed for reproducibility
 | |
|             must_include_days: Days that must be included in every batch
 | |
|             feature_subset: Subset of neural features to use
 | |
|             prefetch_buffer: Buffer size for prefetching
 | |
|             num_parallel_calls: Parallel processing threads
 | |
|         """
 | |
| 
 | |
|         # Set random seed for reproducibility
 | |
|         if random_seed != -1:
 | |
|             tf.random.set_seed(random_seed)
 | |
|             np.random.seed(random_seed)
 | |
| 
 | |
|         self.split = split
 | |
|         if self.split not in ['train', 'test']:
 | |
|             raise ValueError(f'split must be either "train" or "test". Received {self.split}')
 | |
| 
 | |
|         self.days_per_batch = days_per_batch
 | |
|         self.batch_size = batch_size
 | |
|         self.n_batches = n_batches
 | |
|         self.trial_indices = trial_indices
 | |
|         self.n_days = len(trial_indices.keys())
 | |
|         self.feature_subset = feature_subset
 | |
|         self.must_include_days = must_include_days
 | |
|         self.prefetch_buffer = prefetch_buffer
 | |
|         self.num_parallel_calls = num_parallel_calls
 | |
| 
 | |
|         # Calculate total number of trials
 | |
|         self.n_trials = 0
 | |
|         for d in trial_indices:
 | |
|             self.n_trials += len(trial_indices[d]['trials'])
 | |
| 
 | |
|         # Validation checks
 | |
|         if must_include_days is not None:
 | |
|             if len(must_include_days) > days_per_batch:
 | |
|                 raise ValueError(f'must_include_days must be <= days_per_batch')
 | |
| 
 | |
|             # Map negative indices
 | |
|             for i, d in enumerate(must_include_days):
 | |
|                 if d < 0:
 | |
|                     must_include_days[i] = self.n_days + d
 | |
| 
 | |
|         if self.split == 'train' and self.days_per_batch > self.n_days:
 | |
|             raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
 | |
| 
 | |
|         # Create batch indices
 | |
|         if self.split == 'train':
 | |
|             self.batch_indices = self._create_batch_index_train()
 | |
|         else:
 | |
|             self.batch_indices = self._create_batch_index_test()
 | |
|             self.n_batches = len(self.batch_indices)
 | |
| 
 | |
|     def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
 | |
|         """Create training batch indices with random sampling"""
 | |
|         batch_indices = {}
 | |
| 
 | |
|         # Precompute non-must-include days
 | |
|         if self.must_include_days is not None:
 | |
|             non_must_include_days = [
 | |
|                 d for d in self.trial_indices.keys()
 | |
|                 if d not in self.must_include_days
 | |
|             ]
 | |
| 
 | |
|         for batch_idx in range(self.n_batches):
 | |
|             batch = {}
 | |
| 
 | |
|             # Select days for this batch
 | |
|             if self.must_include_days is not None and len(self.must_include_days) > 0:
 | |
|                 additional_days = np.random.choice(
 | |
|                     non_must_include_days,
 | |
|                     size=self.days_per_batch - len(self.must_include_days),
 | |
|                     replace=False
 | |
|                 )
 | |
|                 days = np.concatenate((self.must_include_days, additional_days))
 | |
|             else:
 | |
|                 days = np.random.choice(
 | |
|                     list(self.trial_indices.keys()),
 | |
|                     size=self.days_per_batch,
 | |
|                     replace=False
 | |
|                 )
 | |
| 
 | |
|             # Calculate trials per day
 | |
|             num_trials = math.ceil(self.batch_size / self.days_per_batch)
 | |
| 
 | |
|             for d in days:
 | |
|                 # Sample trials with replacement
 | |
|                 trial_idxs = np.random.choice(
 | |
|                     self.trial_indices[d]['trials'],
 | |
|                     size=num_trials,
 | |
|                     replace=True
 | |
|                 )
 | |
|                 batch[d] = trial_idxs.tolist()
 | |
| 
 | |
|             # Remove extra trials to match exact batch size
 | |
|             extra_trials = (num_trials * len(days)) - self.batch_size
 | |
|             while extra_trials > 0:
 | |
|                 d = np.random.choice(days)
 | |
|                 if len(batch[d]) > 0:
 | |
|                     batch[d] = batch[d][:-1]
 | |
|                     extra_trials -= 1
 | |
| 
 | |
|             batch_indices[batch_idx] = batch
 | |
| 
 | |
|         return batch_indices
 | |
| 
 | |
|     def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
 | |
|         """Create test batch indices ensuring all trials are seen once"""
 | |
|         batch_indices = {}
 | |
|         batch_idx = 0
 | |
| 
 | |
|         for d in self.trial_indices.keys():
 | |
|             num_trials = len(self.trial_indices[d]['trials'])
 | |
|             num_batches = (num_trials + self.batch_size - 1) // self.batch_size
 | |
| 
 | |
|             for i in range(num_batches):
 | |
|                 start_idx = i * self.batch_size
 | |
|                 end_idx = min((i + 1) * self.batch_size, num_trials)
 | |
| 
 | |
|                 batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
 | |
|                 batch_indices[batch_idx] = {d: batch_trials}
 | |
|                 batch_idx += 1
 | |
| 
 | |
|         return batch_indices
 | |
| 
 | |
|     def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
 | |
|         """Load a single trial's data from HDF5 file"""
 | |
|         try:
 | |
|             session_path = self.trial_indices[day]['session_path']
 | |
| 
 | |
|             with h5py.File(session_path, 'r') as f:
 | |
|                 g = f[f'trial_{trial:04d}']
 | |
| 
 | |
|                 # Load neural features
 | |
|                 input_features = g['input_features'][:]
 | |
|                 if self.feature_subset:
 | |
|                     input_features = input_features[:, self.feature_subset]
 | |
| 
 | |
|                 # Convert to bfloat16 for TPU efficiency
 | |
|                 input_features = input_features.astype(np.float32)  # TF will handle bfloat16 conversion
 | |
| 
 | |
|                 trial_data = {
 | |
|                     'input_features': input_features,
 | |
|                     'seq_class_ids': g['seq_class_ids'][:],
 | |
|                     'transcription': g['transcription'][:],
 | |
|                     'n_time_steps': g.attrs['n_time_steps'],
 | |
|                     'phone_seq_lens': g.attrs['seq_len'],
 | |
|                     'day_index': day,
 | |
|                     'block_num': g.attrs['block_num'],
 | |
|                     'trial_num': g.attrs['trial_num']
 | |
|                 }
 | |
| 
 | |
|                 return trial_data
 | |
| 
 | |
|         except Exception as e:
 | |
|             print(f'Error loading trial {trial} from day {day}: {e}')
 | |
|             # Return dummy data to maintain batch structure
 | |
|             return {
 | |
|                 'input_features': np.zeros((100, 512), dtype=np.float32),
 | |
|                 'seq_class_ids': np.zeros((10,), dtype=np.int32),
 | |
|                 'transcription': np.zeros((50,), dtype=np.int32),
 | |
|                 'n_time_steps': 100,
 | |
|                 'phone_seq_lens': 10,
 | |
|                 'day_index': day,
 | |
|                 'block_num': 0,
 | |
|                 'trial_num': 0
 | |
|             }
 | |
| 
 | |
|     def _create_batch_generator(self):
 | |
|         """Generator function that yields individual batches"""
 | |
|         for batch_idx in range(self.n_batches):
 | |
|             batch_data = {
 | |
|                 'input_features': [],
 | |
|                 'seq_class_ids': [],
 | |
|                 'n_time_steps': [],
 | |
|                 'phone_seq_lens': [],
 | |
|                 'day_indices': [],
 | |
|                 'transcriptions': [],
 | |
|                 'block_nums': [],
 | |
|                 'trial_nums': []
 | |
|             }
 | |
| 
 | |
|             batch_index = self.batch_indices[batch_idx]
 | |
| 
 | |
|             # Load data for each day in the batch
 | |
|             for day in batch_index.keys():
 | |
|                 for trial in batch_index[day]:
 | |
|                     trial_data = self._load_trial_data(day, trial)
 | |
| 
 | |
|                     batch_data['input_features'].append(trial_data['input_features'])
 | |
|                     batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
 | |
|                     batch_data['transcriptions'].append(trial_data['transcription'])
 | |
|                     batch_data['n_time_steps'].append(trial_data['n_time_steps'])
 | |
|                     batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
 | |
|                     batch_data['day_indices'].append(trial_data['day_index'])
 | |
|                     batch_data['block_nums'].append(trial_data['block_num'])
 | |
|                     batch_data['trial_nums'].append(trial_data['trial_num'])
 | |
| 
 | |
|             # Pad sequences to create uniform batch
 | |
|             max_time_steps = max(batch_data['n_time_steps'])
 | |
|             max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
 | |
|             max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
 | |
| 
 | |
|             # Pad input features
 | |
|             padded_features = []
 | |
|             for features in batch_data['input_features']:
 | |
|                 if features.shape[0] < max_time_steps:
 | |
|                     padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
 | |
|                     features = np.vstack([features, padding])
 | |
|                 padded_features.append(features)
 | |
| 
 | |
|             # Pad sequences
 | |
|             padded_seq_ids = []
 | |
|             for seq in batch_data['seq_class_ids']:
 | |
|                 if len(seq) < max_phone_len:
 | |
|                     padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
 | |
|                     seq = np.concatenate([seq, padding])
 | |
|                 padded_seq_ids.append(seq)
 | |
| 
 | |
|             # Pad transcriptions
 | |
|             padded_transcriptions = []
 | |
|             for trans in batch_data['transcriptions']:
 | |
|                 if len(trans) < max_transcription_len:
 | |
|                     padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
 | |
|                     trans = np.concatenate([trans, padding])
 | |
|                 padded_transcriptions.append(trans)
 | |
| 
 | |
|             # Create final batch tensors
 | |
|             batch = {
 | |
|                 'input_features': np.stack(padded_features),
 | |
|                 'seq_class_ids': np.stack(padded_seq_ids),
 | |
|                 'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
 | |
|                 'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
 | |
|                 'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
 | |
|                 'transcriptions': np.stack(padded_transcriptions),
 | |
|                 'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
 | |
|                 'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
 | |
|             }
 | |
| 
 | |
|             yield batch
 | |
| 
 | |
|     def create_dataset(self) -> tf.data.Dataset:
 | |
|         """Create optimized tf.data.Dataset for TPU training"""
 | |
| 
 | |
|         # Define output signature for the dataset
 | |
|         output_signature = {
 | |
|             'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
 | |
|             'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
 | |
|             'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
 | |
|             'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
 | |
|             'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
 | |
|             'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
 | |
|             'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
 | |
|             'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
 | |
|         }
 | |
| 
 | |
|         # Create dataset from generator
 | |
|         dataset = tf.data.Dataset.from_generator(
 | |
|             self._create_batch_generator,
 | |
|             output_signature=output_signature
 | |
|         )
 | |
| 
 | |
|         # Apply TPU-optimized transformations
 | |
|         if self.split == 'train':
 | |
|             # For training, add shuffling
 | |
|             dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches))
 | |
| 
 | |
|         # Prefetch for better performance
 | |
|         dataset = dataset.prefetch(self.prefetch_buffer)
 | |
| 
 | |
|         return dataset
 | |
| 
 | |
| 
 | |
| class DataAugmentationTF:
 | |
|     """
 | |
|     TensorFlow data augmentation functions optimized for TPU v5e-8
 | |
|     """
 | |
| 
 | |
|     @staticmethod
 | |
|     def gauss_smooth(inputs: tf.Tensor,
 | |
|                     smooth_kernel_std: float = 2.0,
 | |
|                     smooth_kernel_size: int = 100) -> tf.Tensor:
 | |
|         """
 | |
|         Apply Gaussian smoothing along the time axis using TensorFlow operations
 | |
| 
 | |
|         Args:
 | |
|             inputs: Input tensor [batch_size, time_steps, features]
 | |
|             smooth_kernel_std: Standard deviation of Gaussian kernel
 | |
|             smooth_kernel_size: Size of the Gaussian kernel
 | |
| 
 | |
|         Returns:
 | |
|             Smoothed tensor with same shape as input
 | |
|         """
 | |
|         # Create Gaussian kernel using numpy (computed once)
 | |
|         inp = np.zeros(smooth_kernel_size, dtype=np.float32)
 | |
|         inp[smooth_kernel_size // 2] = 1
 | |
|         gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
 | |
|         valid_idx = np.argwhere(gauss_kernel > 0.01)
 | |
|         gauss_kernel = gauss_kernel[valid_idx].flatten()
 | |
|         gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
 | |
| 
 | |
|         # Convert to TensorFlow tensor and reshape for conv1d
 | |
|         gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
 | |
|         kernel_size = tf.shape(gauss_kernel)[0]
 | |
|         gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1])  # [kernel_size, in_channels, out_channels]
 | |
| 
 | |
|         # Get tensor dimensions
 | |
|         batch_size = tf.shape(inputs)[0]
 | |
|         time_steps = tf.shape(inputs)[1]
 | |
|         num_features = tf.shape(inputs)[2]
 | |
| 
 | |
|         # Apply convolution to each feature channel separately
 | |
|         smoothed_features = []
 | |
| 
 | |
|         # Convert num_features to Python int for loop
 | |
|         num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
 | |
| 
 | |
|         if isinstance(num_features_py, tf.Tensor):
 | |
|             # If dynamic, use tf.map_fn for dynamic number of features
 | |
|             def smooth_single_feature(i):
 | |
|                 # Extract single feature channel: [batch_size, time_steps, 1]
 | |
|                 feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
 | |
|                 # Apply 1D convolution
 | |
|                 return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
 | |
| 
 | |
|             # Use tf.map_fn for dynamic features
 | |
|             indices = tf.range(num_features)
 | |
|             smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
 | |
|             # Transpose to get [batch_size, time_steps, features]
 | |
|             smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
 | |
|             smoothed = tf.squeeze(smoothed, axis=-1)
 | |
|         else:
 | |
|             # Static number of features - use loop
 | |
|             for i in range(num_features_py):
 | |
|                 # Extract single feature channel: [batch_size, time_steps, 1]
 | |
|                 feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
 | |
|                 # Apply 1D convolution
 | |
|                 smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
 | |
|                 smoothed_features.append(smoothed_channel)
 | |
| 
 | |
|             # Concatenate all smoothed features
 | |
|             smoothed = tf.concat(smoothed_features, axis=-1)  # [batch_size, time_steps, features]
 | |
| 
 | |
|         return smoothed
 | |
| 
 | |
|     @staticmethod
 | |
|     def transform_data(features: tf.Tensor,
 | |
|                       n_time_steps: tf.Tensor,
 | |
|                       transform_args: Dict[str, Any],
 | |
|                       training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
 | |
|         """
 | |
|         Apply data transformations optimized for TPU
 | |
| 
 | |
|         Args:
 | |
|             features: Input features [batch_size, time_steps, channels]
 | |
|             n_time_steps: Number of valid time steps per sample
 | |
|             transform_args: Transformation configuration
 | |
|             training: Whether to apply training-only augmentations
 | |
| 
 | |
|         Returns:
 | |
|             Transformed features and updated time steps
 | |
|         """
 | |
|         batch_size = tf.shape(features)[0]
 | |
|         time_steps = tf.shape(features)[1]
 | |
|         channels = tf.shape(features)[2]
 | |
| 
 | |
|         # Training-only augmentations
 | |
|         if training:
 | |
|             # Static gain noise
 | |
|             if transform_args.get('static_gain_std', 0) > 0:
 | |
|                 gain_std = transform_args['static_gain_std']
 | |
|                 # Create identity matrices for each batch
 | |
|                 identity_matrices = tf.eye(channels, batch_shape=[batch_size])
 | |
|                 # Add noise to create warp matrices
 | |
|                 noise = tf.random.normal([batch_size, channels, channels]) * gain_std
 | |
|                 warp_matrices = identity_matrices + noise
 | |
|                 # Apply transformation
 | |
|                 features = tf.linalg.matmul(features, warp_matrices)
 | |
| 
 | |
|             # White noise
 | |
|             if transform_args.get('white_noise_std', 0) > 0:
 | |
|                 white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
 | |
|                 features = features + white_noise
 | |
| 
 | |
|             # Constant offset noise
 | |
|             if transform_args.get('constant_offset_std', 0) > 0:
 | |
|                 offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
 | |
|                 features = features + offset_noise
 | |
| 
 | |
|             # Random walk noise
 | |
|             if transform_args.get('random_walk_std', 0) > 0:
 | |
|                 random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
 | |
|                 axis = transform_args.get('random_walk_axis', 1)
 | |
|                 random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
 | |
|                 features = features + random_walk_noise
 | |
| 
 | |
|             # Random cutoff (simplified for TPU - apply to all samples in batch)
 | |
|             if transform_args.get('random_cut', 0) > 0:
 | |
|                 max_cut = transform_args['random_cut']
 | |
|                 cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
 | |
|                 features = features[:, cut:, :]
 | |
|                 n_time_steps = n_time_steps - cut
 | |
| 
 | |
|         # Apply Gaussian smoothing (both training and validation)
 | |
|         if transform_args.get('smooth_data', False):
 | |
|             features = DataAugmentationTF.gauss_smooth(
 | |
|                 features,
 | |
|                 smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
 | |
|                 smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
 | |
|             )
 | |
| 
 | |
|         return features, n_time_steps
 | |
| 
 | |
| 
 | |
| def train_test_split_indices(file_paths: List[str],
 | |
|                            test_percentage: float = 0.1,
 | |
|                            seed: int = -1,
 | |
|                            bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
 | |
|     """
 | |
|     Split data from file_paths into train and test splits
 | |
| 
 | |
|     Args:
 | |
|         file_paths: List of HDF5 file paths
 | |
|         test_percentage: Percentage of trials for testing
 | |
|         seed: Random seed for reproducibility
 | |
|         bad_trials_dict: Dictionary of trials to exclude
 | |
| 
 | |
|     Returns:
 | |
|         Tuple of (train_trials, test_trials) dictionaries
 | |
|     """
 | |
|     # Set seed for reproducibility
 | |
|     if seed != -1:
 | |
|         np.random.seed(seed)
 | |
| 
 | |
|     # Get trials in each day
 | |
|     trials_per_day = {}
 | |
|     for i, path in enumerate(file_paths):
 | |
|         # Handle both Windows and Unix path separators
 | |
|         path_parts = path.replace('\\', '/').split('/')
 | |
|         session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
 | |
| 
 | |
|         good_trial_indices = []
 | |
| 
 | |
|         if os.path.exists(path):
 | |
|             with h5py.File(path, 'r') as f:
 | |
|                 num_trials = len(list(f.keys()))
 | |
|                 for t in range(num_trials):
 | |
|                     key = f'trial_{t:04d}'
 | |
| 
 | |
|                     if key not in f:
 | |
|                         continue
 | |
| 
 | |
|                     block_num = f[key].attrs['block_num']
 | |
|                     trial_num = f[key].attrs['trial_num']
 | |
| 
 | |
|                     # Check if trial should be excluded
 | |
|                     if (bad_trials_dict is not None
 | |
|                         and session in bad_trials_dict
 | |
|                         and str(block_num) in bad_trials_dict[session]
 | |
|                         and trial_num in bad_trials_dict[session][str(block_num)]):
 | |
|                         continue
 | |
| 
 | |
|                     good_trial_indices.append(t)
 | |
| 
 | |
|         trials_per_day[i] = {
 | |
|             'num_trials': len(good_trial_indices),
 | |
|             'trial_indices': good_trial_indices,
 | |
|             'session_path': path
 | |
|         }
 | |
| 
 | |
|     # Split trials into train and test
 | |
|     train_trials = {}
 | |
|     test_trials = {}
 | |
| 
 | |
|     for day in trials_per_day.keys():
 | |
|         num_trials = trials_per_day[day]['num_trials']
 | |
|         all_trial_indices = trials_per_day[day]['trial_indices']
 | |
| 
 | |
|         if test_percentage == 0:
 | |
|             train_trials[day] = {
 | |
|                 'trials': all_trial_indices,
 | |
|                 'session_path': trials_per_day[day]['session_path']
 | |
|             }
 | |
|             test_trials[day] = {
 | |
|                 'trials': [],
 | |
|                 'session_path': trials_per_day[day]['session_path']
 | |
|             }
 | |
|         elif test_percentage == 1:
 | |
|             train_trials[day] = {
 | |
|                 'trials': [],
 | |
|                 'session_path': trials_per_day[day]['session_path']
 | |
|             }
 | |
|             test_trials[day] = {
 | |
|                 'trials': all_trial_indices,
 | |
|                 'session_path': trials_per_day[day]['session_path']
 | |
|             }
 | |
|         else:
 | |
|             # Calculate number of test trials
 | |
|             num_test = max(1, int(num_trials * test_percentage))
 | |
| 
 | |
|             # Randomly select test indices
 | |
|             test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
 | |
| 
 | |
|             # Remaining indices for training
 | |
|             train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
 | |
| 
 | |
|             train_trials[day] = {
 | |
|                 'trials': train_indices,
 | |
|                 'session_path': trials_per_day[day]['session_path']
 | |
|             }
 | |
|             test_trials[day] = {
 | |
|                 'trials': test_indices,
 | |
|                 'session_path': trials_per_day[day]['session_path']
 | |
|             }
 | |
| 
 | |
|     return train_trials, test_trials
 | |
| 
 | |
| 
 | |
| # Utility functions for TPU-optimized data pipeline
 | |
| def create_input_fn(dataset_tf: BrainToTextDatasetTF,
 | |
|                    transform_args: Dict[str, Any],
 | |
|                    training: bool = True) -> tf.data.Dataset:
 | |
|     """
 | |
|     Create input function for TPU training with data augmentation
 | |
| 
 | |
|     Args:
 | |
|         dataset_tf: BrainToTextDatasetTF instance
 | |
|         transform_args: Data transformation configuration
 | |
|         training: Whether this is for training (applies augmentations)
 | |
| 
 | |
|     Returns:
 | |
|         tf.data.Dataset ready for TPU training
 | |
|     """
 | |
|     dataset = dataset_tf.create_dataset()
 | |
| 
 | |
|     def apply_transforms(batch):
 | |
|         """Apply data transformations to a batch"""
 | |
|         features = batch['input_features']
 | |
|         n_time_steps = batch['n_time_steps']
 | |
| 
 | |
|         # Apply transformations
 | |
|         features, n_time_steps = DataAugmentationTF.transform_data(
 | |
|             features, n_time_steps, transform_args, training=training
 | |
|         )
 | |
| 
 | |
|         # Update batch with transformed data
 | |
|         batch['input_features'] = features
 | |
|         batch['n_time_steps'] = n_time_steps
 | |
| 
 | |
|         return batch
 | |
| 
 | |
|     # Apply transformations
 | |
|     dataset = dataset.map(
 | |
|         apply_transforms,
 | |
|         num_parallel_calls=tf.data.AUTOTUNE
 | |
|     )
 | |
| 
 | |
|     return dataset | 
