tpu维护

This commit is contained in:
Zchen
2025-10-16 13:39:05 +08:00
parent 5a1e446219
commit a545cc5648
4 changed files with 708 additions and 26 deletions

4
.gitignore vendored
View File

@@ -12,4 +12,6 @@ model_training_lstm/trained_models
model_training_lstm/trained_models_history
*.pkl
*.pkl
.idea

View File

@@ -0,0 +1,403 @@
#!/usr/bin/env python3
"""
TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控
适用于TPU v5e-8环境
"""
import tensorflow as tf
import time
import numpy as np
def monitor_tpu_during_training():
"""训练过程中的TPU实时内存和MXU监控"""
print("📊 TPU训练实时监控工具")
print("=" * 50)
# 获取TPU设备
try:
tpu_devices = tf.config.list_logical_devices('TPU')
print(f"📍 发现TPU设备: {len(tpu_devices)}")
if not tpu_devices:
print("❌ 未发现TPU设备")
return
except Exception as e:
print(f"❌ 无法检测TPU设备: {e}")
return
def get_detailed_memory_snapshot():
"""获取详细的内存快照,包含所有核心信息"""
snapshot = {}
total_current = 0
total_peak = 0
active_cores = 0
core_details = []
for i, device in enumerate(tpu_devices):
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
if current_mb > 1: # >1MB算活跃
active_cores += 1
total_current += current_mb
total_peak += peak_mb
core_details.append(f"Core{i}:{current_mb}MB")
snapshot[f'core_{i}'] = {
'current': current_mb,
'peak': peak_mb,
'device': device.name
}
else:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name}
except Exception as e:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)}
snapshot['summary'] = {
'total_current': total_current,
'total_peak': total_peak,
'active_cores': active_cores,
'total_cores': len(tpu_devices),
'core_details': core_details
}
return snapshot
def test_mxu_performance():
"""测试MXU性能和计算能力"""
print("\n🧮 MXU计算性能测试:")
mxu_results = []
try:
with tf.device(tpu_devices[0].name):
# 测试不同规模的矩阵运算
test_configs = [
(2000, "2K×2K", tf.bfloat16),
(4000, "4K×4K", tf.bfloat16),
(6000, "6K×6K", tf.bfloat16),
]
for size, desc, dtype in test_configs:
try:
# 获取测试前内存
pre_mem = get_detailed_memory_snapshot()
start_time = time.time()
# 创建矩阵并执行MXU密集型运算
matrix_a = tf.random.normal([size, size], dtype=dtype)
matrix_b = tf.random.normal([size, size], dtype=dtype)
@tf.function
def mxu_operation():
# 连续矩阵运算充分使用MXU
result = tf.matmul(matrix_a, matrix_b)
result = tf.matmul(result, matrix_a)
return tf.reduce_sum(result)
result = mxu_operation()
# 使用result确保计算被执行
_ = result.numpy()
end_time = time.time()
# 获取测试后内存
post_mem = get_detailed_memory_snapshot()
duration = end_time - start_time
# 计算FLOPS (两次矩阵乘法)
flops = 2 * (2 * size**3)
tflops = flops / duration / 1e12
memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current']
print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB")
mxu_results.append({
'size': size,
'tflops': tflops,
'duration': duration,
'memory_used': memory_used
})
except Exception as e:
print(f" {desc}: 测试失败 - {str(e)[:50]}")
# MXU性能分析
if mxu_results:
max_tflops = max(r['tflops'] for r in mxu_results)
total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0)
# TPU v5e-8单核理论性能
theoretical_tflops = 275 # bf16峰值性能
efficiency = (max_tflops / theoretical_tflops) * 100
print(f"\n 📊 MXU性能汇总:")
print(f" 峰值性能: {max_tflops:.1f} TFLOPS")
print(f" 理论峰值: {theoretical_tflops} TFLOPS")
print(f" MXU效率: {efficiency:.1f}%")
print(f" 计算内存占用: {total_memory}MB")
if efficiency > 80:
status = "🟢 优秀"
elif efficiency > 50:
status = "🟡 良好"
elif efficiency > 20:
status = "🟠 中等"
else:
status = "🔴 需优化"
print(f" 性能评级: {status}")
except Exception as e:
print(f" MXU测试失败: {e}")
try:
print("🎯 开始TPU训练监控...")
# 1. 获取初始状态
print("\n📸 初始TPU状态:")
baseline_snapshot = get_detailed_memory_snapshot()
print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB")
print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}")
# 显示各核心详细状态
for i in range(len(tpu_devices)):
core = baseline_snapshot[f'core_{i}']
if core['current'] > 0 or core['peak'] > 0:
print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB")
# 2. MXU性能基准测试
test_mxu_performance()
# 3. 创建分布式策略 - 使用项目验证的TPU初始化代码
print(f"\n🔄 使用项目标准TPU初始化...")
try:
# 使用项目里验证过的TPU初始化代码
# 禁用GPU避免冲突
try:
tf.config.set_visible_devices([], 'GPU')
print("🚫 GPU已禁用避免CUDA冲突")
except:
pass
# 使用标准的TPU初始化流程
print("🚀 使用官方TensorFlow TPU初始化...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
# 验证TPU设备
tpu_devices_check = tf.config.list_logical_devices('TPU')
print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备")
# 创建TPU策略
strategy = tf.distribute.TPUStrategy(resolver)
print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本")
use_distributed = True
except Exception as e:
print(f"⚠️ 分布式策略失败: {str(e)[:80]}")
print(" 将使用单设备模拟")
use_distributed = False
# 4. 模拟Brain-to-Text训练场景
print(f"\n🧠 模拟Brain-to-Text训练场景...")
if use_distributed:
# 分布式训练模拟
with strategy.scope():
print("📦 创建分布式模型参数...")
# 创建接近真实Brain-to-Text模型的参数 (修复维度匹配)
model_components = {
# GRU层权重第一层接收512维输入后续层接收256维
'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'),
'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'),
'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'),
'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'),
# 添加day-specific层模拟 (输入512维输出512维)
'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)]
}
# 检查模型加载后内存
after_model = get_detailed_memory_snapshot()
model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心")
# 训练循环模拟
print(f"\n🔄 开始训练循环监控...")
for step in range(10):
step_start_time = time.time()
@tf.function
def distributed_training_step():
# 模拟真实训练数据大小
batch_size = 32
seq_length = 1000
features = 512
# 输入数据
neural_data = tf.random.normal([batch_size, seq_length, features])
targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32)
# 模拟前向传播
x = neural_data
# Day-specific transformation (简化版本避免复杂的维度操作)
# 模拟day-specific变换对每个时间步应用相同变换
day_weight = model_components['day_weights'][0] # 简化使用第一个day权重
# 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512]
x = tf.matmul(x, day_weight)
# 为CTC损失添加目标使用模拟
target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1)
# 简化的CTC相关计算
batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32))
# GRU layers simulation
for i in range(3):
layer_name = f'gru_layer_{i}'
weight = model_components[layer_name]
# 处理张量维度第一层从3D输入后续层从2D输入
if i == 0:
# 第一层:取最后时间步 [batch, seq, features] -> [batch, features]
if len(x.shape) == 3:
x = x[:, -1, :] # 取最后时间步
x = tf.nn.tanh(tf.matmul(x, weight))
else:
# 后续层直接处理2D张量 [batch, features] -> [batch, features]
x = tf.nn.tanh(tf.matmul(x, weight))
# 输出投影
logits = tf.matmul(x, model_components['output_projection'])
# CTC loss模拟使用batch_loss_weight作为权重
base_loss = tf.reduce_mean(tf.square(logits))
loss = base_loss * batch_loss_weight
return loss
# 执行训练步骤
per_replica_loss = strategy.run(distributed_training_step)
# 聚合分布式结果
loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None)
step_duration = time.time() - step_start_time
# 获取当前内存状态
current_snapshot = get_detailed_memory_snapshot()
step_memory = current_snapshot['summary']['total_current']
memory_delta = step_memory - baseline_snapshot['summary']['total_current']
# 显示详细训练状态
active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)"
print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, "
f"时间={step_duration:.3f}s, "
f"内存={step_memory}MB(+{memory_delta}), "
f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}")
# 每5步显示峰值内存
if step % 5 == 0:
peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB"
print(f" {peak_info}")
time.sleep(0.2) # 短暂暂停观察
else:
# 单设备训练模拟(改进版)
print("🔸 单设备训练模拟...")
with tf.device(tpu_devices[0].name):
# 创建较小的模型参数
simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net')
for step in range(8):
step_start = time.time()
# 创建较大的数据批次
batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size
# 模拟计算密集型操作
@tf.function
def compute_step():
x = tf.reshape(batch_data, [-1, 512])
result = tf.matmul(x, simple_weights)
result = tf.nn.relu(result)
return tf.reduce_mean(result)
result = compute_step()
step_duration = time.time() - step_start
# 获取内存状态
snapshot = get_detailed_memory_snapshot()
memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f" Step {step}: result={result.numpy():.4f}, "
f"时间={step_duration:.3f}s, "
f"内存变化=+{memory_change}MB, "
f"峰值={snapshot['summary']['total_peak']}MB")
# 5. 最终分析报告
final_snapshot = get_detailed_memory_snapshot()
total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
peak_usage = final_snapshot['summary']['total_peak']
print(f"\n📈 训练监控报告:")
print(f" 总内存增长: +{total_growth}MB")
print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)")
print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}")
# 各核心最终状态
print(f" 各核心最终状态:")
has_changes = False
for i in range(len(tpu_devices)):
final_core = final_snapshot[f'core_{i}']
baseline_core = baseline_snapshot[f'core_{i}']
current_change = final_core['current'] - baseline_core['current']
peak_change = final_core['peak'] - baseline_core['peak']
if current_change != 0 or peak_change != 0:
has_changes = True
print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})")
if not has_changes:
print(f" 所有核心内存无明显变化")
# 分布式使用分析
if final_snapshot['summary']['active_cores'] == 1:
print(f"\n⚠️ 分布式问题诊断:")
print(f" 只有1个核心活跃其他7个核心空闲")
print(f" 可能原因: TPU策略配置问题或模型未正确分布")
print(f" 建议: 检查分布式策略和模型分片")
elif final_snapshot['summary']['active_cores'] > 4:
print(f"\n✅ 分布式状态良好:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常")
else:
print(f"\n🟡 分布式部分工作:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡")
print("✅ TPU训练监控完成")
except Exception as e:
print(f"❌ 训练监控失败: {e}")
import traceback
print(f"详细错误: {traceback.format_exc()[:300]}")
if __name__ == "__main__":
print("🚀 TPU训练内存监控工具")
print("专注于训练过程中的实时内存和性能监控")
print("适用于TPU v5e-8环境")
print()
monitor_tpu_during_training()
print(f"\n🎯 监控要点总结:")
print(f" 1. 确认所有8个TPU核心是否活跃")
print(f" 2. 监控内存增长模式和峰值使用")
print(f" 3. 检测MXU计算性能和效率")
print(f" 4. 验证分布式策略是否正常工作")
print(f" 5. 识别可能的内存泄漏或性能瓶颈")

