253 lines
6.4 KiB
Python
253 lines
6.4 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
超简单MNIST TPU训练 - 完全避开混合精度问题
|
|||
|
只使用float32,确保稳定运行
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import time
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.optim as optim
|
|||
|
import torchvision
|
|||
|
import torchvision.transforms as transforms
|
|||
|
|
|||
|
# 清理所有可能导致bf16问题的环境变量
|
|||
|
for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']:
|
|||
|
if key in os.environ:
|
|||
|
del os.environ[key]
|
|||
|
|
|||
|
# 只设置最基本的XLA优化
|
|||
|
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false'
|
|||
|
|
|||
|
import torch_xla.core.xla_model as xm
|
|||
|
import torch_xla.distributed.parallel_loader as pl
|
|||
|
|
|||
|
|
|||
|
class SimpleMNISTNet(nn.Module):
|
|||
|
"""超简单的MNIST分类器"""
|
|||
|
|
|||
|
def __init__(self):
|
|||
|
super(SimpleMNISTNet, self).__init__()
|
|||
|
self.flatten = nn.Flatten()
|
|||
|
self.fc1 = nn.Linear(28 * 28, 128)
|
|||
|
self.relu1 = nn.ReLU()
|
|||
|
self.fc2 = nn.Linear(128, 64)
|
|||
|
self.relu2 = nn.ReLU()
|
|||
|
self.fc3 = nn.Linear(64, 10)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
x = self.flatten(x)
|
|||
|
x = self.relu1(self.fc1(x))
|
|||
|
x = self.relu2(self.fc2(x))
|
|||
|
x = self.fc3(x)
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
def get_mnist_data(batch_size=64):
|
|||
|
"""获取MNIST数据"""
|
|||
|
transform = transforms.Compose([
|
|||
|
transforms.ToTensor(),
|
|||
|
transforms.Normalize((0.1307,), (0.3081,))
|
|||
|
])
|
|||
|
|
|||
|
train_dataset = torchvision.datasets.MNIST(
|
|||
|
root='./mnist_data',
|
|||
|
train=True,
|
|||
|
download=True,
|
|||
|
transform=transform
|
|||
|
)
|
|||
|
|
|||
|
test_dataset = torchvision.datasets.MNIST(
|
|||
|
root='./mnist_data',
|
|||
|
train=False,
|
|||
|
download=True,
|
|||
|
transform=transform
|
|||
|
)
|
|||
|
|
|||
|
train_loader = torch.utils.data.DataLoader(
|
|||
|
train_dataset,
|
|||
|
batch_size=batch_size,
|
|||
|
shuffle=True,
|
|||
|
num_workers=0
|
|||
|
)
|
|||
|
|
|||
|
test_loader = torch.utils.data.DataLoader(
|
|||
|
test_dataset,
|
|||
|
batch_size=batch_size,
|
|||
|
shuffle=False,
|
|||
|
num_workers=0
|
|||
|
)
|
|||
|
|
|||
|
return train_loader, test_loader
|
|||
|
|
|||
|
|
|||
|
def train_mnist():
|
|||
|
"""训练MNIST模型"""
|
|||
|
print("🚀 开始MNIST TPU训练...")
|
|||
|
|
|||
|
# 获取设备
|
|||
|
device = xm.xla_device()
|
|||
|
print(f"📱 设备: {device}")
|
|||
|
|
|||
|
# 创建模型
|
|||
|
model = SimpleMNISTNet().to(device)
|
|||
|
|
|||
|
# 确保所有参数都是float32
|
|||
|
for param in model.parameters():
|
|||
|
param.data = param.data.to(torch.float32)
|
|||
|
|
|||
|
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
|||
|
|
|||
|
# 损失函数和优化器
|
|||
|
criterion = nn.CrossEntropyLoss()
|
|||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|||
|
|
|||
|
# 获取数据
|
|||
|
print("📥 加载MNIST数据...")
|
|||
|
train_loader, test_loader = get_mnist_data(batch_size=64)
|
|||
|
|
|||
|
# 使用XLA并行加载器
|
|||
|
train_device_loader = pl.MpDeviceLoader(train_loader, device)
|
|||
|
|
|||
|
print("🎯 开始训练...")
|
|||
|
|
|||
|
model.train()
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
total_loss = 0.0
|
|||
|
correct = 0
|
|||
|
total = 0
|
|||
|
max_batches = 100 # 只训练100个批次,快速验证
|
|||
|
|
|||
|
for batch_idx, (data, target) in enumerate(train_device_loader):
|
|||
|
if batch_idx >= max_batches:
|
|||
|
break
|
|||
|
|
|||
|
# 确保数据类型正确
|
|||
|
data = data.to(torch.float32)
|
|||
|
target = target.to(torch.long)
|
|||
|
|
|||
|
# 前向传播
|
|||
|
optimizer.zero_grad()
|
|||
|
output = model(data)
|
|||
|
loss = criterion(output, target)
|
|||
|
|
|||
|
# 反向传播
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
|
|||
|
# 统计
|
|||
|
total_loss += loss.item()
|
|||
|
pred = output.argmax(dim=1)
|
|||
|
correct += pred.eq(target).sum().item()
|
|||
|
total += target.size(0)
|
|||
|
|
|||
|
# 每10个批次同步一次
|
|||
|
if batch_idx % 10 == 0:
|
|||
|
xm.mark_step()
|
|||
|
current_acc = 100. * correct / total
|
|||
|
avg_loss = total_loss / (batch_idx + 1)
|
|||
|
|
|||
|
print(f'批次 {batch_idx:3d}/{max_batches} | '
|
|||
|
f'损失: {avg_loss:.4f} | '
|
|||
|
f'准确率: {current_acc:.2f}%')
|
|||
|
|
|||
|
# 最终同步
|
|||
|
xm.mark_step()
|
|||
|
xm.wait_device_ops()
|
|||
|
|
|||
|
train_time = time.time() - start_time
|
|||
|
final_acc = 100. * correct / total
|
|||
|
final_loss = total_loss / min(batch_idx + 1, max_batches)
|
|||
|
|
|||
|
print(f"\n✅ 训练完成!")
|
|||
|
print(f"⏱️ 训练时间: {train_time:.2f}秒")
|
|||
|
print(f"🎯 最终损失: {final_loss:.4f}")
|
|||
|
print(f"🎯 训练准确率: {final_acc:.2f}%")
|
|||
|
|
|||
|
return model, final_loss, final_acc
|
|||
|
|
|||
|
|
|||
|
def test_mnist(model):
|
|||
|
"""测试MNIST模型"""
|
|||
|
print("\n🧪 开始测试...")
|
|||
|
|
|||
|
device = xm.xla_device()
|
|||
|
_, test_loader = get_mnist_data(batch_size=64)
|
|||
|
|
|||
|
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
|||
|
|
|||
|
model.eval()
|
|||
|
correct = 0
|
|||
|
total = 0
|
|||
|
max_test_batches = 50 # 只测试50个批次
|
|||
|
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
with torch.no_grad():
|
|||
|
for batch_idx, (data, target) in enumerate(test_device_loader):
|
|||
|
if batch_idx >= max_test_batches:
|
|||
|
break
|
|||
|
|
|||
|
# 确保数据类型
|
|||
|
data = data.to(torch.float32)
|
|||
|
target = target.to(torch.long)
|
|||
|
|
|||
|
output = model(data)
|
|||
|
pred = output.argmax(dim=1)
|
|||
|
correct += pred.eq(target).sum().item()
|
|||
|
total += target.size(0)
|
|||
|
|
|||
|
if batch_idx % 10 == 0:
|
|||
|
xm.mark_step()
|
|||
|
|
|||
|
xm.mark_step()
|
|||
|
xm.wait_device_ops()
|
|||
|
|
|||
|
test_time = time.time() - start_time
|
|||
|
accuracy = 100. * correct / total
|
|||
|
|
|||
|
print(f"✅ 测试完成!")
|
|||
|
print(f"⏱️ 测试时间: {test_time:.2f}秒")
|
|||
|
print(f"🎯 测试准确率: {accuracy:.2f}%")
|
|||
|
|
|||
|
return accuracy
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数"""
|
|||
|
print("=" * 60)
|
|||
|
print("🔢 超简单MNIST TPU训练 (仅float32)")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
try:
|
|||
|
# 训练
|
|||
|
model, train_loss, train_acc = train_mnist()
|
|||
|
|
|||
|
# 测试
|
|||
|
test_acc = test_mnist(model)
|
|||
|
|
|||
|
# 保存模型
|
|||
|
print("\n💾 保存模型...")
|
|||
|
model_cpu = model.cpu()
|
|||
|
torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth')
|
|||
|
print("✅ 模型已保存")
|
|||
|
|
|||
|
print("\n🎉 全部完成!")
|
|||
|
print(f"📊 训练准确率: {train_acc:.2f}%")
|
|||
|
print(f"📊 测试准确率: {test_acc:.2f}%")
|
|||
|
|
|||
|
if train_acc > 80 and test_acc > 75:
|
|||
|
print("✅ 模型训练成功!")
|
|||
|
else:
|
|||
|
print("⚠️ 模型性能一般,但TPU功能正常")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 训练失败: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|