This commit is contained in:
Zchen
2025-10-15 15:14:01 +08:00
parent 7bdfc0d257
commit 082018cd46
6 changed files with 496 additions and 500 deletions

View File

@@ -1,93 +0,0 @@
# ====================
# 单元格: XLA版本兼容性检查和修复
# ====================
import torch
import torch.nn as nn
print("🔧 PyTorch XLA版本兼容性检查...")
# 导入XLA
import torch_xla.core.xla_model as xm
print("✅ PyTorch XLA导入成功!")
# 定义兼容性函数
def get_xla_world_size():
"""获取XLA world size兼容不同版本"""
try:
return xm.xrt_world_size()
except AttributeError:
try:
return xm.get_world_size()
except AttributeError:
return 1 # 默认返回1
def get_xla_ordinal():
"""获取XLA ordinal兼容不同版本"""
try:
return xm.get_ordinal()
except AttributeError:
return 0 # 默认返回0
def xla_mark_step():
"""XLA mark step兼容不同版本"""
try:
xm.mark_step()
except AttributeError:
try:
xm.wait_device_ops()
except AttributeError:
pass # 如果都不可用,则跳过
def check_xla_device():
"""检查XLA设备状态"""
try:
device = xm.xla_device()
print(f"📱 XLA设备: {device}")
world_size = get_xla_world_size()
ordinal = get_xla_ordinal()
print(f"🌍 World Size: {world_size}")
print(f"🔢 Ordinal: {ordinal}")
# 检测设备类型
device_str = str(device)
if 'xla' in device_str and 'cpu' not in device_str:
print("✅ 检测到TPU设备")
return True, "TPU"
elif 'xla' in device_str and 'cpu' in device_str:
print("⚠️ XLA CPU模拟模式")
return True, "XLA_CPU"
else:
print("❌ 未检测到XLA设备")
return False, "CPU"
except Exception as e:
print(f"❌ XLA设备检查失败: {e}")
return False, "ERROR"
# 执行兼容性检查
device_available, device_type = check_xla_device()
if device_available:
print(f"✅ XLA环境正常设备类型: {device_type}")
# 测试基本XLA操作
print("🧪 测试基本XLA操作...")
try:
device = xm.xla_device()
x = torch.randn(2, 2, device=device)
y = torch.matmul(x, x)
# 测试同步函数
xla_mark_step()
print("✅ 基本XLA操作测试成功")
except Exception as e:
print(f"❌ XLA操作测试失败: {e}")
else:
print("❌ XLA环境不可用")
print("\n💡 兼容性检查完成,可以运行后续单元格")

View File