View File

@@ -0,0 +1,236 @@
#!/usr/bin/env python3
"""
TPU内存监控工具 - 专门用于训练过程
解决tf.config.experimental.get_memory_info()在TPU上无法工作的问题
"""
import tensorflow as tf
import time
import psutil
import os
class TPUMemoryMonitor:
"""TPU内存监控类"""
def __init__(self):
self.tpu_devices = tf.config.list_logical_devices('TPU')
self.baseline_memory = None
self.peak_allocations = {}
def get_tpu_status(self) -> str:
"""获取TPU状态 - 实用版本不依赖get_memory_info"""
try:
if not self.tpu_devices:
return "TPU: No devices"
num_cores = len(self.tpu_devices)
# 测试TPU响应性
try:
with tf.device('/TPU:0'):
test_tensor = tf.constant([1.0, 2.0, 3.0])
result = tf.reduce_sum(test_tensor)
_ = result.numpy() # 强制执行
activity = "active"
except Exception:
activity = "inactive"
# 获取主机内存作为参考
try:
memory = psutil.virtual_memory()
host_mem = f"Host:{memory.percent:.1f}%"
except:
host_mem = "Host:unknown"
return f"TPU: {num_cores}cores {activity} {host_mem}"
except Exception as e:
return f"TPU: error({str(e)[:20]})"
def estimate_tensor_memory(self, tensor_shape, dtype=tf.float32):
"""估算张量内存使用量"""
if dtype == tf.float32:
bytes_per_element = 4
elif dtype == tf.float16 or dtype == tf.bfloat16:
bytes_per_element = 2
elif dtype == tf.int32:
bytes_per_element = 4
elif dtype == tf.int64:
bytes_per_element = 8
else:
bytes_per_element = 4 # 默认
total_elements = 1
for dim in tensor_shape:
total_elements *= dim
total_bytes = total_elements * bytes_per_element
return total_bytes / (1024 * 1024) # 返回MB
def track_allocation(self, name: str, tensor_shape, dtype=tf.float32):
"""跟踪内存分配"""
mb = self.estimate_tensor_memory(tensor_shape, dtype)
self.peak_allocations[name] = self.peak_allocations.get(name, 0) + mb
return mb
def get_allocation_summary(self) -> str:
"""获取分配汇总"""
if not self.peak_allocations:
return "No allocations tracked"
total_mb = sum(self.peak_allocations.values())
top_3 = sorted(self.peak_allocations.items(), key=lambda x: x[1], reverse=True)[:3]
summary = f"Tracked:{total_mb:.1f}MB "
summary += f"Top:({top_3[0][0]}:{top_3[0][1]:.1f}MB)"
return summary
def test_memory_allocation_across_cores(self):
"""测试8个核心的内存分配"""
print("🧪 测试所有TPU核心内存分配")
print("=" * 40)
allocations_per_core = []
for i, device in enumerate(self.tpu_devices):
print(f"核心 {i+1}: {device.name}")
try:
with tf.device(device.name):
# 创建不同大小的测试张量
test_sizes = [
([1000, 1000], "1K×1K"),
([3000, 3000], "3K×3K"),
([5000, 5000], "5K×5K"),
([7000, 7000], "7K×7K"),
]
core_total = 0
successful_allocs = []
for shape, desc in test_sizes:
try:
tensor = tf.ones(shape, dtype=tf.float32)
mb = self.estimate_tensor_memory(shape)
core_total += mb
successful_allocs.append(f"{desc}({mb:.1f}MB)")
# 实际使用张量防止被优化
_ = tf.reduce_mean(tensor)
except Exception as e:
print(f" {desc} 失败: {str(e)[:30]}")
break
allocations_per_core.append(core_total)
print(f" 成功分配: {' + '.join(successful_allocs)}")
print(f" 核心总计: {core_total:.1f}MB")
except Exception as e:
print(f" 核心{i+1}失败: {e}")
allocations_per_core.append(0)
# 汇总结果
total_all_cores = sum(allocations_per_core)
avg_per_core = total_all_cores / len(self.tpu_devices) if self.tpu_devices else 0
print(f"\n📊 汇总结果:")
print(f" 总分配: {total_all_cores:.1f}MB ({total_all_cores/1024:.2f}GB)")
print(f" 平均每核: {avg_per_core:.1f}MB ({avg_per_core/1024:.2f}GB)")
# 推测内存配置
if avg_per_core > 8000: # > 8GB
print(" 推测: 每核心≥16GB (高端配置)")
elif avg_per_core > 4000: # > 4GB
print(" 推测: 每核心8-16GB (标准配置)")
elif avg_per_core > 1000: # > 1GB
print(" 推测: 每核心2-8GB (受限或共享)")
else:
print(" 推测: 每核心<2GB (严重受限)")
return allocations_per_core
def test_training_memory_pattern():
"""测试模拟训练的内存模式"""
print("\n🏋️ 模拟训练内存模式测试")
print("=" * 30)
monitor = TPUMemoryMonitor()
# 模拟典型的brain-to-text模型内存使用
with tf.device('/TPU:0'):
print("创建模拟模型组件...")
# 1. 输入数据 (batch_size=32, seq_len=1000, features=512)
batch_size, seq_len, features = 32, 1000, 512
input_data = tf.random.normal([batch_size, seq_len, features])
input_mb = monitor.track_allocation("input_data", [batch_size, seq_len, features])
print(f" 输入数据: {input_mb:.1f}MB")
# 2. GRU权重 (假设3层, 每层256单元)
n_layers, n_units = 3, 256
for layer in range(n_layers):
# GRU有3个门每个门需要权重矩阵
weight_shape = [features if layer == 0 else n_units, n_units * 3]
weights = tf.random.normal(weight_shape)
weight_mb = monitor.track_allocation(f"gru_layer_{layer}", weight_shape)
print(f" GRU层{layer+1}权重: {weight_mb:.1f}MB")
# 3. 输出投影层 (n_units -> n_classes=41)
n_classes = 41
output_weights = tf.random.normal([n_units, n_classes])
output_mb = monitor.track_allocation("output_projection", [n_units, n_classes])
print(f" 输出投影: {output_mb:.1f}MB")
# 4. 中间激活值 (前向传播)
hidden_states = tf.random.normal([batch_size, seq_len, n_units])
hidden_mb = monitor.track_allocation("hidden_states", [batch_size, seq_len, n_units])
print(f" 隐藏状态: {hidden_mb:.1f}MB")
# 5. 梯度 (反向传播时会翻倍内存)
total_params_mb = sum([v for k, v in monitor.peak_allocations.items() if 'layer' in k or 'projection' in k])
gradient_mb = total_params_mb # 梯度内存约等于参数内存
print(f" 梯度内存: {gradient_mb:.1f}MB (估算)")
print(f"\n模型总内存估算: {monitor.get_allocation_summary()}")
# 实际执行一些操作确保内存被分配
result = tf.reduce_mean(input_data) + tf.reduce_mean(hidden_states)
print(f"验证计算结果: {result.numpy():.4f}")
if __name__ == "__main__":
print("🚀 TPU内存监控工具启动")
monitor = TPUMemoryMonitor()
# 基础状态检查
print(f"当前TPU状态: {monitor.get_tpu_status()}")
# 测试所有核心
print("\n" + "="*50)
core_allocations = monitor.test_memory_allocation_across_cores()
# 训练内存模式测试
print("\n" + "="*50)
test_training_memory_pattern()
print(f"\n🎯 关键发现:")
if core_allocations:
max_core = max(core_allocations)
min_core = min([x for x in core_allocations if x > 0])
print(f" 最大单核分配: {max_core:.1f}MB")
print(f" 最小单核分配: {min_core:.1f}MB")
if max_core > 9000: # 你之前测试到9.4GB
print(" ✅ 内存充足,可支持大模型训练")
elif max_core > 5000:
print(" ⚠️ 内存中等,建议优化模型大小")
else:
print(" ❌ 内存不足,需要大幅减少模型参数")
print(f"\n💡 针对你的训练卡顿问题:")
print(f" - SetPriority错误通常是XLA编译问题不是内存问题")
print(f" - 你的9.4GB测试说明TPU内存工作正常")
print(f" - 建议检查模型是否有导致XLA编译卡顿的操作")
print(f" - 考虑使用更简单的操作或关闭某些XLA优化")

