修复数据加载器低效问题
This commit is contained in:
@@ -26,7 +26,9 @@ class BrainToTextDatasetTF:
|
||||
must_include_days: Optional[List[int]] = None,
|
||||
feature_subset: Optional[List[int]] = None,
|
||||
prefetch_buffer: int = tf.data.AUTOTUNE,
|
||||
num_parallel_calls: int = tf.data.AUTOTUNE
|
||||
num_parallel_calls: int = tf.data.AUTOTUNE,
|
||||
cache_data: bool = True,
|
||||
preload_all_data: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize TensorFlow dataset for brain-to-text data
|
||||
@@ -42,6 +44,8 @@ class BrainToTextDatasetTF:
|
||||
feature_subset: Subset of neural features to use
|
||||
prefetch_buffer: Buffer size for prefetching
|
||||
num_parallel_calls: Parallel processing threads
|
||||
cache_data: Whether to cache loaded data in memory
|
||||
preload_all_data: Whether to preload all data at initialization
|
||||
"""
|
||||
|
||||
# Set random seed for reproducibility
|
||||
@@ -62,6 +66,11 @@ class BrainToTextDatasetTF:
|
||||
self.must_include_days = must_include_days
|
||||
self.prefetch_buffer = prefetch_buffer
|
||||
self.num_parallel_calls = num_parallel_calls
|
||||
self.cache_data = cache_data
|
||||
self.preload_all_data = preload_all_data
|
||||
|
||||
# Initialize data cache
|
||||
self.data_cache = {} if cache_data else None
|
||||
|
||||
# Calculate total number of trials
|
||||
self.n_trials = 0
|
||||
@@ -88,6 +97,12 @@ class BrainToTextDatasetTF:
|
||||
self.batch_indices = self._create_batch_index_test()
|
||||
self.n_batches = len(self.batch_indices)
|
||||
|
||||
# Preload data if requested (speeds up first batch significantly)
|
||||
if self.preload_all_data:
|
||||
print(f"🔄 Preloading all data for {self.split} split...")
|
||||
self._preload_all_data()
|
||||
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
|
||||
|
||||
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
|
||||
"""Create training batch indices with random sampling"""
|
||||
batch_indices = {}
|
||||
@@ -160,8 +175,51 @@ class BrainToTextDatasetTF:
|
||||
|
||||
return batch_indices
|
||||
|
||||
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
|
||||
"""Load a single trial's data from HDF5 file"""
|
||||
def _preload_all_data(self):
|
||||
"""Preload all trial data into memory cache (uses available RAM optimally)"""
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
# Use CPU cores efficiently for parallel I/O
|
||||
max_workers = min(multiprocessing.cpu_count(), 32) # Limit to avoid overwhelming I/O
|
||||
|
||||
# Collect all trials to load
|
||||
trials_to_load = []
|
||||
for day in self.trial_indices:
|
||||
for trial in self.trial_indices[day]['trials']:
|
||||
trials_to_load.append((day, trial))
|
||||
|
||||
print(f"📊 Preloading {len(trials_to_load)} trials using {max_workers} workers...")
|
||||
|
||||
# Parallel loading using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all loading tasks
|
||||
future_to_trial = {
|
||||
executor.submit(self._load_single_trial_data, day, trial): (day, trial)
|
||||
for day, trial in trials_to_load
|
||||
}
|
||||
|
||||
# Process completed tasks and update cache
|
||||
loaded_count = 0
|
||||
for future in as_completed(future_to_trial):
|
||||
day, trial = future_to_trial[future]
|
||||
try:
|
||||
trial_data = future.result()
|
||||
cache_key = f"{day}_{trial}"
|
||||
self.data_cache[cache_key] = trial_data
|
||||
loaded_count += 1
|
||||
|
||||
# Progress indicator every 100 trials
|
||||
if loaded_count % 100 == 0:
|
||||
print(f" Loaded {loaded_count}/{len(trials_to_load)} trials...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to load trial {day}_{trial}: {e}")
|
||||
|
||||
print(f"✅ Preloading completed: {loaded_count}/{len(trials_to_load)} trials cached")
|
||||
|
||||
def _load_single_trial_data(self, day: int, trial: int) -> Dict[str, Any]:
|
||||
"""Load a single trial's data - optimized version for parallel loading"""
|
||||
try:
|
||||
session_path = self.trial_indices[day]['session_path']
|
||||
|
||||
@@ -173,8 +231,8 @@ class BrainToTextDatasetTF:
|
||||
if self.feature_subset:
|
||||
input_features = input_features[:, self.feature_subset]
|
||||
|
||||
# Convert to bfloat16 for TPU efficiency
|
||||
input_features = input_features.astype(np.float32) # TF will handle bfloat16 conversion
|
||||
# Convert to float32 for TF compatibility
|
||||
input_features = input_features.astype(np.float32)
|
||||
|
||||
trial_data = {
|
||||
'input_features': input_features,
|
||||
@@ -190,8 +248,7 @@ class BrainToTextDatasetTF:
|
||||
return trial_data
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error loading trial {trial} from day {day}: {e}')
|
||||
# Return dummy data to maintain batch structure
|
||||
# Return dummy data for failed loads
|
||||
return {
|
||||
'input_features': np.zeros((100, 512), dtype=np.float32),
|
||||
'seq_class_ids': np.zeros((10,), dtype=np.int32),
|
||||
@@ -203,9 +260,32 @@ class BrainToTextDatasetTF:
|
||||
'trial_num': 0
|
||||
}
|
||||
|
||||
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
|
||||
"""Load a single trial's data from cache or HDF5 file"""
|
||||
# Check cache first if caching is enabled
|
||||
if self.cache_data:
|
||||
cache_key = f"{day}_{trial}"
|
||||
if cache_key in self.data_cache:
|
||||
return self.data_cache[cache_key]
|
||||
|
||||
# Load from disk if not in cache
|
||||
trial_data = self._load_single_trial_data(day, trial)
|
||||
|
||||
# Cache the loaded data if caching is enabled
|
||||
if self.cache_data:
|
||||
cache_key = f"{day}_{trial}"
|
||||
self.data_cache[cache_key] = trial_data
|
||||
|
||||
return trial_data
|
||||
|
||||
def _create_batch_generator(self):
|
||||
"""Generator function that yields individual batches"""
|
||||
"""Generator function that yields individual batches with optimized loading"""
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
for batch_idx in range(self.n_batches):
|
||||
batch_start_time = time.time()
|
||||
|
||||
batch_data = {
|
||||
'input_features': [],
|
||||
'seq_class_ids': [],
|
||||
@@ -219,11 +299,42 @@ class BrainToTextDatasetTF:
|
||||
|
||||
batch_index = self.batch_indices[batch_idx]
|
||||
|
||||
# Load data for each day in the batch
|
||||
# Collect all trials to load for this batch
|
||||
trials_to_load = []
|
||||
for day in batch_index.keys():
|
||||
for trial in batch_index[day]:
|
||||
trial_data = self._load_trial_data(day, trial)
|
||||
trials_to_load.append((day, trial))
|
||||
|
||||
# Use parallel loading if not preloaded and have multiple trials
|
||||
if not self.preload_all_data and len(trials_to_load) > 4:
|
||||
# Parallel loading for faster I/O
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(trials_to_load))) as executor:
|
||||
future_to_trial = {
|
||||
executor.submit(self._load_trial_data, day, trial): (day, trial)
|
||||
for day, trial in trials_to_load
|
||||
}
|
||||
|
||||
# Collect results in order
|
||||
trial_results = {}
|
||||
for future in future_to_trial:
|
||||
day, trial = future_to_trial[future]
|
||||
trial_results[(day, trial)] = future.result()
|
||||
|
||||
# Add data in original order
|
||||
for day, trial in trials_to_load:
|
||||
trial_data = trial_results[(day, trial)]
|
||||
batch_data['input_features'].append(trial_data['input_features'])
|
||||
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
||||
batch_data['transcriptions'].append(trial_data['transcription'])
|
||||
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
|
||||
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
|
||||
batch_data['day_indices'].append(trial_data['day_index'])
|
||||
batch_data['block_nums'].append(trial_data['block_num'])
|
||||
batch_data['trial_nums'].append(trial_data['trial_num'])
|
||||
else:
|
||||
# Sequential loading (fast when data is cached or few trials)
|
||||
for day, trial in trials_to_load:
|
||||
trial_data = self._load_trial_data(day, trial)
|
||||
batch_data['input_features'].append(trial_data['input_features'])
|
||||
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
||||
batch_data['transcriptions'].append(trial_data['transcription'])
|
||||
@@ -233,6 +344,14 @@ class BrainToTextDatasetTF:
|
||||
batch_data['block_nums'].append(trial_data['block_num'])
|
||||
batch_data['trial_nums'].append(trial_data['trial_num'])
|
||||
|
||||
data_loading_time = time.time() - batch_start_time
|
||||
|
||||
# Add timing diagnostic for first few batches
|
||||
if batch_idx < 3:
|
||||
cache_status = "cached" if self.preload_all_data else "disk"
|
||||
loading_method = "parallel" if (not self.preload_all_data and len(trials_to_load) > 4) else "sequential"
|
||||
print(f"⏱️ Batch {batch_idx}: {len(trials_to_load)} trials loaded in {data_loading_time:.3f}s ({cache_status}, {loading_method})")
|
||||
|
||||
# Pad sequences to create uniform batch
|
||||
max_time_steps = max(batch_data['n_time_steps'])
|
||||
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
|
||||
|
||||
Reference in New Issue
Block a user