Files
b2txt25/model_training_nnn_tpu/test_simple_model.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

162 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简化模型测试脚本 - 验证XLA编译是否正常工作
"""
import os
import time
import torch
import torch.nn as nn
# 设置XLA环境变量必须在导入torch_xla之前
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
os.environ['XLA_USE_BF16'] = '1'
print(f"🔧 XLA环境变量设置:")
print(f" CPU核心数: {os.cpu_count()}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
import torch_xla.core.xla_model as xm
class SimpleModel(nn.Module):
"""简化的测试模型"""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(512, 256)
self.gru = nn.GRU(256, 128, batch_first=True)
self.linear2 = nn.Linear(128, 41) # 41个音素类别
def forward(self, x):
x = torch.relu(self.linear1(x))
x, _ = self.gru(x)
x = self.linear2(x)
return x
def test_xla_compilation():
"""测试XLA编译速度"""
print("\n🚀 开始简化模型XLA编译测试...")
# 检查TPU设备
device = xm.xla_device()
print(f"📱 TPU设备: {device}")
print(f"🌍 TPU World Size: {xm.xrt_world_size()}")
# 创建简化模型
model = SimpleModel().to(device)
print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建测试数据
batch_size = 8 # 小批次
seq_len = 100 # 短序列
x = torch.randn(batch_size, seq_len, 512, device=device)
print(f"📥 输入形状: {x.shape}")
# 首次前向传播 - 触发XLA编译
print(f"🔄 开始首次前向传播 (XLA编译)...")
start_time = time.time()
with torch.no_grad():
output = model(x)
compile_time = time.time() - start_time
print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}")
print(f"📤 输出形状: {output.shape}")
# 再次前向传播 - 使用编译后的图
print(f"🔄 第二次前向传播 (使用编译后的图)...")
start_time = time.time()
with torch.no_grad():
output2 = model(x)
execution_time = time.time() - start_time
print(f"⚡ 执行完成! 耗时: {execution_time:.4f}")
# 性能对比
speedup = compile_time / execution_time if execution_time > 0 else float('inf')
print(f"\n📈 性能分析:")
print(f" 编译时间: {compile_time:.2f}")
print(f" 执行时间: {execution_time:.4f}")
print(f" 加速比: {speedup:.1f}x")
if compile_time < 60: # 1分钟内编译完成
print("✅ XLA编译正常!")
return True
else:
print("❌ XLA编译过慢可能有问题")
return False
def test_training_step():
"""测试训练步骤"""
print("\n🎯 测试简化训练步骤...")
device = xm.xla_device()
model = SimpleModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 创建训练数据
x = torch.randn(4, 50, 512, device=device)
labels = torch.randint(0, 41, (4, 50), device=device)
print(f"🔄 开始训练步骤 (包含反向传播)...")
start_time = time.time()
# 前向传播
outputs = model(x)
# 计算损失
loss = criterion(outputs.view(-1, 41), labels.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
step_time = time.time() - start_time
print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}")
return step_time < 120 # 2分钟内完成
def main():
print("=" * 60)
print("🧪 XLA编译快速测试")
print("=" * 60)
try:
# 测试1: 简单模型编译
compilation_ok = test_xla_compilation()
if compilation_ok:
# 测试2: 训练步骤
training_ok = test_training_step()
if training_ok:
print("\n✅ 所有测试通过! 可以尝试完整模型训练")
print("💡 建议:")
print(" 1. 确保有足够内存 (32GB+)")
print(" 2. 减小batch_size (比如从32改为16)")
print(" 3. 使用gradient_accumulation_steps补偿")
else:
print("\n⚠️ 训练步骤较慢,建议优化")
else:
print("\n❌ XLA编译有问题需要检查环境")
except Exception as e:
print(f"\n💥 测试失败: {e}")
print("💡 可能的问题:")
print(" - TPU资源不可用")
print(" - PyTorch XLA安装问题")
print(" - 内存不足")
print("=" * 60)
if __name__ == "__main__":
main()