View File

@@ -177,26 +177,43 @@ class BrainToTextDecoderTrainerTF:
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
# Try to get TPU memory info (HBM)
# Get TPU memory info using the working /device:TPU:X format
try:
# Attempt to get TPU memory usage for each device
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)"
# Check all TPU devices for memory usage
active_cores = 0
total_current_mb = 0
max_peak_mb = 0
for device in tpu_devices:
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
if current_mb > 1: # >1MB considered active
active_cores += 1
total_current_mb += current_mb
max_peak_mb = max(max_peak_mb, peak_mb)
except:
continue
if active_cores > 0:
if active_cores == 1:
hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)"
else:
hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)"
else:
hbm_info = "HBM: unknown"
hbm_info = "HBM:idle"
except Exception:
# Fallback: simple TPU activity check
try:
# Test TPU responsiveness
with tf.device('/TPU:0'):
test_tensor = tf.constant([1.0, 2.0])
_ = tf.reduce_sum(test_tensor)
hbm_info = "HBM: active"
_ = tf.constant(1.0)
hbm_info = "HBM:active"
except Exception:
hbm_info = "HBM: inactive"
hbm_info = "HBM:inactive"
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
f"{hbm_info}")
@@ -217,17 +234,37 @@ class BrainToTextDecoderTrainerTF:
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
strategy_type = type(self.strategy).__name__
# Get TPU HBM memory info
# Get TPU HBM memory info using working device format
try:
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
if memory_info and 'current' in memory_info:
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
# TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB
estimated_total_gb = 32 * len(tpu_devices)
hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)"
active_cores = 0
total_current_gb = 0
max_peak_gb = 0
memory_details = []
for i, device in enumerate(tpu_devices):
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB
active_cores += 1
total_current_gb += current_gb
max_peak_gb = max(max_peak_gb, peak_gb)
if current_gb > 0:
memory_details.append(f"Core{i}:{current_gb}GB")
except:
continue
if active_cores > 0:
# Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core
estimated_per_core = 16 # Conservative estimate
estimated_total_gb = estimated_per_core * len(tpu_devices)
hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores"
else:
hbm_usage = "HBM: unknown"
hbm_usage = "HBM: 0GB/256GB (idle)"
except Exception:
hbm_usage = "HBM: unavailable"
@@ -559,39 +596,43 @@ class BrainToTextDecoderTrainerTF:
self.args['dataset']['data_transforms'],
training=True
)
val_dataset = create_input_fn(
self.val_dataset_tf,
self.args['dataset']['data_transforms'],
training=False
)
# Distribute datasets
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
self.logger.info("Created distributed training and validation datasets")
# Training metrics
train_losses = []
val_losses = []
val_pers = []
val_results = []
val_steps_since_improvement = 0
self.logger.info("Training time count beginning...")
train_start_time = time.time()
# Training loop
step = 0
for batch in train_dist_dataset:
if step >= self.args['num_training_batches']:
self.logger.info("Reached maximum training batches, stopping training")
break
start_time = time.time()
# Distributed training step
self.logger.info("Running distributed training step...")
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
# Reduce across replicas
self.logger.info("Reducing results across replicas...")
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)