tpu
This commit is contained in:
78
model_training_nnn_tpu/jupyter_xla_test.py
Normal file
78
model_training_nnn_tpu/jupyter_xla_test.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# ====================
|
||||
# 单元格3: 快速XLA编译测试
|
||||
# ====================
|
||||
|
||||
# 简化测试模型
|
||||
class QuickTestModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(512, 128)
|
||||
self.gru = nn.GRU(128, 64, batch_first=True)
|
||||
self.linear2 = nn.Linear(64, 41)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.linear1(x))
|
||||
x, _ = self.gru(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
print("🧪 开始XLA编译快速测试...")
|
||||
|
||||
# 启动监控
|
||||
compilation_monitor.start_monitoring()
|
||||
|
||||
try:
|
||||
# 获取TPU设备
|
||||
device = xm.xla_device()
|
||||
|
||||
# 创建小模型
|
||||
model = QuickTestModel().to(device)
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
print(f"📊 测试模型参数: {param_count:,}")
|
||||
|
||||
# 创建测试数据 (很小的batch)
|
||||
x = torch.randn(2, 20, 512, device=device)
|
||||
print(f"📥 输入数据形状: {x.shape}")
|
||||
|
||||
print("🔄 开始首次前向传播 (触发XLA编译)...")
|
||||
|
||||
# 首次前向传播 - 这会触发XLA编译
|
||||
with torch.no_grad():
|
||||
start_compile = time.time()
|
||||
output = model(x)
|
||||
compile_time = time.time() - start_compile
|
||||
|
||||
print(f"✅ XLA编译完成!")
|
||||
print(f"📤 输出形状: {output.shape}")
|
||||
|
||||
# 完成监控
|
||||
compilation_monitor.complete_monitoring()
|
||||
|
||||
# 测试编译后的性能
|
||||
print("\n🚀 测试编译后的执行速度...")
|
||||
with torch.no_grad():
|
||||
start_exec = time.time()
|
||||
for _ in range(10):
|
||||
output = model(x)
|
||||
avg_exec_time = (time.time() - start_exec) / 10
|
||||
|
||||
print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
|
||||
|
||||
# 性能评估
|
||||
if compile_time < 30:
|
||||
print("✅ 编译速度优秀! 可以尝试完整模型")
|
||||
test_result = "excellent"
|
||||
elif compile_time < 120:
|
||||
print("✅ 编译速度良好! 建议使用简化配置")
|
||||
test_result = "good"
|
||||
else:
|
||||
print("⚠️ 编译速度较慢,建议进一步优化")
|
||||
test_result = "slow"
|
||||
|
||||
except Exception as e:
|
||||
compilation_monitor.complete_monitoring()
|
||||
print(f"❌ 测试失败: {e}")
|
||||
test_result = "failed"
|
||||
|
||||
print(f"\n📋 测试结果: {test_result}")
|
||||
print("💡 如果测试通过,可以运行下一个单元格进行完整训练")
|
Reference in New Issue
Block a user