Files
b2txt25/model_training_nnn_tpu/simple_tpu_model.py

367 lines
10 KiB
Python
Raw Normal View History

2025-10-15 15:14:01 +08:00
#!/usr/bin/env python3
"""
简单TPU模型训练和测试脚本
基于大脑到文本数据的简化版本专门为TPU优化
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, Any, Tuple
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
class SimpleBrainToTextModel(nn.Module):
"""简化的大脑到文本模型 - TPU优化版本"""
def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3):
super().__init__()
# 输入处理层
self.input_proj = nn.Linear(input_features, hidden_size)
self.input_dropout = nn.Dropout(0.2)
# GRU层 - 使用较小的隐藏层以提高TPU效率
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=0.3 if num_layers > 1 else 0
)
# 输出层
self.output_proj = nn.Linear(hidden_size, num_classes)
# 初始化权重
self._init_weights()
def _init_weights(self):
"""初始化模型权重"""
for name, param in self.named_parameters():
if 'weight' in name:
if 'gru' in name:
nn.init.orthogonal_(param)
else:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)
def forward(self, x):
"""
前向传播
Args:
x: (batch_size, seq_len, input_features)
Returns:
logits: (batch_size, seq_len, num_classes)
"""
# 输入投影
x = torch.relu(self.input_proj(x))
x = self.input_dropout(x)
# GRU处理
output, _ = self.gru(x)
# 输出投影
logits = self.output_proj(output)
return logits
class SimpleDataGenerator:
"""简单的数据生成器 - 模拟大脑信号数据"""
def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41):
self.batch_size = batch_size
self.seq_len = seq_len
self.input_features = input_features
self.num_classes = num_classes
def generate_batch(self, device):
"""生成一个批次的模拟数据"""
# 生成模拟的神经信号数据
features = torch.randn(
self.batch_size, self.seq_len, self.input_features,
device=device, dtype=torch.float32
)
# 生成模拟的标签(音素序列)
labels = torch.randint(
0, self.num_classes,
(self.batch_size, self.seq_len),
device=device
)
# 生成序列长度
seq_lengths = torch.randint(
self.seq_len // 2, self.seq_len + 1,
(self.batch_size,),
device=device
)
return {
'features': features,
'labels': labels,
'seq_lengths': seq_lengths
}
class SimpleTpuTrainer:
"""简单的TPU训练器"""
def __init__(self, model, device, learning_rate=0.001):
self.model = model
self.device = device
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
# 数据生成器
self.data_generator = SimpleDataGenerator()
# 训练统计
self.step = 0
self.best_loss = float('inf')
def train_step(self, batch):
"""单个训练步骤"""
self.model.train()
self.optimizer.zero_grad()
# 前向传播
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失 - 重新调整形状以适应CrossEntropyLoss
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.optimizer.step()
return loss.item()
def evaluate_step(self, batch):
"""单个评估步骤"""
self.model.eval()
with torch.no_grad():
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 计算准确率
predictions = torch.argmax(logits, dim=-1)
correct = (predictions == labels).float()
accuracy = correct.mean()
return loss.item(), accuracy.item()
def train(self, num_steps=1000, eval_every=100, save_every=500):
"""训练模型"""
print(f"🚀 开始TPU训练 - 设备: {self.device}")
print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}")
train_losses = []
eval_losses = []
eval_accuracies = []
start_time = time.time()
for step in range(num_steps):
# 生成训练数据
train_batch = self.data_generator.generate_batch(self.device)
# 训练步骤
train_loss = self.train_step(train_batch)
train_losses.append(train_loss)
# XLA同步
if step % 10 == 0: # 每10步同步一次以提高效率
xm.mark_step()
# 评估
if step % eval_every == 0:
eval_batch = self.data_generator.generate_batch(self.device)
eval_loss, eval_acc = self.evaluate_step(eval_batch)
eval_losses.append(eval_loss)
eval_accuracies.append(eval_acc)
# 同步XLA操作以获得准确的时间
xm.mark_step()
xm.wait_device_ops()
current_time = time.time()
elapsed = current_time - start_time
print(f"步骤 {step:4d}/{num_steps} | "
f"训练损失: {train_loss:.4f} | "
f"验证损失: {eval_loss:.4f} | "
f"验证准确率: {eval_acc:.4f} | "
f"耗时: {elapsed:.1f}s")
# 保存最佳模型
if eval_loss < self.best_loss:
self.best_loss = eval_loss
print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}")
# 定期保存
if step > 0 and step % save_every == 0:
self.save_checkpoint(f"checkpoint_step_{step}.pt")
# 最终同步
xm.mark_step()
xm.wait_device_ops()
total_time = time.time() - start_time
print(f"\n✅ 训练完成!")
print(f"⏱️ 总耗时: {total_time:.1f}")
print(f"🎯 最终训练损失: {train_losses[-1]:.4f}")
if eval_losses:
print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}")
print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}")
return {
'train_losses': train_losses,
'eval_losses': eval_losses,
'eval_accuracies': eval_accuracies,
'total_time': total_time
}
def save_checkpoint(self, filename):
"""保存检查点"""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'step': self.step,
'best_loss': self.best_loss,
}
# 在TPU上需要先移动到CPU再保存
if 'xla' in str(self.device):
checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu'))
torch.save(checkpoint, filename)
print(f"💾 保存检查点: {filename}")
def load_checkpoint(self, filename):
"""加载检查点"""
checkpoint = torch.load(filename, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.step = checkpoint['step']
self.best_loss = checkpoint['best_loss']
print(f"📂 加载检查点: {filename}")
print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}")
def test_simple_inference():
"""测试简单推理"""
print("\n🧪 测试简单推理...")
device = xm.xla_device()
# 创建模型
model = SimpleBrainToTextModel().to(device)
# 创建测试数据
batch_size = 4
seq_len = 50
test_input = torch.randn(batch_size, seq_len, 512, device=device)
# 推理
model.eval()
with torch.no_grad():
start_time = time.time()
output = model(test_input)
xm.mark_step()
xm.wait_device_ops()
inference_time = time.time() - start_time
print(f"✅ 推理完成!")
print(f" 输入形状: {test_input.shape}")
print(f" 输出形状: {output.shape}")
print(f" 推理时间: {inference_time:.4f}")
return True
def main():
"""主函数"""
print("=" * 60)
print("🧠 简单TPU大脑到文本模型训练")
print("=" * 60)
try:
# 检查TPU设备
device = xm.xla_device()
print(f"📱 使用设备: {device}")
# 创建模型
model = SimpleBrainToTextModel(
input_features=512,
hidden_size=256,
num_classes=41,
num_layers=3
).to(device)
# 创建训练器
trainer = SimpleTpuTrainer(model, device, learning_rate=0.001)
# 开始训练
results = trainer.train(
num_steps=1000,
eval_every=100,
save_every=500
)
# 保存最终模型
trainer.save_checkpoint("final_simple_model.pt")
# 测试推理
test_simple_inference()
print("\n🎉 所有测试完成!")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()