import os import tensorflow as tf import h5py import numpy as np import math from typing import Dict, List, Tuple, Optional, Any from scipy.ndimage import gaussian_filter1d class BrainToTextDatasetTF: """ TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8 This class creates tf.data.Dataset objects that efficiently load and batch brain-to-text data from HDF5 files with TPU-optimized operations. """ def __init__( self, trial_indices: Dict[int, Dict[str, Any]], n_batches: Optional[int], split: str = 'train', batch_size: int = 64, days_per_batch: int = 1, random_seed: int = -1, must_include_days: Optional[List[int]] = None, feature_subset: Optional[List[int]] = None, prefetch_buffer: int = tf.data.AUTOTUNE, num_parallel_calls: int = tf.data.AUTOTUNE, cache_data: bool = True, preload_all_data: bool = False ): """ Initialize TensorFlow dataset for brain-to-text data Args: trial_indices: Dictionary with day numbers as keys and trial info as values n_batches: Number of training batches to create (None for validation) split: 'train' or 'test' batch_size: Number of examples per batch days_per_batch: Number of unique days per batch (for day-specific layers) random_seed: Random seed for reproducibility must_include_days: Days that must be included in every batch feature_subset: Subset of neural features to use prefetch_buffer: Buffer size for prefetching num_parallel_calls: Parallel processing threads cache_data: Whether to cache loaded data in memory preload_all_data: Whether to preload all data at initialization """ # Set random seed for reproducibility if random_seed != -1: tf.random.set_seed(random_seed) np.random.seed(random_seed) self.split = split if self.split not in ['train', 'test']: raise ValueError(f'split must be either "train" or "test". Received {self.split}') self.days_per_batch = days_per_batch self.batch_size = batch_size self.n_batches = n_batches self.trial_indices = trial_indices self.n_days = len(trial_indices.keys()) self.feature_subset = feature_subset self.must_include_days = must_include_days self.prefetch_buffer = prefetch_buffer self.num_parallel_calls = num_parallel_calls self.cache_data = cache_data self.preload_all_data = preload_all_data # Initialize data cache self.data_cache = {} if cache_data else None # Calculate total number of trials self.n_trials = 0 for d in trial_indices: self.n_trials += len(trial_indices[d]['trials']) # Validation checks if must_include_days is not None: if len(must_include_days) > days_per_batch: raise ValueError(f'must_include_days must be <= days_per_batch') # Map negative indices for i, d in enumerate(must_include_days): if d < 0: must_include_days[i] = self.n_days + d if self.split == 'train' and self.days_per_batch > self.n_days: raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})') # Create batch indices if self.split == 'train': self.batch_indices = self._create_batch_index_train() else: self.batch_indices = self._create_batch_index_test() self.n_batches = len(self.batch_indices) # Preload data if requested (speeds up first batch significantly) if self.preload_all_data: print(f"🔄 Preloading all data for {self.split} split...") self._preload_all_data() print(f"✅ Preloading completed - {len(self.data_cache)} trials cached") def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]: """Create training batch indices with random sampling""" batch_indices = {} # Precompute non-must-include days if self.must_include_days is not None: non_must_include_days = [ d for d in self.trial_indices.keys() if d not in self.must_include_days ] for batch_idx in range(self.n_batches): batch = {} # Select days for this batch if self.must_include_days is not None and len(self.must_include_days) > 0: additional_days = np.random.choice( non_must_include_days, size=self.days_per_batch - len(self.must_include_days), replace=False ) days = np.concatenate((self.must_include_days, additional_days)) else: days = np.random.choice( list(self.trial_indices.keys()), size=self.days_per_batch, replace=False ) # Calculate trials per day num_trials = math.ceil(self.batch_size / self.days_per_batch) for d in days: # Sample trials with replacement trial_idxs = np.random.choice( self.trial_indices[d]['trials'], size=num_trials, replace=True ) batch[d] = trial_idxs.tolist() # Remove extra trials to match exact batch size extra_trials = (num_trials * len(days)) - self.batch_size while extra_trials > 0: d = np.random.choice(days) if len(batch[d]) > 0: batch[d] = batch[d][:-1] extra_trials -= 1 batch_indices[batch_idx] = batch return batch_indices def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]: """Create test batch indices ensuring all trials are seen once""" batch_indices = {} batch_idx = 0 for d in self.trial_indices.keys(): num_trials = len(self.trial_indices[d]['trials']) num_batches = (num_trials + self.batch_size - 1) // self.batch_size for i in range(num_batches): start_idx = i * self.batch_size end_idx = min((i + 1) * self.batch_size, num_trials) batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx] batch_indices[batch_idx] = {d: batch_trials} batch_idx += 1 return batch_indices def _preload_all_data(self): """Preload all trial data into memory cache (uses available RAM optimally)""" import multiprocessing from concurrent.futures import ThreadPoolExecutor, as_completed # Use CPU cores efficiently for parallel I/O max_workers = min(multiprocessing.cpu_count(), 32) # Limit to avoid overwhelming I/O # Collect all trials to load trials_to_load = [] for day in self.trial_indices: for trial in self.trial_indices[day]['trials']: trials_to_load.append((day, trial)) print(f"📊 Preloading {len(trials_to_load)} trials using {max_workers} workers...") # Parallel loading using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all loading tasks future_to_trial = { executor.submit(self._load_single_trial_data, day, trial): (day, trial) for day, trial in trials_to_load } # Process completed tasks and update cache loaded_count = 0 for future in as_completed(future_to_trial): day, trial = future_to_trial[future] try: trial_data = future.result() cache_key = f"{day}_{trial}" self.data_cache[cache_key] = trial_data loaded_count += 1 # Progress indicator every 100 trials if loaded_count % 100 == 0: print(f" Loaded {loaded_count}/{len(trials_to_load)} trials...") except Exception as e: print(f" Warning: Failed to load trial {day}_{trial}: {e}") print(f"✅ Preloading completed: {loaded_count}/{len(trials_to_load)} trials cached") def _load_single_trial_data(self, day: int, trial: int) -> Dict[str, Any]: """Load a single trial's data - optimized version for parallel loading""" try: session_path = self.trial_indices[day]['session_path'] with h5py.File(session_path, 'r') as f: g = f[f'trial_{trial:04d}'] # Load neural features input_features = g['input_features'][:] if self.feature_subset: input_features = input_features[:, self.feature_subset] # Convert to float32 for TF compatibility input_features = input_features.astype(np.float32) trial_data = { 'input_features': input_features, 'seq_class_ids': g['seq_class_ids'][:], 'transcription': g['transcription'][:], 'n_time_steps': g.attrs['n_time_steps'], 'phone_seq_lens': g.attrs['seq_len'], 'day_index': day, 'block_num': g.attrs['block_num'], 'trial_num': g.attrs['trial_num'] } return trial_data except Exception as e: # Return dummy data for failed loads return { 'input_features': np.zeros((100, 512), dtype=np.float32), 'seq_class_ids': np.zeros((10,), dtype=np.int32), 'transcription': np.zeros((50,), dtype=np.int32), 'n_time_steps': 100, 'phone_seq_lens': 10, 'day_index': day, 'block_num': 0, 'trial_num': 0 } def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]: """Load a single trial's data from cache or HDF5 file""" # Check cache first if caching is enabled if self.cache_data: cache_key = f"{day}_{trial}" if cache_key in self.data_cache: return self.data_cache[cache_key] # Load from disk if not in cache trial_data = self._load_single_trial_data(day, trial) # Cache the loaded data if caching is enabled if self.cache_data: cache_key = f"{day}_{trial}" self.data_cache[cache_key] = trial_data return trial_data def _create_batch_generator(self): """Generator function that yields individual batches with optimized loading""" import time from concurrent.futures import ThreadPoolExecutor for batch_idx in range(self.n_batches): batch_start_time = time.time() batch_data = { 'input_features': [], 'seq_class_ids': [], 'n_time_steps': [], 'phone_seq_lens': [], 'day_indices': [], 'transcriptions': [], 'block_nums': [], 'trial_nums': [] } batch_index = self.batch_indices[batch_idx] # Collect all trials to load for this batch trials_to_load = [] for day in batch_index.keys(): for trial in batch_index[day]: trials_to_load.append((day, trial)) # Use parallel loading if not preloaded and have multiple trials if not self.preload_all_data and len(trials_to_load) > 4: # Parallel loading for faster I/O with ThreadPoolExecutor(max_workers=min(8, len(trials_to_load))) as executor: future_to_trial = { executor.submit(self._load_trial_data, day, trial): (day, trial) for day, trial in trials_to_load } # Collect results in order trial_results = {} for future in future_to_trial: day, trial = future_to_trial[future] trial_results[(day, trial)] = future.result() # Add data in original order for day, trial in trials_to_load: trial_data = trial_results[(day, trial)] batch_data['input_features'].append(trial_data['input_features']) batch_data['seq_class_ids'].append(trial_data['seq_class_ids']) batch_data['transcriptions'].append(trial_data['transcription']) batch_data['n_time_steps'].append(trial_data['n_time_steps']) batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens']) batch_data['day_indices'].append(trial_data['day_index']) batch_data['block_nums'].append(trial_data['block_num']) batch_data['trial_nums'].append(trial_data['trial_num']) else: # Sequential loading (fast when data is cached or few trials) for day, trial in trials_to_load: trial_data = self._load_trial_data(day, trial) batch_data['input_features'].append(trial_data['input_features']) batch_data['seq_class_ids'].append(trial_data['seq_class_ids']) batch_data['transcriptions'].append(trial_data['transcription']) batch_data['n_time_steps'].append(trial_data['n_time_steps']) batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens']) batch_data['day_indices'].append(trial_data['day_index']) batch_data['block_nums'].append(trial_data['block_num']) batch_data['trial_nums'].append(trial_data['trial_num']) data_loading_time = time.time() - batch_start_time # Add timing diagnostic for first few batches if batch_idx < 3: cache_status = "cached" if self.preload_all_data else "disk" loading_method = "parallel" if (not self.preload_all_data and len(trials_to_load) > 4) else "sequential" print(f"⏱️ Batch {batch_idx}: {len(trials_to_load)} trials loaded in {data_loading_time:.3f}s ({cache_status}, {loading_method})") # Pad sequences to create uniform batch max_time_steps = max(batch_data['n_time_steps']) max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids']) max_transcription_len = max(len(trans) for trans in batch_data['transcriptions']) # Pad input features padded_features = [] for features in batch_data['input_features']: if features.shape[0] < max_time_steps: padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32) features = np.vstack([features, padding]) padded_features.append(features) # Pad sequences padded_seq_ids = [] for seq in batch_data['seq_class_ids']: if len(seq) < max_phone_len: padding = np.zeros(max_phone_len - len(seq), dtype=np.int32) seq = np.concatenate([seq, padding]) padded_seq_ids.append(seq) # Pad transcriptions padded_transcriptions = [] for trans in batch_data['transcriptions']: if len(trans) < max_transcription_len: padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32) trans = np.concatenate([trans, padding]) padded_transcriptions.append(trans) # Create final batch tensors batch = { 'input_features': np.stack(padded_features), 'seq_class_ids': np.stack(padded_seq_ids), 'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32), 'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32), 'day_indices': np.array(batch_data['day_indices'], dtype=np.int32), 'transcriptions': np.stack(padded_transcriptions), 'block_nums': np.array(batch_data['block_nums'], dtype=np.int32), 'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32) } yield batch def create_dataset(self) -> tf.data.Dataset: """Create optimized tf.data.Dataset for TPU training""" # Define output signature for the dataset output_signature = { 'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32), 'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32), 'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32), 'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32) } # Create dataset from generator dataset = tf.data.Dataset.from_generator( self._create_batch_generator, output_signature=output_signature ) # Apply TPU-optimized transformations if self.split == 'train': # For training, add shuffling dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # Prefetch for better performance dataset = dataset.prefetch(self.prefetch_buffer) return dataset class DataAugmentationTF: """ TensorFlow data augmentation functions optimized for TPU v5e-8 """ @staticmethod def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kernel_size: int = 100) -> tf.Tensor: """ Apply Gaussian smoothing along the time axis using TensorFlow operations Args: inputs: Input tensor [batch_size, time_steps, features] smooth_kernel_std: Standard deviation of Gaussian kernel smooth_kernel_size: Size of the Gaussian kernel Returns: Smoothed tensor with same shape as input """ # Create Gaussian kernel using numpy (computed once) inp = np.zeros(smooth_kernel_size, dtype=np.float32) inp[smooth_kernel_size // 2] = 1 gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std) valid_idx = np.argwhere(gauss_kernel > 0.01) gauss_kernel = gauss_kernel[valid_idx].flatten() gauss_kernel = gauss_kernel / np.sum(gauss_kernel) # Convert to TensorFlow tensor and reshape for conv1d gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32) kernel_size = tf.shape(gauss_kernel)[0] gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels] # Get tensor dimensions batch_size = tf.shape(inputs)[0] time_steps = tf.shape(inputs)[1] num_features = tf.shape(inputs)[2] # Apply convolution to each feature channel separately smoothed_features = [] # Convert num_features to Python int for loop num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1] if isinstance(num_features_py, tf.Tensor): # If dynamic, use tf.map_fn for dynamic number of features def smooth_single_feature(i): # Extract single feature channel: [batch_size, time_steps, 1] feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1) # Apply 1D convolution return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME') # Use tf.map_fn for dynamic features indices = tf.range(num_features) smoothed_features_tensor = tf.map_fn( smooth_single_feature, indices, fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32) ) # Transpose to get [batch_size, time_steps, features] smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3]) smoothed = tf.squeeze(smoothed, axis=-1) else: # Static number of features - use loop for i in range(num_features_py): # Extract single feature channel: [batch_size, time_steps, 1] feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1) # Apply 1D convolution smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME') smoothed_features.append(smoothed_channel) # Concatenate all smoothed features smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features] return smoothed @staticmethod def transform_data(features: tf.Tensor, n_time_steps: tf.Tensor, transform_args: Dict[str, Any], training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]: """ Apply data transformations optimized for TPU Args: features: Input features [batch_size, time_steps, channels] n_time_steps: Number of valid time steps per sample transform_args: Transformation configuration training: Whether to apply training-only augmentations Returns: Transformed features and updated time steps """ batch_size = tf.shape(features)[0] time_steps = tf.shape(features)[1] channels = tf.shape(features)[2] # Training-only augmentations if training: # Static gain noise if transform_args.get('static_gain_std', 0) > 0: gain_std = transform_args['static_gain_std'] # Create identity matrices for each batch identity_matrices = tf.eye(channels, batch_shape=[batch_size]) # Add noise to create warp matrices noise = tf.random.normal([batch_size, channels, channels]) * gain_std warp_matrices = identity_matrices + noise # Apply transformation features = tf.linalg.matmul(features, warp_matrices) # White noise if transform_args.get('white_noise_std', 0) > 0: white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std'] features = features + white_noise # Constant offset noise if transform_args.get('constant_offset_std', 0) > 0: offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std'] features = features + offset_noise # Random walk noise if transform_args.get('random_walk_std', 0) > 0: random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std'] axis = transform_args.get('random_walk_axis', 1) random_walk_noise = tf.cumsum(random_walk_noise, axis=axis) features = features + random_walk_noise # Random cutoff (simplified for TPU - apply to all samples in batch) if transform_args.get('random_cut', 0) > 0: max_cut = transform_args['random_cut'] cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32) features = features[:, cut:, :] n_time_steps = n_time_steps - cut # Apply Gaussian smoothing (both training and validation) if transform_args.get('smooth_data', False): features = DataAugmentationTF.gauss_smooth( features, smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0), smooth_kernel_size=transform_args.get('smooth_kernel_size', 100) ) return features, n_time_steps def train_test_split_indices(file_paths: List[str], test_percentage: float = 0.1, seed: int = -1, bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]: """ Split data from file_paths into train and test splits Args: file_paths: List of HDF5 file paths test_percentage: Percentage of trials for testing seed: Random seed for reproducibility bad_trials_dict: Dictionary of trials to exclude Returns: Tuple of (train_trials, test_trials) dictionaries """ # Set seed for reproducibility if seed != -1: np.random.seed(seed) # Get trials in each day trials_per_day = {} for i, path in enumerate(file_paths): # Handle both Windows and Unix path separators path_parts = path.replace('\\', '/').split('/') session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0] good_trial_indices = [] if os.path.exists(path): with h5py.File(path, 'r') as f: num_trials = len(list(f.keys())) for t in range(num_trials): key = f'trial_{t:04d}' if key not in f: continue block_num = f[key].attrs['block_num'] trial_num = f[key].attrs['trial_num'] # Check if trial should be excluded if (bad_trials_dict is not None and session in bad_trials_dict and str(block_num) in bad_trials_dict[session] and trial_num in bad_trials_dict[session][str(block_num)]): continue good_trial_indices.append(t) trials_per_day[i] = { 'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path } # Split trials into train and test train_trials = {} test_trials = {} for day in trials_per_day.keys(): num_trials = trials_per_day[day]['num_trials'] all_trial_indices = trials_per_day[day]['trial_indices'] if test_percentage == 0: train_trials[day] = { 'trials': all_trial_indices, 'session_path': trials_per_day[day]['session_path'] } test_trials[day] = { 'trials': [], 'session_path': trials_per_day[day]['session_path'] } elif test_percentage == 1: train_trials[day] = { 'trials': [], 'session_path': trials_per_day[day]['session_path'] } test_trials[day] = { 'trials': all_trial_indices, 'session_path': trials_per_day[day]['session_path'] } else: # Calculate number of test trials num_test = max(1, int(num_trials * test_percentage)) # Randomly select test indices test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist() # Remaining indices for training train_indices = [idx for idx in all_trial_indices if idx not in test_indices] train_trials[day] = { 'trials': train_indices, 'session_path': trials_per_day[day]['session_path'] } test_trials[day] = { 'trials': test_indices, 'session_path': trials_per_day[day]['session_path'] } return train_trials, test_trials # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], training: bool = True) -> tf.data.Dataset: """ Create input function for TPU training with data augmentation Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration training: Whether this is for training (applies augmentations) Returns: tf.data.Dataset ready for TPU training """ dataset = dataset_tf.create_dataset() def apply_transforms(batch): """Apply data transformations to a batch""" features = batch['input_features'] n_time_steps = batch['n_time_steps'] # Apply transformations features, n_time_steps = DataAugmentationTF.transform_data( features, n_time_steps, transform_args, training=training ) # Update batch with transformed data batch['input_features'] = features batch['n_time_steps'] = n_time_steps return batch # Apply transformations dataset = dataset.map( apply_transforms, num_parallel_calls=tf.data.AUTOTUNE ) return dataset