@@ -1,167 +0,0 @@
# ====================
# 单元格2: XLA编译进度监控
# ====================
import torch
import torch.nn as nn
import time
import threading
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
# 导入XLA (环境变量已在单元格1中设置)
print("🚀 导入PyTorch XLA...")
import torch_xla.core.xla_model as xm
print(f"✅ XLA导入成功!")
print(f" TPU设备: {xm.xla_device()}")
# 兼容新版本PyTorch XLA
try:
world_size = xm.xrt_world_size()
print(f" World Size (旧API): {world_size}")
except AttributeError:
try:
world_size = xm.get_world_size()
print(f" World Size (新API): {world_size}")
except AttributeError:
print(" World Size: 无法获取 (可能在CPU模式)")
# 检查XLA版本兼容性
print("🔍 检查XLA API兼容性:")
api_available = []
api_deprecated = []
# 检查各种API
test_apis = [
('xrt_world_size', 'xrt_world_size()'),
('get_world_size', 'get_world_size()'),
('mark_step', 'mark_step()'),
('wait_device_ops', 'wait_device_ops()'),
('get_ordinal', 'get_ordinal()'),
('xla_device_count', 'xla_device_count()')
]
for api_name, api_desc in test_apis:
if hasattr(xm, api_name):
api_available.append(api_desc)
else:
api_deprecated.append(api_desc)
if api_available:
print(f" ✅ 可用API: {', '.join(api_available)}")
if api_deprecated:
print(f" ❌ 不可用API: {', '.join(api_deprecated)}")
# 创建编译进度监控器
class JupyterCompilationMonitor:
def __init__(self):
self.start_time = None
self.is_monitoring = False
# 创建输出widget
self.output_widget = widgets.Output()
# 创建进度条
self.progress_bar = widgets.IntProgress(
value=0,
min=0,
max=100,
description='XLA编译:',
bar_style='info',
style={'bar_color': '#1f77b4'},
orientation='horizontal'
)
# 创建状态标签
self.status_label = widgets.HTML(
value="<b>准备开始编译...</b>"
)
# 创建CPU使用率显示
self.cpu_label = widgets.HTML(
value="CPU: ---%"
)
self.memory_label = widgets.HTML(
value="内存: ---%"
)
# 组合界面
self.monitor_box = widgets.VBox([
widgets.HTML("<h3>🔄 XLA编译监控</h3>"),
self.progress_bar,
self.status_label,
widgets.HBox([self.cpu_label, self.memory_label]),
self.output_widget
])
def start_monitoring(self):
"""开始监控"""
self.start_time = time.time()
self.is_monitoring = True
display(self.monitor_box)
# 启动监控线程
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self.monitor_thread.start()
def _monitor_loop(self):
"""监控循环"""
while self.is_monitoring:
try:
elapsed = time.time() - self.start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
# 更新进度条 (模拟进度)
progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
self.progress_bar.value = progress
# 获取系统资源
cpu_percent = psutil.cpu_percent(interval=0.1)
memory_percent = psutil.virtual_memory().percent
# 更新显示
self.status_label.value = f"<b>编译进行中... ⏱️ {minutes:02d}:{seconds:02d}</b>"
self.cpu_label.value = f"<b>🖥️ CPU: {cpu_percent:5.1f}%</b>"
self.memory_label.value = f"<b>💾 内存: {memory_percent:5.1f}%</b>"
# 检测是否编译完成 (CPU使用率突然下降)
if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
self.complete_monitoring()
break
time.sleep(1)
except Exception as e:
with self.output_widget:
print(f"监控错误: {e}")
break
def complete_monitoring(self):
"""完成监控"""
if self.is_monitoring:
self.is_monitoring = False
elapsed = time.time() - self.start_time
self.progress_bar.value = 100
self.progress_bar.bar_style = 'success'
self.status_label.value = f"<b style='color: green'>✅ 编译完成! 总耗时: {elapsed:.2f}秒</b>"
with self.output_widget:
print(f"\n🎉 XLA编译成功完成!")
print(f"⏱️ 总耗时: {elapsed:.2f}")
if elapsed < 60:
print("✅ 编译速度正常")
elif elapsed < 300:
print("⚠️ 编译稍慢,但可接受")
else:
print("❌ 编译过慢,建议检查设置")
# 创建全局监控器
compilation_monitor = JupyterCompilationMonitor()
print("✅ 编译监控器已准备就绪!")
print("💡 运行下一个单元格开始XLA编译测试")

View File

@@ -1,78 +0,0 @@
# ====================
# 单元格3: 快速XLA编译测试
# ====================
# 简化测试模型
class QuickTestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(512, 128)
self.gru = nn.GRU(128, 64, batch_first=True)
self.linear2 = nn.Linear(64, 41)
def forward(self, x):
x = torch.relu(self.linear1(x))
x, _ = self.gru(x)
x = self.linear2(x)
return x
print("🧪 开始XLA编译快速测试...")
# 启动监控
compilation_monitor.start_monitoring()
try:
# 获取TPU设备
device = xm.xla_device()
# 创建小模型
model = QuickTestModel().to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"📊 测试模型参数: {param_count:,}")
# 创建测试数据 (很小的batch)
x = torch.randn(2, 20, 512, device=device)
print(f"📥 输入数据形状: {x.shape}")
print("🔄 开始首次前向传播 (触发XLA编译)...")
# 首次前向传播 - 这会触发XLA编译
with torch.no_grad():
start_compile = time.time()
output = model(x)
compile_time = time.time() - start_compile
print(f"✅ XLA编译完成!")
print(f"📤 输出形状: {output.shape}")
# 完成监控
compilation_monitor.complete_monitoring()
# 测试编译后的性能
print("\n🚀 测试编译后的执行速度...")
with torch.no_grad():
start_exec = time.time()
for _ in range(10):
output = model(x)
avg_exec_time = (time.time() - start_exec) / 10
print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
# 性能评估
if compile_time < 30:
print("✅ 编译速度优秀! 可以尝试完整模型")
test_result = "excellent"
elif compile_time < 120:
print("✅ 编译速度良好! 建议使用简化配置")
test_result = "good"
else:
print("⚠️ 编译速度较慢,建议进一步优化")
test_result = "slow"
except Exception as e:
compilation_monitor.complete_monitoring()
print(f"❌ 测试失败: {e}")
test_result = "failed"
print(f"\n📋 测试结果: {test_result}")
print("💡 如果测试通过,可以运行下一个单元格进行完整训练")

