内存泄漏修复

This commit is contained in:
Zchen
2025-10-16 20:26:32 +08:00
parent 1b9e0d9bdf
commit c2661550ef
2 changed files with 19 additions and 17 deletions

View File

@@ -417,9 +417,11 @@ class BrainToTextDatasetTF:
)
# Apply TPU-optimized transformations
if self.split == 'train':
# For training, add shuffling
dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches))
# 🚨 GPU版本策略不需要在Dataset级别shuffle!
# GPU版本在 _create_batch_index_train() 中已经做了随机采样第107-118行
# 这里再shuffle会导致内存爆炸1000 batch × 256 trials = 256,000 trials同时在内存
# if self.split == 'train':
# dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手
# Prefetch for better performance
dataset = dataset.prefetch(self.prefetch_buffer)

View File

@@ -328,8 +328,8 @@ class BrainToTextDecoderTrainerTF:
import psutil
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
print("🔄 Initializing training dataset with full data preloading...")
preload_start_time = time.time()
print("🔄 Initializing training dataset with GPU-style memory management...")
init_start_time = time.time()
self.train_dataset_tf = BrainToTextDatasetTF(
trial_indices=train_trials,
n_batches=self.args['num_training_batches'],
@@ -339,18 +339,18 @@ class BrainToTextDecoderTrainerTF:
random_seed=self.args['dataset']['seed'],
must_include_days=self.args['dataset'].get('must_include_days'),
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用数据缓存
preload_all_data=True # 一次性加载所有训练数据到内存
cache_data=True, # 启用智能缓存像GPU版本一样
preload_all_data=False # 🚨 采用GPU版本策略按需加载避免内存溢出
)
# Log training data preloading performance
train_preload_time = time.time() - preload_start_time
# Log training dataset initialization performance
train_init_time = time.time() - init_start_time
train_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
train_memory_used = train_memory_mb - initial_memory_mb
print(f"✅ Training data preloaded in {train_preload_time:.2f}s, using {train_memory_used:.1f} MB RAM")
print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM")
print("🔄 Initializing validation dataset with caching...")
val_preload_start_time = time.time()
print("🔄 Initializing validation dataset with GPU-style memory management...")
val_init_start_time = time.time()
self.val_dataset_tf = BrainToTextDatasetTF(
trial_indices=val_trials,
n_batches=None, # Use all validation data
@@ -359,16 +359,16 @@ class BrainToTextDecoderTrainerTF:
days_per_batch=1, # One day per validation batch
random_seed=self.args['dataset']['seed'],
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用数据缓存
preload_all_data=True # 一次性加载所有验证数据到内存
cache_data=True, # 启用智能缓存像GPU版本一样
preload_all_data=False # 🚨 采用GPU版本策略按需加载避免内存溢出
)
# Log validation data preloading performance
val_preload_time = time.time() - val_preload_start_time
# Log validation dataset initialization performance
val_init_time = time.time() - val_init_start_time
final_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
total_memory_used = final_memory_mb - initial_memory_mb
val_memory_used = final_memory_mb - train_memory_mb
print(f"✅ Validation data preloaded in {val_preload_time:.2f}s, using {val_memory_used:.1f} MB RAM")
print(f"✅ Validation dataset initialized in {val_init_time:.2f}s, using {val_memory_used:.1f} MB RAM")
print(f"📊 Total data cache: {total_memory_used:.1f} MB RAM used for all datasets")
self.logger.info("Successfully initialized TensorFlow datasets")