From 5a1e446219f0d62050b0646ece07ebaf5b338dfc Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:42:56 +0800 Subject: [PATCH] HBM --- model_training_nnn_tpu/trainer_tf.py | 57 +++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index dce2af5..31e344f 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -81,12 +81,24 @@ class BrainToTextDecoderTrainerTF: # Initialize datasets self._initialize_datasets() + # Build model within strategy scope with self.strategy.scope(): + print("🔨 Building model within TPU strategy scope...") self.model = self._build_model() + print("✅ Model built successfully") + + print("⚙️ Creating optimizer...") self.optimizer = self._create_optimizer() + print("✅ Optimizer created") + + print("📅 Setting up learning rate scheduler...") self.lr_scheduler = self._create_lr_scheduler() + print("✅ LR scheduler ready") + + print("🎯 Initializing CTC loss...") self.ctc_loss = CTCLoss(blank_index=0, reduction='none') + print("✅ CTC loss initialized") # Log model information self._log_model_info() @@ -154,7 +166,7 @@ class BrainToTextDecoderTrainerTF: print(f" This will accelerate data loading and preprocessing while TPU handles training") def _get_tpu_status(self) -> str: - """Get current TPU status and utilization info""" + """Get current TPU status and HBM utilization info""" try: # Get TPU devices tpu_devices = tf.config.list_logical_devices('TPU') @@ -165,12 +177,29 @@ class BrainToTextDecoderTrainerTF: # Get strategy info num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 - # Get memory usage (simplified) - import psutil - memory = psutil.virtual_memory() + # Try to get TPU memory info (HBM) + try: + # Attempt to get TPU memory usage for each device + memory_info = tf.config.experimental.get_memory_info('/TPU:0') + if memory_info and 'current' in memory_info: + current_mb = memory_info['current'] // (1024 * 1024) + peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) + hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)" + else: + hbm_info = "HBM: unknown" + except Exception: + # Fallback: simple TPU activity check + try: + # Test TPU responsiveness + with tf.device('/TPU:0'): + test_tensor = tf.constant([1.0, 2.0]) + _ = tf.reduce_sum(test_tensor) + hbm_info = "HBM: active" + except Exception: + hbm_info = "HBM: inactive" return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores " - f"RAM: {memory.percent:.1f}%") + f"{hbm_info}") except Exception as e: return f"TPU: status_error({str(e)[:20]})" @@ -188,9 +217,19 @@ class BrainToTextDecoderTrainerTF: num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 strategy_type = type(self.strategy).__name__ - # Get memory info - import psutil - memory = psutil.virtual_memory() + # Get TPU HBM memory info + try: + memory_info = tf.config.experimental.get_memory_info('/TPU:0') + if memory_info and 'current' in memory_info: + current_gb = memory_info['current'] // (1024 * 1024 * 1024) + peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024) + # TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB + estimated_total_gb = 32 * len(tpu_devices) + hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)" + else: + hbm_usage = "HBM: unknown" + except Exception: + hbm_usage = "HBM: unavailable" # Simple TPU test try: @@ -204,7 +243,7 @@ class BrainToTextDecoderTrainerTF: return (f"TPU Devices: {len(tpu_devices)} | " f"Strategy: {strategy_type} | " f"Cores: {num_replicas} | " - f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | " + f"{hbm_usage} | " f"Test: {tpu_test}") except Exception as e: