Files
b2txt25/model_training_nnn_tpu/check_tpu_memory.py

403 lines
18 KiB
Python
Raw Normal View History

2025-10-16 13:39:05 +08:00
#!/usr/bin/env python3
"""
TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控
适用于TPU v5e-8环境
"""
import tensorflow as tf
import time
import numpy as np
def monitor_tpu_during_training():
"""训练过程中的TPU实时内存和MXU监控"""
print("📊 TPU训练实时监控工具")
print("=" * 50)
# 获取TPU设备
try:
tpu_devices = tf.config.list_logical_devices('TPU')
print(f"📍 发现TPU设备: {len(tpu_devices)}")
if not tpu_devices:
print("❌ 未发现TPU设备")
return
except Exception as e:
print(f"❌ 无法检测TPU设备: {e}")
return
def get_detailed_memory_snapshot():
"""获取详细的内存快照,包含所有核心信息"""
snapshot = {}
total_current = 0
total_peak = 0
active_cores = 0
core_details = []
for i, device in enumerate(tpu_devices):
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
if current_mb > 1: # >1MB算活跃
active_cores += 1
total_current += current_mb
total_peak += peak_mb
core_details.append(f"Core{i}:{current_mb}MB")
snapshot[f'core_{i}'] = {
'current': current_mb,
'peak': peak_mb,
'device': device.name
}
else:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name}
except Exception as e:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)}
snapshot['summary'] = {
'total_current': total_current,
'total_peak': total_peak,
'active_cores': active_cores,
'total_cores': len(tpu_devices),
'core_details': core_details
}
return snapshot
def test_mxu_performance():
"""测试MXU性能和计算能力"""
print("\n🧮 MXU计算性能测试:")
mxu_results = []
try:
with tf.device(tpu_devices[0].name):
# 测试不同规模的矩阵运算
test_configs = [
(2000, "2K×2K", tf.bfloat16),
(4000, "4K×4K", tf.bfloat16),
(6000, "6K×6K", tf.bfloat16),
]
for size, desc, dtype in test_configs:
try:
# 获取测试前内存
pre_mem = get_detailed_memory_snapshot()
start_time = time.time()
# 创建矩阵并执行MXU密集型运算
matrix_a = tf.random.normal([size, size], dtype=dtype)
matrix_b = tf.random.normal([size, size], dtype=dtype)
@tf.function
def mxu_operation():
# 连续矩阵运算充分使用MXU
result = tf.matmul(matrix_a, matrix_b)
result = tf.matmul(result, matrix_a)
return tf.reduce_sum(result)
result = mxu_operation()
# 使用result确保计算被执行
_ = result.numpy()
end_time = time.time()
# 获取测试后内存
post_mem = get_detailed_memory_snapshot()
duration = end_time - start_time
# 计算FLOPS (两次矩阵乘法)
flops = 2 * (2 * size**3)
tflops = flops / duration / 1e12
memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current']
print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB")
mxu_results.append({
'size': size,
'tflops': tflops,
'duration': duration,
'memory_used': memory_used
})
except Exception as e:
print(f" {desc}: 测试失败 - {str(e)[:50]}")
# MXU性能分析
if mxu_results:
max_tflops = max(r['tflops'] for r in mxu_results)
total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0)
# TPU v5e-8单核理论性能
theoretical_tflops = 275 # bf16峰值性能
efficiency = (max_tflops / theoretical_tflops) * 100
print(f"\n 📊 MXU性能汇总:")
print(f" 峰值性能: {max_tflops:.1f} TFLOPS")
print(f" 理论峰值: {theoretical_tflops} TFLOPS")
print(f" MXU效率: {efficiency:.1f}%")
print(f" 计算内存占用: {total_memory}MB")
if efficiency > 80:
status = "🟢 优秀"
elif efficiency > 50:
status = "🟡 良好"
elif efficiency > 20:
status = "🟠 中等"
else:
status = "🔴 需优化"
print(f" 性能评级: {status}")
except Exception as e:
print(f" MXU测试失败: {e}")
try:
print("🎯 开始TPU训练监控...")
# 1. 获取初始状态
print("\n📸 初始TPU状态:")
baseline_snapshot = get_detailed_memory_snapshot()
print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB")
print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}")
# 显示各核心详细状态
for i in range(len(tpu_devices)):
core = baseline_snapshot[f'core_{i}']
if core['current'] > 0 or core['peak'] > 0:
print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB")
# 2. MXU性能基准测试
test_mxu_performance()
# 3. 创建分布式策略 - 使用项目验证的TPU初始化代码
print(f"\n🔄 使用项目标准TPU初始化...")
try:
# 使用项目里验证过的TPU初始化代码
# 禁用GPU避免冲突
try:
tf.config.set_visible_devices([], 'GPU')
print("🚫 GPU已禁用避免CUDA冲突")
except:
pass
# 使用标准的TPU初始化流程
print("🚀 使用官方TensorFlow TPU初始化...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
# 验证TPU设备
tpu_devices_check = tf.config.list_logical_devices('TPU')
print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备")
# 创建TPU策略
strategy = tf.distribute.TPUStrategy(resolver)
print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本")
use_distributed = True
except Exception as e:
print(f"⚠️ 分布式策略失败: {str(e)[:80]}")
print(" 将使用单设备模拟")
use_distributed = False
# 4. 模拟Brain-to-Text训练场景
print(f"\n🧠 模拟Brain-to-Text训练场景...")
if use_distributed:
# 分布式训练模拟
with strategy.scope():
print("📦 创建分布式模型参数...")
# 创建接近真实Brain-to-Text模型的参数 (修复维度匹配)
model_components = {
# GRU层权重第一层接收512维输入后续层接收256维
'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'),
'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'),
'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'),
'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'),
# 添加day-specific层模拟 (输入512维输出512维)
'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)]
}
# 检查模型加载后内存
after_model = get_detailed_memory_snapshot()
model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心")
# 训练循环模拟
print(f"\n🔄 开始训练循环监控...")
for step in range(10):
step_start_time = time.time()
@tf.function
def distributed_training_step():
# 模拟真实训练数据大小
batch_size = 32
seq_length = 1000
features = 512
# 输入数据
neural_data = tf.random.normal([batch_size, seq_length, features])
targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32)
# 模拟前向传播
x = neural_data
# Day-specific transformation (简化版本避免复杂的维度操作)
# 模拟day-specific变换对每个时间步应用相同变换
day_weight = model_components['day_weights'][0] # 简化使用第一个day权重
# 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512]
x = tf.matmul(x, day_weight)
# 为CTC损失添加目标使用模拟
target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1)
# 简化的CTC相关计算
batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32))
# GRU layers simulation
for i in range(3):
layer_name = f'gru_layer_{i}'
weight = model_components[layer_name]
# 处理张量维度第一层从3D输入后续层从2D输入
if i == 0:
# 第一层:取最后时间步 [batch, seq, features] -> [batch, features]
if len(x.shape) == 3:
x = x[:, -1, :] # 取最后时间步
x = tf.nn.tanh(tf.matmul(x, weight))
else:
# 后续层直接处理2D张量 [batch, features] -> [batch, features]
x = tf.nn.tanh(tf.matmul(x, weight))
# 输出投影
logits = tf.matmul(x, model_components['output_projection'])
# CTC loss模拟使用batch_loss_weight作为权重
base_loss = tf.reduce_mean(tf.square(logits))
loss = base_loss * batch_loss_weight
return loss
# 执行训练步骤
per_replica_loss = strategy.run(distributed_training_step)
# 聚合分布式结果
loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None)
step_duration = time.time() - step_start_time
# 获取当前内存状态
current_snapshot = get_detailed_memory_snapshot()
step_memory = current_snapshot['summary']['total_current']
memory_delta = step_memory - baseline_snapshot['summary']['total_current']
# 显示详细训练状态
active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)"
print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, "
f"时间={step_duration:.3f}s, "
f"内存={step_memory}MB(+{memory_delta}), "
f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}")
# 每5步显示峰值内存
if step % 5 == 0:
peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB"
print(f" {peak_info}")
time.sleep(0.2) # 短暂暂停观察
else:
# 单设备训练模拟(改进版)
print("🔸 单设备训练模拟...")
with tf.device(tpu_devices[0].name):
# 创建较小的模型参数
simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net')
for step in range(8):
step_start = time.time()
# 创建较大的数据批次
batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size
# 模拟计算密集型操作
@tf.function
def compute_step():
x = tf.reshape(batch_data, [-1, 512])
result = tf.matmul(x, simple_weights)
result = tf.nn.relu(result)
return tf.reduce_mean(result)
result = compute_step()
step_duration = time.time() - step_start
# 获取内存状态
snapshot = get_detailed_memory_snapshot()
memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f" Step {step}: result={result.numpy():.4f}, "
f"时间={step_duration:.3f}s, "
f"内存变化=+{memory_change}MB, "
f"峰值={snapshot['summary']['total_peak']}MB")
# 5. 最终分析报告
final_snapshot = get_detailed_memory_snapshot()
total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
peak_usage = final_snapshot['summary']['total_peak']
print(f"\n📈 训练监控报告:")
print(f" 总内存增长: +{total_growth}MB")
print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)")
print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}")
# 各核心最终状态
print(f" 各核心最终状态:")
has_changes = False
for i in range(len(tpu_devices)):
final_core = final_snapshot[f'core_{i}']
baseline_core = baseline_snapshot[f'core_{i}']
current_change = final_core['current'] - baseline_core['current']
peak_change = final_core['peak'] - baseline_core['peak']
if current_change != 0 or peak_change != 0:
has_changes = True
print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})")
if not has_changes:
print(f" 所有核心内存无明显变化")
# 分布式使用分析
if final_snapshot['summary']['active_cores'] == 1:
print(f"\n⚠️ 分布式问题诊断:")
print(f" 只有1个核心活跃其他7个核心空闲")
print(f" 可能原因: TPU策略配置问题或模型未正确分布")
print(f" 建议: 检查分布式策略和模型分片")
elif final_snapshot['summary']['active_cores'] > 4:
print(f"\n✅ 分布式状态良好:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常")
else:
print(f"\n🟡 分布式部分工作:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡")
print("✅ TPU训练监控完成")
except Exception as e:
print(f"❌ 训练监控失败: {e}")
import traceback
print(f"详细错误: {traceback.format_exc()[:300]}")
if __name__ == "__main__":
print("🚀 TPU训练内存监控工具")
print("专注于训练过程中的实时内存和性能监控")
print("适用于TPU v5e-8环境")
print()
monitor_tpu_during_training()
print(f"\n🎯 监控要点总结:")
print(f" 1. 确认所有8个TPU核心是否活跃")
print(f" 2. 监控内存增长模式和峰值使用")
print(f" 3. 检测MXU计算性能和效率")
print(f" 4. 验证分布式策略是否正常工作")
print(f" 5. 识别可能的内存泄漏或性能瓶颈")