View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
"""
快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行
"""
import os
import time
import torch
import torch.nn as nn
# 设置环境变量
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true'
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
def quick_test():
"""快速测试TPU是否工作正常"""
print("🚀 开始快速TPU测试...")
try:
# 获取TPU设备
device = xm.xla_device()
print(f"📱 TPU设备: {device}")
# 创建简单模型
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.GRU(256, 128, batch_first=True),
nn.Linear(128, 41)
).to(device)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 创建测试数据
x = torch.randn(8, 50, 512, device=device)
print(f"📥 输入形状: {x.shape}")
# 测试前向传播
print("🔄 测试前向传播...")
start_time = time.time()
with torch.no_grad():
if hasattr(model, '__getitem__'):
# 对于Sequential模型手动处理GRU层
x_proj = model[1](model[0](x)) # Linear + ReLU
gru_out, _ = model[2](x_proj) # GRU
output = model[3](gru_out) # Final Linear
else:
output = model(x)
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
forward_time = time.time() - start_time
print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}")
print(f"📤 输出形状: {output.shape}")
# 测试反向传播
print("🔄 测试反向传播...")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
start_time = time.time()
# 创建虚拟标签
labels = torch.randint(0, 41, (8, 50), device=device)
criterion = nn.CrossEntropyLoss()
# 前向传播
if hasattr(model, '__getitem__'):
x_proj = model[1](model[0](x))
gru_out, _ = model[2](x_proj)
output = model[3](gru_out)
else:
output = model(x)
# 计算损失
loss = criterion(output.view(-1, 41), labels.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
backward_time = time.time() - start_time
print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}")
print(f"🎯 损失值: {loss.item():.4f}")
# 总结
print(f"\n📈 性能总结:")
print(f" 前向传播: {forward_time:.3f}")
print(f" 反向传播: {backward_time:.3f}")
print(f" 总计: {forward_time + backward_time:.3f}")
if (forward_time + backward_time) < 10: # 10秒内完成
print("✅ TPU测试通过! 可以进行完整训练")
return True
else:
print("⚠️ TPU性能较慢可能需要优化")
return False
except Exception as e:
print(f"❌ TPU测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("=" * 50)
print("⚡ 快速TPU测试")
print("=" * 50)
success = quick_test()
if success:
print("\n🎉 测试成功! 现在可以运行:")
print(" python simple_tpu_model.py")
else:
print("\n❌ 测试失败请检查TPU配置")
print("=" * 50)

View File

@@ -0,0 +1,367 @@
#!/usr/bin/env python3
"""
简单TPU模型训练和测试脚本
基于大脑到文本数据的简化版本专门为TPU优化
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, Any, Tuple
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
class SimpleBrainToTextModel(nn.Module):
"""简化的大脑到文本模型 - TPU优化版本"""
def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3):
super().__init__()
# 输入处理层
self.input_proj = nn.Linear(input_features, hidden_size)
self.input_dropout = nn.Dropout(0.2)
# GRU层 - 使用较小的隐藏层以提高TPU效率
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=0.3 if num_layers > 1 else 0
)
# 输出层
self.output_proj = nn.Linear(hidden_size, num_classes)
# 初始化权重
self._init_weights()
def _init_weights(self):
"""初始化模型权重"""
for name, param in self.named_parameters():
if 'weight' in name:
if 'gru' in name:
nn.init.orthogonal_(param)
else:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)
def forward(self, x):
"""
前向传播
Args:
x: (batch_size, seq_len, input_features)
Returns:
logits: (batch_size, seq_len, num_classes)
"""
# 输入投影
x = torch.relu(self.input_proj(x))
x = self.input_dropout(x)
# GRU处理
output, _ = self.gru(x)
# 输出投影
logits = self.output_proj(output)
return logits
class SimpleDataGenerator:
"""简单的数据生成器 - 模拟大脑信号数据"""
def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41):
self.batch_size = batch_size
self.seq_len = seq_len
self.input_features = input_features
self.num_classes = num_classes
def generate_batch(self, device):
"""生成一个批次的模拟数据"""
# 生成模拟的神经信号数据
features = torch.randn(
self.batch_size, self.seq_len, self.input_features,
device=device, dtype=torch.float32
)
# 生成模拟的标签(音素序列)
labels = torch.randint(
0, self.num_classes,
(self.batch_size, self.seq_len),
device=device
)
# 生成序列长度
seq_lengths = torch.randint(
self.seq_len // 2, self.seq_len + 1,
(self.batch_size,),
device=device
)
return {
'features': features,
'labels': labels,
'seq_lengths': seq_lengths
}
class SimpleTpuTrainer:
"""简单的TPU训练器"""
def __init__(self, model, device, learning_rate=0.001):
self.model = model
self.device = device
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
# 数据生成器
self.data_generator = SimpleDataGenerator()
# 训练统计
self.step = 0
self.best_loss = float('inf')
def train_step(self, batch):
"""单个训练步骤"""
self.model.train()
self.optimizer.zero_grad()
# 前向传播
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失 - 重新调整形状以适应CrossEntropyLoss
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.optimizer.step()
return loss.item()
def evaluate_step(self, batch):
"""单个评估步骤"""
self.model.eval()
with torch.no_grad():
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 计算准确率
predictions = torch.argmax(logits, dim=-1)
correct = (predictions == labels).float()
accuracy = correct.mean()
return loss.item(), accuracy.item()
def train(self, num_steps=1000, eval_every=100, save_every=500):
"""训练模型"""
print(f"🚀 开始TPU训练 - 设备: {self.device}")
print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}")
train_losses = []
eval_losses = []
eval_accuracies = []
start_time = time.time()
for step in range(num_steps):
# 生成训练数据
train_batch = self.data_generator.generate_batch(self.device)
# 训练步骤
train_loss = self.train_step(train_batch)
train_losses.append(train_loss)
# XLA同步
if step % 10 == 0: # 每10步同步一次以提高效率
xm.mark_step()
# 评估
if step % eval_every == 0:
eval_batch = self.data_generator.generate_batch(self.device)
eval_loss, eval_acc = self.evaluate_step(eval_batch)
eval_losses.append(eval_loss)
eval_accuracies.append(eval_acc)
# 同步XLA操作以获得准确的时间
xm.mark_step()
xm.wait_device_ops()
current_time = time.time()
elapsed = current_time - start_time
print(f"步骤 {step:4d}/{num_steps} | "
f"训练损失: {train_loss:.4f} | "
f"验证损失: {eval_loss:.4f} | "
f"验证准确率: {eval_acc:.4f} | "
f"耗时: {elapsed:.1f}s")
# 保存最佳模型
if eval_loss < self.best_loss:
self.best_loss = eval_loss
print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}")
# 定期保存
if step > 0 and step % save_every == 0:
self.save_checkpoint(f"checkpoint_step_{step}.pt")
# 最终同步
xm.mark_step()
xm.wait_device_ops()
total_time = time.time() - start_time
print(f"\n✅ 训练完成!")
print(f"⏱️ 总耗时: {total_time:.1f}")
print(f"🎯 最终训练损失: {train_losses[-1]:.4f}")
if eval_losses:
print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}")
print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}")
return {
'train_losses': train_losses,
'eval_losses': eval_losses,
'eval_accuracies': eval_accuracies,
'total_time': total_time
}
def save_checkpoint(self, filename):
"""保存检查点"""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'step': self.step,
'best_loss': self.best_loss,
}
# 在TPU上需要先移动到CPU再保存
if 'xla' in str(self.device):
checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu'))
torch.save(checkpoint, filename)
print(f"💾 保存检查点: {filename}")
def load_checkpoint(self, filename):
"""加载检查点"""
checkpoint = torch.load(filename, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.step = checkpoint['step']
self.best_loss = checkpoint['best_loss']
print(f"📂 加载检查点: {filename}")
print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}")
def test_simple_inference():
"""测试简单推理"""
print("\n🧪 测试简单推理...")
device = xm.xla_device()
# 创建模型
model = SimpleBrainToTextModel().to(device)
# 创建测试数据
batch_size = 4
seq_len = 50
test_input = torch.randn(batch_size, seq_len, 512, device=device)
# 推理
model.eval()
with torch.no_grad():
start_time = time.time()
output = model(test_input)
xm.mark_step()
xm.wait_device_ops()
inference_time = time.time() - start_time
print(f"✅ 推理完成!")
print(f" 输入形状: {test_input.shape}")
print(f" 输出形状: {output.shape}")
print(f" 推理时间: {inference_time:.4f}")
return True
def main():
"""主函数"""
print("=" * 60)
print("🧠 简单TPU大脑到文本模型训练")
print("=" * 60)
try:
# 检查TPU设备
device = xm.xla_device()
print(f"📱 使用设备: {device}")
# 创建模型
model = SimpleBrainToTextModel(
input_features=512,
hidden_size=256,
num_classes=41,
num_layers=3
).to(device)
# 创建训练器
trainer = SimpleTpuTrainer(model, device, learning_rate=0.001)
# 开始训练
results = trainer.train(
num_steps=1000,
eval_every=100,
save_every=500
)
# 保存最终模型
trainer.save_checkpoint("final_simple_model.pt")
# 测试推理
test_simple_inference()
print("\n🎉 所有测试完成!")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,162 +0,0 @@
#!/usr/bin/env python3
"""
简化模型测试脚本 - 验证XLA编译是否正常工作
"""
import os
import time
import torch
import torch.nn as nn
# 设置XLA环境变量必须在导入torch_xla之前
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
os.environ['XLA_USE_BF16'] = '1'
print(f"🔧 XLA环境变量设置:")
print(f" CPU核心数: {os.cpu_count()}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
import torch_xla.core.xla_model as xm
class SimpleModel(nn.Module):
"""简化的测试模型"""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(512, 256)
self.gru = nn.GRU(256, 128, batch_first=True)
self.linear2 = nn.Linear(128, 41) # 41个音素类别
def forward(self, x):
x = torch.relu(self.linear1(x))
x, _ = self.gru(x)
x = self.linear2(x)
return x
def test_xla_compilation():
"""测试XLA编译速度"""
print("\n🚀 开始简化模型XLA编译测试...")
# 检查TPU设备
device = xm.xla_device()
print(f"📱 TPU设备: {device}")
print(f"🌍 TPU World Size: {xm.xrt_world_size()}")
# 创建简化模型
model = SimpleModel().to(device)
print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建测试数据
batch_size = 8 # 小批次
seq_len = 100 # 短序列
x = torch.randn(batch_size, seq_len, 512, device=device)
print(f"📥 输入形状: {x.shape}")
# 首次前向传播 - 触发XLA编译
print(f"🔄 开始首次前向传播 (XLA编译)...")
start_time = time.time()
with torch.no_grad():
output = model(x)
compile_time = time.time() - start_time
print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}")
print(f"📤 输出形状: {output.shape}")
# 再次前向传播 - 使用编译后的图
print(f"🔄 第二次前向传播 (使用编译后的图)...")
start_time = time.time()
with torch.no_grad():
output2 = model(x)
execution_time = time.time() - start_time
print(f"⚡ 执行完成! 耗时: {execution_time:.4f}")
# 性能对比
speedup = compile_time / execution_time if execution_time > 0 else float('inf')
print(f"\n📈 性能分析:")
print(f" 编译时间: {compile_time:.2f}")
print(f" 执行时间: {execution_time:.4f}")
print(f" 加速比: {speedup:.1f}x")
if compile_time < 60: # 1分钟内编译完成
print("✅ XLA编译正常!")
return True
else:
print("❌ XLA编译过慢可能有问题")
return False
def test_training_step():
"""测试训练步骤"""
print("\n🎯 测试简化训练步骤...")
device = xm.xla_device()
model = SimpleModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 创建训练数据
x = torch.randn(4, 50, 512, device=device)
labels = torch.randint(0, 41, (4, 50), device=device)
print(f"🔄 开始训练步骤 (包含反向传播)...")
start_time = time.time()
# 前向传播
outputs = model(x)
# 计算损失
loss = criterion(outputs.view(-1, 41), labels.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
step_time = time.time() - start_time
print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}")
return step_time < 120 # 2分钟内完成
def main():
print("=" * 60)
print("🧪 XLA编译快速测试")
print("=" * 60)
try:
# 测试1: 简单模型编译
compilation_ok = test_xla_compilation()
if compilation_ok:
# 测试2: 训练步骤
training_ok = test_training_step()
if training_ok:
print("\n✅ 所有测试通过! 可以尝试完整模型训练")
print("💡 建议:")
print(" 1. 确保有足够内存 (32GB+)")
print(" 2. 减小batch_size (比如从32改为16)")
print(" 3. 使用gradient_accumulation_steps补偿")
else:
print("\n⚠️ 训练步骤较慢,建议优化")
else:
print("\n❌ XLA编译有问题需要检查环境")
except Exception as e:
print(f"\n💥 测试失败: {e}")
print("💡 可能的问题:")
print(" - TPU资源不可用")
print(" - PyTorch XLA安装问题")
print(" - 内存不足")
print("=" * 60)
if __name__ == "__main__":
main()