From 69a72858867e0cbfb3811f981b60022216726898 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 01:17:36 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=8A=A0=E8=BD=BD=E5=99=A8?= =?UTF-8?q?=E5=A4=9A=E7=BA=BF=E7=A8=8B=E5=8A=A0=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_training_nnn_tpu/trainer_tf.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 253a0db..e14e47f 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -43,6 +43,9 @@ class BrainToTextDecoderTrainerTF: self.args = args self.logger = None + # Optimize CPU utilization for data pipeline (利用224核心) + self._configure_cpu_optimization() + # Initialize TPU strategy self.strategy = create_tpu_strategy() print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores") @@ -123,6 +126,28 @@ class BrainToTextDecoderTrainerTF: if self.mixed_precision: self.logger.info('Mixed precision (bfloat16) enabled for TPU training') + def _configure_cpu_optimization(self): + """Configure CPU utilization to make use of 224 cores for data pipeline""" + import multiprocessing + + # Get available CPU cores + available_cores = multiprocessing.cpu_count() + print(f"💻 Available CPU cores: {available_cores}") + + # Optimize for data pipeline parallelism + # Use ~1/4 of cores for inter-op (between operations) + # Use ~1/8 of cores for intra-op (within operations) + inter_op_threads = min(32, available_cores // 4) + intra_op_threads = min(16, available_cores // 8) + + tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads) + tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads) + + print(f"🔧 CPU optimization configured:") + print(f" Inter-op parallelism: {inter_op_threads} threads") + print(f" Intra-op parallelism: {intra_op_threads} threads") + print(f" This will accelerate data loading and preprocessing while TPU handles training") + def _initialize_datasets(self): """Initialize training and validation datasets""" # Create file paths