HBM
This commit is contained in:
@@ -81,12 +81,24 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Initialize datasets
|
||||
self._initialize_datasets()
|
||||
|
||||
|
||||
# Build model within strategy scope
|
||||
with self.strategy.scope():
|
||||
print("🔨 Building model within TPU strategy scope...")
|
||||
self.model = self._build_model()
|
||||
print("✅ Model built successfully")
|
||||
|
||||
print("⚙️ Creating optimizer...")
|
||||
self.optimizer = self._create_optimizer()
|
||||
print("✅ Optimizer created")
|
||||
|
||||
print("📅 Setting up learning rate scheduler...")
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
print("✅ LR scheduler ready")
|
||||
|
||||
print("🎯 Initializing CTC loss...")
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
print("✅ CTC loss initialized")
|
||||
|
||||
# Log model information
|
||||
self._log_model_info()
|
||||
@@ -154,7 +166,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
||||
|
||||
def _get_tpu_status(self) -> str:
|
||||
"""Get current TPU status and utilization info"""
|
||||
"""Get current TPU status and HBM utilization info"""
|
||||
try:
|
||||
# Get TPU devices
|
||||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||
@@ -165,12 +177,29 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Get strategy info
|
||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||
|
||||
# Get memory usage (simplified)
|
||||
import psutil
|
||||
memory = psutil.virtual_memory()
|
||||
# Try to get TPU memory info (HBM)
|
||||
try:
|
||||
# Attempt to get TPU memory usage for each device
|
||||
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
|
||||
if memory_info and 'current' in memory_info:
|
||||
current_mb = memory_info['current'] // (1024 * 1024)
|
||||
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
|
||||
hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)"
|
||||
else:
|
||||
hbm_info = "HBM: unknown"
|
||||
except Exception:
|
||||
# Fallback: simple TPU activity check
|
||||
try:
|
||||
# Test TPU responsiveness
|
||||
with tf.device('/TPU:0'):
|
||||
test_tensor = tf.constant([1.0, 2.0])
|
||||
_ = tf.reduce_sum(test_tensor)
|
||||
hbm_info = "HBM: active"
|
||||
except Exception:
|
||||
hbm_info = "HBM: inactive"
|
||||
|
||||
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
||||
f"RAM: {memory.percent:.1f}%")
|
||||
f"{hbm_info}")
|
||||
|
||||
except Exception as e:
|
||||
return f"TPU: status_error({str(e)[:20]})"
|
||||
@@ -188,9 +217,19 @@ class BrainToTextDecoderTrainerTF:
|
||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||
strategy_type = type(self.strategy).__name__
|
||||
|
||||
# Get memory info
|
||||
import psutil
|
||||
memory = psutil.virtual_memory()
|
||||
# Get TPU HBM memory info
|
||||
try:
|
||||
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
|
||||
if memory_info and 'current' in memory_info:
|
||||
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
|
||||
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
|
||||
# TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB
|
||||
estimated_total_gb = 32 * len(tpu_devices)
|
||||
hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)"
|
||||
else:
|
||||
hbm_usage = "HBM: unknown"
|
||||
except Exception:
|
||||
hbm_usage = "HBM: unavailable"
|
||||
|
||||
# Simple TPU test
|
||||
try:
|
||||
@@ -204,7 +243,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
return (f"TPU Devices: {len(tpu_devices)} | "
|
||||
f"Strategy: {strategy_type} | "
|
||||
f"Cores: {num_replicas} | "
|
||||
f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | "
|
||||
f"{hbm_usage} | "
|
||||
f"Test: {tpu_test}")
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user