HBM
This commit is contained in:
@@ -81,12 +81,24 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Initialize datasets
|
# Initialize datasets
|
||||||
self._initialize_datasets()
|
self._initialize_datasets()
|
||||||
|
|
||||||
|
|
||||||
# Build model within strategy scope
|
# Build model within strategy scope
|
||||||
with self.strategy.scope():
|
with self.strategy.scope():
|
||||||
|
print("🔨 Building model within TPU strategy scope...")
|
||||||
self.model = self._build_model()
|
self.model = self._build_model()
|
||||||
|
print("✅ Model built successfully")
|
||||||
|
|
||||||
|
print("⚙️ Creating optimizer...")
|
||||||
self.optimizer = self._create_optimizer()
|
self.optimizer = self._create_optimizer()
|
||||||
|
print("✅ Optimizer created")
|
||||||
|
|
||||||
|
print("📅 Setting up learning rate scheduler...")
|
||||||
self.lr_scheduler = self._create_lr_scheduler()
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
|
print("✅ LR scheduler ready")
|
||||||
|
|
||||||
|
print("🎯 Initializing CTC loss...")
|
||||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||||
|
print("✅ CTC loss initialized")
|
||||||
|
|
||||||
# Log model information
|
# Log model information
|
||||||
self._log_model_info()
|
self._log_model_info()
|
||||||
@@ -154,7 +166,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
||||||
|
|
||||||
def _get_tpu_status(self) -> str:
|
def _get_tpu_status(self) -> str:
|
||||||
"""Get current TPU status and utilization info"""
|
"""Get current TPU status and HBM utilization info"""
|
||||||
try:
|
try:
|
||||||
# Get TPU devices
|
# Get TPU devices
|
||||||
tpu_devices = tf.config.list_logical_devices('TPU')
|
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||||
@@ -165,12 +177,29 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Get strategy info
|
# Get strategy info
|
||||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||||
|
|
||||||
# Get memory usage (simplified)
|
# Try to get TPU memory info (HBM)
|
||||||
import psutil
|
try:
|
||||||
memory = psutil.virtual_memory()
|
# Attempt to get TPU memory usage for each device
|
||||||
|
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
|
||||||
|
if memory_info and 'current' in memory_info:
|
||||||
|
current_mb = memory_info['current'] // (1024 * 1024)
|
||||||
|
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
|
||||||
|
hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)"
|
||||||
|
else:
|
||||||
|
hbm_info = "HBM: unknown"
|
||||||
|
except Exception:
|
||||||
|
# Fallback: simple TPU activity check
|
||||||
|
try:
|
||||||
|
# Test TPU responsiveness
|
||||||
|
with tf.device('/TPU:0'):
|
||||||
|
test_tensor = tf.constant([1.0, 2.0])
|
||||||
|
_ = tf.reduce_sum(test_tensor)
|
||||||
|
hbm_info = "HBM: active"
|
||||||
|
except Exception:
|
||||||
|
hbm_info = "HBM: inactive"
|
||||||
|
|
||||||
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
||||||
f"RAM: {memory.percent:.1f}%")
|
f"{hbm_info}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"TPU: status_error({str(e)[:20]})"
|
return f"TPU: status_error({str(e)[:20]})"
|
||||||
@@ -188,9 +217,19 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||||
strategy_type = type(self.strategy).__name__
|
strategy_type = type(self.strategy).__name__
|
||||||
|
|
||||||
# Get memory info
|
# Get TPU HBM memory info
|
||||||
import psutil
|
try:
|
||||||
memory = psutil.virtual_memory()
|
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
|
||||||
|
if memory_info and 'current' in memory_info:
|
||||||
|
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
|
||||||
|
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
|
||||||
|
# TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB
|
||||||
|
estimated_total_gb = 32 * len(tpu_devices)
|
||||||
|
hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)"
|
||||||
|
else:
|
||||||
|
hbm_usage = "HBM: unknown"
|
||||||
|
except Exception:
|
||||||
|
hbm_usage = "HBM: unavailable"
|
||||||
|
|
||||||
# Simple TPU test
|
# Simple TPU test
|
||||||
try:
|
try:
|
||||||
@@ -204,7 +243,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
return (f"TPU Devices: {len(tpu_devices)} | "
|
return (f"TPU Devices: {len(tpu_devices)} | "
|
||||||
f"Strategy: {strategy_type} | "
|
f"Strategy: {strategy_type} | "
|
||||||
f"Cores: {num_replicas} | "
|
f"Cores: {num_replicas} | "
|
||||||
f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | "
|
f"{hbm_usage} | "
|
||||||
f"Test: {tpu_test}")
|
f"Test: {tpu_test}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Reference in New Issue
Block a user