162 lines
4.8 KiB
Python
162 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
简化模型测试脚本 - 验证XLA编译是否正常工作
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# 设置XLA环境变量(必须在导入torch_xla之前)
|
||
os.environ['XLA_FLAGS'] = (
|
||
'--xla_cpu_multi_thread_eigen=true '
|
||
'--xla_cpu_enable_fast_math=true '
|
||
f'--xla_force_host_platform_device_count={os.cpu_count()}'
|
||
)
|
||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
|
||
os.environ['XLA_USE_BF16'] = '1'
|
||
|
||
print(f"🔧 XLA环境变量设置:")
|
||
print(f" CPU核心数: {os.cpu_count()}")
|
||
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||
|
||
import torch_xla.core.xla_model as xm
|
||
|
||
class SimpleModel(nn.Module):
|
||
"""简化的测试模型"""
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.linear1 = nn.Linear(512, 256)
|
||
self.gru = nn.GRU(256, 128, batch_first=True)
|
||
self.linear2 = nn.Linear(128, 41) # 41个音素类别
|
||
|
||
def forward(self, x):
|
||
x = torch.relu(self.linear1(x))
|
||
x, _ = self.gru(x)
|
||
x = self.linear2(x)
|
||
return x
|
||
|
||
def test_xla_compilation():
|
||
"""测试XLA编译速度"""
|
||
print("\n🚀 开始简化模型XLA编译测试...")
|
||
|
||
# 检查TPU设备
|
||
device = xm.xla_device()
|
||
print(f"📱 TPU设备: {device}")
|
||
print(f"🌍 TPU World Size: {xm.xrt_world_size()}")
|
||
|
||
# 创建简化模型
|
||
model = SimpleModel().to(device)
|
||
print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
||
|
||
# 创建测试数据
|
||
batch_size = 8 # 小批次
|
||
seq_len = 100 # 短序列
|
||
x = torch.randn(batch_size, seq_len, 512, device=device)
|
||
|
||
print(f"📥 输入形状: {x.shape}")
|
||
|
||
# 首次前向传播 - 触发XLA编译
|
||
print(f"🔄 开始首次前向传播 (XLA编译)...")
|
||
start_time = time.time()
|
||
|
||
with torch.no_grad():
|
||
output = model(x)
|
||
|
||
compile_time = time.time() - start_time
|
||
print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}秒")
|
||
print(f"📤 输出形状: {output.shape}")
|
||
|
||
# 再次前向传播 - 使用编译后的图
|
||
print(f"🔄 第二次前向传播 (使用编译后的图)...")
|
||
start_time = time.time()
|
||
|
||
with torch.no_grad():
|
||
output2 = model(x)
|
||
|
||
execution_time = time.time() - start_time
|
||
print(f"⚡ 执行完成! 耗时: {execution_time:.4f}秒")
|
||
|
||
# 性能对比
|
||
speedup = compile_time / execution_time if execution_time > 0 else float('inf')
|
||
print(f"\n📈 性能分析:")
|
||
print(f" 编译时间: {compile_time:.2f}秒")
|
||
print(f" 执行时间: {execution_time:.4f}秒")
|
||
print(f" 加速比: {speedup:.1f}x")
|
||
|
||
if compile_time < 60: # 1分钟内编译完成
|
||
print("✅ XLA编译正常!")
|
||
return True
|
||
else:
|
||
print("❌ XLA编译过慢,可能有问题")
|
||
return False
|
||
|
||
def test_training_step():
|
||
"""测试训练步骤"""
|
||
print("\n🎯 测试简化训练步骤...")
|
||
|
||
device = xm.xla_device()
|
||
model = SimpleModel().to(device)
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
criterion = nn.CrossEntropyLoss()
|
||
|
||
# 创建训练数据
|
||
x = torch.randn(4, 50, 512, device=device)
|
||
labels = torch.randint(0, 41, (4, 50), device=device)
|
||
|
||
print(f"🔄 开始训练步骤 (包含反向传播)...")
|
||
start_time = time.time()
|
||
|
||
# 前向传播
|
||
outputs = model(x)
|
||
|
||
# 计算损失
|
||
loss = criterion(outputs.view(-1, 41), labels.view(-1))
|
||
|
||
# 反向传播
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
step_time = time.time() - start_time
|
||
print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}")
|
||
|
||
return step_time < 120 # 2分钟内完成
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print("🧪 XLA编译快速测试")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 测试1: 简单模型编译
|
||
compilation_ok = test_xla_compilation()
|
||
|
||
if compilation_ok:
|
||
# 测试2: 训练步骤
|
||
training_ok = test_training_step()
|
||
|
||
if training_ok:
|
||
print("\n✅ 所有测试通过! 可以尝试完整模型训练")
|
||
print("💡 建议:")
|
||
print(" 1. 确保有足够内存 (32GB+)")
|
||
print(" 2. 减小batch_size (比如从32改为16)")
|
||
print(" 3. 使用gradient_accumulation_steps补偿")
|
||
else:
|
||
print("\n⚠️ 训练步骤较慢,建议优化")
|
||
else:
|
||
print("\n❌ XLA编译有问题,需要检查环境")
|
||
|
||
except Exception as e:
|
||
print(f"\n💥 测试失败: {e}")
|
||
print("💡 可能的问题:")
|
||
print(" - TPU资源不可用")
|
||
print(" - PyTorch XLA安装问题")
|
||
print(" - 内存不足")
|
||
|
||
print("=" * 60)
|
||
|
||
if __name__ == "__main__":
|
||
main() |