Files
b2txt25/model_training_nnn_tpu/mnist_tpu_simple.py

253 lines
6.4 KiB
Python
Raw Normal View History

2025-10-15 15:22:13 +08:00
#!/usr/bin/env python3
"""
超简单MNIST TPU训练 - 完全避开混合精度问题
只使用float32确保稳定运行
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 清理所有可能导致bf16问题的环境变量
for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']:
if key in os.environ:
del os.environ[key]
# 只设置最基本的XLA优化
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
class SimpleMNISTNet(nn.Module):
"""超简单的MNIST分类器"""
def __init__(self):
super(SimpleMNISTNet, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 64)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = self.flatten(x)
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
def get_mnist_data(batch_size=64):
"""获取MNIST数据"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=False,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
return train_loader, test_loader
def train_mnist():
"""训练MNIST模型"""
print("🚀 开始MNIST TPU训练...")
# 获取设备
device = xm.xla_device()
print(f"📱 设备: {device}")
# 创建模型
model = SimpleMNISTNet().to(device)
# 确保所有参数都是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 获取数据
print("📥 加载MNIST数据...")
train_loader, test_loader = get_mnist_data(batch_size=64)
# 使用XLA并行加载器
train_device_loader = pl.MpDeviceLoader(train_loader, device)
print("🎯 开始训练...")
model.train()
start_time = time.time()
total_loss = 0.0
correct = 0
total = 0
max_batches = 100 # 只训练100个批次快速验证
for batch_idx, (data, target) in enumerate(train_device_loader):
if batch_idx >= max_batches:
break
# 确保数据类型正确
data = data.to(torch.float32)
target = target.to(torch.long)
# 前向传播
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# 反向传播
loss.backward()
optimizer.step()
# 统计
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
# 每10个批次同步一次
if batch_idx % 10 == 0:
xm.mark_step()
current_acc = 100. * correct / total
avg_loss = total_loss / (batch_idx + 1)
print(f'批次 {batch_idx:3d}/{max_batches} | '
f'损失: {avg_loss:.4f} | '
f'准确率: {current_acc:.2f}%')
# 最终同步
xm.mark_step()
xm.wait_device_ops()
train_time = time.time() - start_time
final_acc = 100. * correct / total
final_loss = total_loss / min(batch_idx + 1, max_batches)
print(f"\n✅ 训练完成!")
print(f"⏱️ 训练时间: {train_time:.2f}")
print(f"🎯 最终损失: {final_loss:.4f}")
print(f"🎯 训练准确率: {final_acc:.2f}%")
return model, final_loss, final_acc
def test_mnist(model):
"""测试MNIST模型"""
print("\n🧪 开始测试...")
device = xm.xla_device()
_, test_loader = get_mnist_data(batch_size=64)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
model.eval()
correct = 0
total = 0
max_test_batches = 50 # 只测试50个批次
start_time = time.time()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_device_loader):
if batch_idx >= max_test_batches:
break
# 确保数据类型
data = data.to(torch.float32)
target = target.to(torch.long)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
if batch_idx % 10 == 0:
xm.mark_step()
xm.mark_step()
xm.wait_device_ops()
test_time = time.time() - start_time
accuracy = 100. * correct / total
print(f"✅ 测试完成!")
print(f"⏱️ 测试时间: {test_time:.2f}")
print(f"🎯 测试准确率: {accuracy:.2f}%")
return accuracy
def main():
"""主函数"""
print("=" * 60)
print("🔢 超简单MNIST TPU训练 (仅float32)")
print("=" * 60)
try:
# 训练
model, train_loss, train_acc = train_mnist()
# 测试
test_acc = test_mnist(model)
# 保存模型
print("\n💾 保存模型...")
model_cpu = model.cpu()
torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth')
print("✅ 模型已保存")
print("\n🎉 全部完成!")
print(f"📊 训练准确率: {train_acc:.2f}%")
print(f"📊 测试准确率: {test_acc:.2f}%")
if train_acc > 80 and test_acc > 75:
print("✅ 模型训练成功!")
else:
print("⚠️ 模型性能一般但TPU功能正常")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()