From c2661550ef35554a70e3afd1e4d103dd0ac925f7 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 20:26:32 +0800 Subject: [PATCH] =?UTF-8?q?=E5=86=85=E5=AD=98=E6=B3=84=E6=BC=8F=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_training_nnn_tpu/dataset_tf.py | 8 +++++--- model_training_nnn_tpu/trainer_tf.py | 28 ++++++++++++++-------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index b7ce215..7e81f7f 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -417,9 +417,11 @@ class BrainToTextDatasetTF: ) # Apply TPU-optimized transformations - if self.split == 'train': - # For training, add shuffling - dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) + # 🚨 GPU版本策略:不需要在Dataset级别shuffle! + # GPU版本在 _create_batch_index_train() 中已经做了随机采样(第107-118行) + # 这里再shuffle会导致内存爆炸(1000 batch × 256 trials = 256,000 trials同时在内存) + # if self.split == 'train': + # dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手 # Prefetch for better performance dataset = dataset.prefetch(self.prefetch_buffer) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 5438999..d4598c6 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -328,8 +328,8 @@ class BrainToTextDecoderTrainerTF: import psutil initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024 - print("🔄 Initializing training dataset with full data preloading...") - preload_start_time = time.time() + print("🔄 Initializing training dataset with GPU-style memory management...") + init_start_time = time.time() self.train_dataset_tf = BrainToTextDatasetTF( trial_indices=train_trials, n_batches=self.args['num_training_batches'], @@ -339,18 +339,18 @@ class BrainToTextDecoderTrainerTF: random_seed=self.args['dataset']['seed'], must_include_days=self.args['dataset'].get('must_include_days'), feature_subset=self.args['dataset'].get('feature_subset'), - cache_data=True, # 启用数据缓存 - preload_all_data=True # 一次性加载所有训练数据到内存 + cache_data=True, # 启用智能缓存(像GPU版本一样) + preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出 ) - # Log training data preloading performance - train_preload_time = time.time() - preload_start_time + # Log training dataset initialization performance + train_init_time = time.time() - init_start_time train_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024 train_memory_used = train_memory_mb - initial_memory_mb - print(f"✅ Training data preloaded in {train_preload_time:.2f}s, using {train_memory_used:.1f} MB RAM") + print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM") - print("🔄 Initializing validation dataset with caching...") - val_preload_start_time = time.time() + print("🔄 Initializing validation dataset with GPU-style memory management...") + val_init_start_time = time.time() self.val_dataset_tf = BrainToTextDatasetTF( trial_indices=val_trials, n_batches=None, # Use all validation data @@ -359,16 +359,16 @@ class BrainToTextDecoderTrainerTF: days_per_batch=1, # One day per validation batch random_seed=self.args['dataset']['seed'], feature_subset=self.args['dataset'].get('feature_subset'), - cache_data=True, # 启用数据缓存 - preload_all_data=True # 一次性加载所有验证数据到内存 + cache_data=True, # 启用智能缓存(像GPU版本一样) + preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出 ) - # Log validation data preloading performance - val_preload_time = time.time() - val_preload_start_time + # Log validation dataset initialization performance + val_init_time = time.time() - val_init_start_time final_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024 total_memory_used = final_memory_mb - initial_memory_mb val_memory_used = final_memory_mb - train_memory_mb - print(f"✅ Validation data preloaded in {val_preload_time:.2f}s, using {val_memory_used:.1f} MB RAM") + print(f"✅ Validation dataset initialized in {val_init_time:.2f}s, using {val_memory_used:.1f} MB RAM") print(f"📊 Total data cache: {total_memory_used:.1f} MB RAM used for all datasets") self.logger.info("Successfully initialized TensorFlow datasets")