Files
b2txt25/model_training_nnn_tpu/tpu_memory_monitor.py
2025-10-16 13:39:05 +08:00

236 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
TPU内存监控工具 - 专门用于训练过程
解决tf.config.experimental.get_memory_info()在TPU上无法工作的问题
"""
import tensorflow as tf
import time
import psutil
import os
class TPUMemoryMonitor:
"""TPU内存监控类"""
def __init__(self):
self.tpu_devices = tf.config.list_logical_devices('TPU')
self.baseline_memory = None
self.peak_allocations = {}
def get_tpu_status(self) -> str:
"""获取TPU状态 - 实用版本不依赖get_memory_info"""
try:
if not self.tpu_devices:
return "TPU: No devices"
num_cores = len(self.tpu_devices)
# 测试TPU响应性
try:
with tf.device('/TPU:0'):
test_tensor = tf.constant([1.0, 2.0, 3.0])
result = tf.reduce_sum(test_tensor)
_ = result.numpy() # 强制执行
activity = "active"
except Exception:
activity = "inactive"
# 获取主机内存作为参考
try:
memory = psutil.virtual_memory()
host_mem = f"Host:{memory.percent:.1f}%"
except:
host_mem = "Host:unknown"
return f"TPU: {num_cores}cores {activity} {host_mem}"
except Exception as e:
return f"TPU: error({str(e)[:20]})"
def estimate_tensor_memory(self, tensor_shape, dtype=tf.float32):
"""估算张量内存使用量"""
if dtype == tf.float32:
bytes_per_element = 4
elif dtype == tf.float16 or dtype == tf.bfloat16:
bytes_per_element = 2
elif dtype == tf.int32:
bytes_per_element = 4
elif dtype == tf.int64:
bytes_per_element = 8
else:
bytes_per_element = 4 # 默认
total_elements = 1
for dim in tensor_shape:
total_elements *= dim
total_bytes = total_elements * bytes_per_element
return total_bytes / (1024 * 1024) # 返回MB
def track_allocation(self, name: str, tensor_shape, dtype=tf.float32):
"""跟踪内存分配"""
mb = self.estimate_tensor_memory(tensor_shape, dtype)
self.peak_allocations[name] = self.peak_allocations.get(name, 0) + mb
return mb
def get_allocation_summary(self) -> str:
"""获取分配汇总"""
if not self.peak_allocations:
return "No allocations tracked"
total_mb = sum(self.peak_allocations.values())
top_3 = sorted(self.peak_allocations.items(), key=lambda x: x[1], reverse=True)[:3]
summary = f"Tracked:{total_mb:.1f}MB "
summary += f"Top:({top_3[0][0]}:{top_3[0][1]:.1f}MB)"
return summary
def test_memory_allocation_across_cores(self):
"""测试8个核心的内存分配"""
print("🧪 测试所有TPU核心内存分配")
print("=" * 40)
allocations_per_core = []
for i, device in enumerate(self.tpu_devices):
print(f"核心 {i+1}: {device.name}")
try:
with tf.device(device.name):
# 创建不同大小的测试张量
test_sizes = [
([1000, 1000], "1K×1K"),
([3000, 3000], "3K×3K"),
([5000, 5000], "5K×5K"),
([7000, 7000], "7K×7K"),
]
core_total = 0
successful_allocs = []
for shape, desc in test_sizes:
try:
tensor = tf.ones(shape, dtype=tf.float32)
mb = self.estimate_tensor_memory(shape)
core_total += mb
successful_allocs.append(f"{desc}({mb:.1f}MB)")
# 实际使用张量防止被优化
_ = tf.reduce_mean(tensor)
except Exception as e:
print(f" {desc} 失败: {str(e)[:30]}")
break
allocations_per_core.append(core_total)
print(f" 成功分配: {' + '.join(successful_allocs)}")
print(f" 核心总计: {core_total:.1f}MB")
except Exception as e:
print(f" 核心{i+1}失败: {e}")
allocations_per_core.append(0)
# 汇总结果
total_all_cores = sum(allocations_per_core)
avg_per_core = total_all_cores / len(self.tpu_devices) if self.tpu_devices else 0
print(f"\n📊 汇总结果:")
print(f" 总分配: {total_all_cores:.1f}MB ({total_all_cores/1024:.2f}GB)")
print(f" 平均每核: {avg_per_core:.1f}MB ({avg_per_core/1024:.2f}GB)")
# 推测内存配置
if avg_per_core > 8000: # > 8GB
print(" 推测: 每核心≥16GB (高端配置)")
elif avg_per_core > 4000: # > 4GB
print(" 推测: 每核心8-16GB (标准配置)")
elif avg_per_core > 1000: # > 1GB
print(" 推测: 每核心2-8GB (受限或共享)")
else:
print(" 推测: 每核心<2GB (严重受限)")
return allocations_per_core
def test_training_memory_pattern():
"""测试模拟训练的内存模式"""
print("\n🏋️ 模拟训练内存模式测试")
print("=" * 30)
monitor = TPUMemoryMonitor()
# 模拟典型的brain-to-text模型内存使用
with tf.device('/TPU:0'):
print("创建模拟模型组件...")
# 1. 输入数据 (batch_size=32, seq_len=1000, features=512)
batch_size, seq_len, features = 32, 1000, 512
input_data = tf.random.normal([batch_size, seq_len, features])
input_mb = monitor.track_allocation("input_data", [batch_size, seq_len, features])
print(f" 输入数据: {input_mb:.1f}MB")
# 2. GRU权重 (假设3层, 每层256单元)
n_layers, n_units = 3, 256
for layer in range(n_layers):
# GRU有3个门每个门需要权重矩阵
weight_shape = [features if layer == 0 else n_units, n_units * 3]
weights = tf.random.normal(weight_shape)
weight_mb = monitor.track_allocation(f"gru_layer_{layer}", weight_shape)
print(f" GRU层{layer+1}权重: {weight_mb:.1f}MB")
# 3. 输出投影层 (n_units -> n_classes=41)
n_classes = 41
output_weights = tf.random.normal([n_units, n_classes])
output_mb = monitor.track_allocation("output_projection", [n_units, n_classes])
print(f" 输出投影: {output_mb:.1f}MB")
# 4. 中间激活值 (前向传播)
hidden_states = tf.random.normal([batch_size, seq_len, n_units])
hidden_mb = monitor.track_allocation("hidden_states", [batch_size, seq_len, n_units])
print(f" 隐藏状态: {hidden_mb:.1f}MB")
# 5. 梯度 (反向传播时会翻倍内存)
total_params_mb = sum([v for k, v in monitor.peak_allocations.items() if 'layer' in k or 'projection' in k])
gradient_mb = total_params_mb # 梯度内存约等于参数内存
print(f" 梯度内存: {gradient_mb:.1f}MB (估算)")
print(f"\n模型总内存估算: {monitor.get_allocation_summary()}")
# 实际执行一些操作确保内存被分配
result = tf.reduce_mean(input_data) + tf.reduce_mean(hidden_states)
print(f"验证计算结果: {result.numpy():.4f}")
if __name__ == "__main__":
print("🚀 TPU内存监控工具启动")
monitor = TPUMemoryMonitor()
# 基础状态检查
print(f"当前TPU状态: {monitor.get_tpu_status()}")
# 测试所有核心
print("\n" + "="*50)
core_allocations = monitor.test_memory_allocation_across_cores()
# 训练内存模式测试
print("\n" + "="*50)
test_training_memory_pattern()
print(f"\n🎯 关键发现:")
if core_allocations:
max_core = max(core_allocations)
min_core = min([x for x in core_allocations if x > 0])
print(f" 最大单核分配: {max_core:.1f}MB")
print(f" 最小单核分配: {min_core:.1f}MB")
if max_core > 9000: # 你之前测试到9.4GB
print(" ✅ 内存充足,可支持大模型训练")
elif max_core > 5000:
print(" ⚠️ 内存中等,建议优化模型大小")
else:
print(" ❌ 内存不足,需要大幅减少模型参数")
print(f"\n💡 针对你的训练卡顿问题:")
print(f" - SetPriority错误通常是XLA编译问题不是内存问题")
print(f" - 你的9.4GB测试说明TPU内存工作正常")
print(f" - 建议检查模型是否有导致XLA编译卡顿的操作")
print(f" - 考虑使用更简单的操作或关闭某些XLA优化")