#!/usr/bin/env python3 """ 使用AMP的TPU训练脚本 正确处理混合精度训练,避免dtype不匹配问题 """ import os import time import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms # 设置AMP相关的环境变量 os.environ['XLA_FLAGS'] = ( '--xla_cpu_multi_thread_eigen=true ' '--xla_cpu_enable_fast_math=true' ) os.environ['XLA_USE_BF16'] = '1' # 启用bf16 import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl import torch_xla.amp as xla_amp class AMPModel(nn.Module): """支持AMP的简单模型""" def __init__(self, input_size=784, hidden_size=512, num_classes=10): super(AMPModel, self).__init__() self.network = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(hidden_size // 2, num_classes) ) def forward(self, x): # 展平输入 x = x.view(x.size(0), -1) return self.network(x) class AMPTrainer: """AMP训练器""" def __init__(self, model, device, learning_rate=0.001): self.model = model self.device = device self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) self.criterion = nn.CrossEntropyLoss() # 初始化AMP scaler self.scaler = xla_amp.GradScaler() print(f"✅ AMP训练器初始化完成") print(f" 设备: {device}") print(f" 模型参数: {sum(p.numel() for p in model.parameters()):,}") def train_step(self, data, target): """单个AMP训练步骤""" self.model.train() self.optimizer.zero_grad() # 使用autocast进行混合精度前向传播 with xla_amp.autocast(): output = self.model(data) loss = self.criterion(output, target) # 使用scaler进行反向传播 self.scaler.scale(loss).backward() # 梯度裁剪(可选) self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 更新参数 self.scaler.step(self.optimizer) self.scaler.update() # 计算准确率 pred = output.argmax(dim=1) correct = pred.eq(target).sum().item() accuracy = correct / target.size(0) return loss.item(), accuracy def evaluate_step(self, data, target): """单个评估步骤""" self.model.eval() with torch.no_grad(): with xla_amp.autocast(): output = self.model(data) loss = self.criterion(output, target) pred = output.argmax(dim=1) correct = pred.eq(target).sum().item() accuracy = correct / target.size(0) return loss.item(), accuracy def get_mnist_loaders(batch_size=64): """获取MNIST数据加载器""" transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = torchvision.datasets.MNIST( root='./mnist_data', train=True, download=True, transform=transform ) test_dataset = torchvision.datasets.MNIST( root='./mnist_data', train=False, download=True, transform=transform ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=0 ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=0 ) return train_loader, test_loader def train_with_amp(): """使用AMP进行训练""" print("🚀 开始AMP TPU训练...") # 获取设备 device = xm.xla_device() print(f"📱 设备: {device}") # 创建模型 model = AMPModel(input_size=784, hidden_size=512, num_classes=10).to(device) # 创建训练器 trainer = AMPTrainer(model, device, learning_rate=0.001) # 获取数据 print("📥 加载MNIST数据...") train_loader, test_loader = get_mnist_loaders(batch_size=64) # 使用XLA并行加载器 train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) print("🎯 开始AMP训练...") # 训练循环 num_epochs = 2 train_losses = [] train_accuracies = [] for epoch in range(num_epochs): print(f"\n📊 Epoch {epoch + 1}/{num_epochs}") epoch_start = time.time() epoch_loss = 0.0 epoch_acc = 0.0 num_batches = 0 max_batches_per_epoch = 200 # 限制每个epoch的批次数 for batch_idx, (data, target) in enumerate(train_device_loader): if batch_idx >= max_batches_per_epoch: break # 训练步骤 loss, accuracy = trainer.train_step(data, target) epoch_loss += loss epoch_acc += accuracy num_batches += 1 # 每20个批次同步一次 if batch_idx % 20 == 0: xm.mark_step() avg_loss = epoch_loss / num_batches avg_acc = epoch_acc / num_batches * 100 print(f" 批次 {batch_idx:3d}/{max_batches_per_epoch} | " f"损失: {avg_loss:.4f} | " f"准确率: {avg_acc:.2f}%") # Epoch结束同步 xm.mark_step() xm.wait_device_ops() epoch_time = time.time() - epoch_start final_loss = epoch_loss / num_batches final_acc = epoch_acc / num_batches * 100 train_losses.append(final_loss) train_accuracies.append(final_acc) print(f"✅ Epoch {epoch + 1} 完成 | " f"耗时: {epoch_time:.2f}s | " f"平均损失: {final_loss:.4f} | " f"平均准确率: {final_acc:.2f}%") return trainer, train_losses, train_accuracies def test_with_amp(trainer): """使用AMP进行测试""" print("\n🧪 开始AMP测试...") device = xm.xla_device() _, test_loader = get_mnist_loaders(batch_size=64) test_device_loader = pl.MpDeviceLoader(test_loader, device) total_loss = 0.0 total_acc = 0.0 num_batches = 0 max_test_batches = 100 start_time = time.time() for batch_idx, (data, target) in enumerate(test_device_loader): if batch_idx >= max_test_batches: break loss, accuracy = trainer.evaluate_step(data, target) total_loss += loss total_acc += accuracy num_batches += 1 if batch_idx % 20 == 0: xm.mark_step() xm.mark_step() xm.wait_device_ops() test_time = time.time() - start_time avg_loss = total_loss / num_batches avg_acc = total_acc / num_batches * 100 print(f"✅ 测试完成!") print(f"⏱️ 测试时间: {test_time:.2f}秒") print(f"🎯 测试损失: {avg_loss:.4f}") print(f"🎯 测试准确率: {avg_acc:.2f}%") return avg_loss, avg_acc def main(): """主函数""" print("=" * 60) print("⚡ AMP TPU训练示例") print("=" * 60) try: # 训练 trainer, train_losses, train_accuracies = train_with_amp() # 测试 test_loss, test_acc = test_with_amp(trainer) # 保存模型 print("\n💾 保存模型...") model_cpu = trainer.model.cpu() torch.save({ 'model_state_dict': model_cpu.state_dict(), 'train_losses': train_losses, 'train_accuracies': train_accuracies, 'test_loss': test_loss, 'test_accuracy': test_acc }, 'amp_mnist_model.pth') print("✅ 模型已保存到 amp_mnist_model.pth") print("\n🎉 AMP训练完成!") print(f"📊 最终训练准确率: {train_accuracies[-1]:.2f}%") print(f"📊 测试准确率: {test_acc:.2f}%") if train_accuracies[-1] > 85 and test_acc > 80: print("✅ AMP训练成功! 模型性能优秀") else: print("⚠️ 模型性能一般,但AMP功能正常") except Exception as e: print(f"❌ AMP训练失败: {e}") import traceback traceback.print_exc() print("\n💡 故障排除建议:") print(" 1. 确保PyTorch XLA版本支持AMP") print(" 2. 检查TPU资源是否充足") print(" 3. 尝试减小batch_size") if __name__ == "__main__": main()