Files
b2txt25/model_training_nnn_tpu/quick_tpu_test.py
2025-10-15 15:14:01 +08:00

129 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行
"""
import os
import time
import torch
import torch.nn as nn
# 设置环境变量
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true'
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
def quick_test():
"""快速测试TPU是否工作正常"""
print("🚀 开始快速TPU测试...")
try:
# 获取TPU设备
device = xm.xla_device()
print(f"📱 TPU设备: {device}")
# 创建简单模型
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.GRU(256, 128, batch_first=True),
nn.Linear(128, 41)
).to(device)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 创建测试数据
x = torch.randn(8, 50, 512, device=device)
print(f"📥 输入形状: {x.shape}")
# 测试前向传播
print("🔄 测试前向传播...")
start_time = time.time()
with torch.no_grad():
if hasattr(model, '__getitem__'):
# 对于Sequential模型手动处理GRU层
x_proj = model[1](model[0](x)) # Linear + ReLU
gru_out, _ = model[2](x_proj) # GRU
output = model[3](gru_out) # Final Linear
else:
output = model(x)
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
forward_time = time.time() - start_time
print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}")
print(f"📤 输出形状: {output.shape}")
# 测试反向传播
print("🔄 测试反向传播...")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
start_time = time.time()
# 创建虚拟标签
labels = torch.randint(0, 41, (8, 50), device=device)
criterion = nn.CrossEntropyLoss()
# 前向传播
if hasattr(model, '__getitem__'):
x_proj = model[1](model[0](x))
gru_out, _ = model[2](x_proj)
output = model[3](gru_out)
else:
output = model(x)
# 计算损失
loss = criterion(output.view(-1, 41), labels.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
backward_time = time.time() - start_time
print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}")
print(f"🎯 损失值: {loss.item():.4f}")
# 总结
print(f"\n📈 性能总结:")
print(f" 前向传播: {forward_time:.3f}")
print(f" 反向传播: {backward_time:.3f}")
print(f" 总计: {forward_time + backward_time:.3f}")
if (forward_time + backward_time) < 10: # 10秒内完成
print("✅ TPU测试通过! 可以进行完整训练")
return True
else:
print("⚠️ TPU性能较慢可能需要优化")
return False
except Exception as e:
print(f"❌ TPU测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("=" * 50)
print("⚡ 快速TPU测试")
print("=" * 50)
success = quick_test()
if success:
print("\n🎉 测试成功! 现在可以运行:")
print(" python simple_tpu_model.py")
else:
print("\n❌ 测试失败请检查TPU配置")
print("=" * 50)