From 082018cd46713837cf2c7c2eacfac69b828a3190 Mon Sep 17 00:00:00 2001
From: Zchen <161216199+ZH-CEN@users.noreply.github.com>
Date: Wed, 15 Oct 2025 15:14:01 +0800
Subject: [PATCH] tpu-test
---
.../jupyter_xla_compatibility.py | 93 -----
model_training_nnn_tpu/jupyter_xla_monitor.py | 167 --------
model_training_nnn_tpu/jupyter_xla_test.py | 78 ----
model_training_nnn_tpu/quick_tpu_test.py | 129 ++++++
model_training_nnn_tpu/simple_tpu_model.py | 367 ++++++++++++++++++
model_training_nnn_tpu/test_simple_model.py | 162 --------
6 files changed, 496 insertions(+), 500 deletions(-)
delete mode 100644 model_training_nnn_tpu/jupyter_xla_compatibility.py
delete mode 100644 model_training_nnn_tpu/jupyter_xla_monitor.py
delete mode 100644 model_training_nnn_tpu/jupyter_xla_test.py
create mode 100644 model_training_nnn_tpu/quick_tpu_test.py
create mode 100644 model_training_nnn_tpu/simple_tpu_model.py
delete mode 100644 model_training_nnn_tpu/test_simple_model.py
diff --git a/model_training_nnn_tpu/jupyter_xla_compatibility.py b/model_training_nnn_tpu/jupyter_xla_compatibility.py
deleted file mode 100644
index ac275af..0000000
--- a/model_training_nnn_tpu/jupyter_xla_compatibility.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# ====================
-# 单元格: XLA版本兼容性检查和修复
-# ====================
-
-import torch
-import torch.nn as nn
-print("🔧 PyTorch XLA版本兼容性检查...")
-
-# 导入XLA
-import torch_xla.core.xla_model as xm
-
-print("✅ PyTorch XLA导入成功!")
-
-# 定义兼容性函数
-def get_xla_world_size():
- """获取XLA world size,兼容不同版本"""
- try:
- return xm.xrt_world_size()
- except AttributeError:
- try:
- return xm.get_world_size()
- except AttributeError:
- return 1 # 默认返回1
-
-def get_xla_ordinal():
- """获取XLA ordinal,兼容不同版本"""
- try:
- return xm.get_ordinal()
- except AttributeError:
- return 0 # 默认返回0
-
-def xla_mark_step():
- """XLA mark step,兼容不同版本"""
- try:
- xm.mark_step()
- except AttributeError:
- try:
- xm.wait_device_ops()
- except AttributeError:
- pass # 如果都不可用,则跳过
-
-def check_xla_device():
- """检查XLA设备状态"""
- try:
- device = xm.xla_device()
- print(f"📱 XLA设备: {device}")
-
- world_size = get_xla_world_size()
- ordinal = get_xla_ordinal()
-
- print(f"🌍 World Size: {world_size}")
- print(f"🔢 Ordinal: {ordinal}")
-
- # 检测设备类型
- device_str = str(device)
- if 'xla' in device_str and 'cpu' not in device_str:
- print("✅ 检测到TPU设备")
- return True, "TPU"
- elif 'xla' in device_str and 'cpu' in device_str:
- print("⚠️ XLA CPU模拟模式")
- return True, "XLA_CPU"
- else:
- print("❌ 未检测到XLA设备")
- return False, "CPU"
-
- except Exception as e:
- print(f"❌ XLA设备检查失败: {e}")
- return False, "ERROR"
-
-# 执行兼容性检查
-device_available, device_type = check_xla_device()
-
-if device_available:
- print(f"✅ XLA环境正常,设备类型: {device_type}")
-
- # 测试基本XLA操作
- print("🧪 测试基本XLA操作...")
- try:
- device = xm.xla_device()
- x = torch.randn(2, 2, device=device)
- y = torch.matmul(x, x)
-
- # 测试同步函数
- xla_mark_step()
-
- print("✅ 基本XLA操作测试成功")
-
- except Exception as e:
- print(f"❌ XLA操作测试失败: {e}")
-else:
- print("❌ XLA环境不可用")
-
-print("\n💡 兼容性检查完成,可以运行后续单元格")
\ No newline at end of file
diff --git a/model_training_nnn_tpu/jupyter_xla_monitor.py b/model_training_nnn_tpu/jupyter_xla_monitor.py
deleted file mode 100644
index e194752..0000000
--- a/model_training_nnn_tpu/jupyter_xla_monitor.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# ====================
-# 单元格2: XLA编译进度监控
-# ====================
-
-import torch
-import torch.nn as nn
-import time
-import threading
-from IPython.display import display, HTML, clear_output
-import ipywidgets as widgets
-
-# 导入XLA (环境变量已在单元格1中设置)
-print("🚀 导入PyTorch XLA...")
-import torch_xla.core.xla_model as xm
-
-print(f"✅ XLA导入成功!")
-print(f" TPU设备: {xm.xla_device()}")
-
-# 兼容新版本PyTorch XLA
-try:
- world_size = xm.xrt_world_size()
- print(f" World Size (旧API): {world_size}")
-except AttributeError:
- try:
- world_size = xm.get_world_size()
- print(f" World Size (新API): {world_size}")
- except AttributeError:
- print(" World Size: 无法获取 (可能在CPU模式)")
-
-# 检查XLA版本兼容性
-print("🔍 检查XLA API兼容性:")
-api_available = []
-api_deprecated = []
-
-# 检查各种API
-test_apis = [
- ('xrt_world_size', 'xrt_world_size()'),
- ('get_world_size', 'get_world_size()'),
- ('mark_step', 'mark_step()'),
- ('wait_device_ops', 'wait_device_ops()'),
- ('get_ordinal', 'get_ordinal()'),
- ('xla_device_count', 'xla_device_count()')
-]
-
-for api_name, api_desc in test_apis:
- if hasattr(xm, api_name):
- api_available.append(api_desc)
- else:
- api_deprecated.append(api_desc)
-
-if api_available:
- print(f" ✅ 可用API: {', '.join(api_available)}")
-if api_deprecated:
- print(f" ❌ 不可用API: {', '.join(api_deprecated)}")
-
-# 创建编译进度监控器
-class JupyterCompilationMonitor:
- def __init__(self):
- self.start_time = None
- self.is_monitoring = False
-
- # 创建输出widget
- self.output_widget = widgets.Output()
-
- # 创建进度条
- self.progress_bar = widgets.IntProgress(
- value=0,
- min=0,
- max=100,
- description='XLA编译:',
- bar_style='info',
- style={'bar_color': '#1f77b4'},
- orientation='horizontal'
- )
-
- # 创建状态标签
- self.status_label = widgets.HTML(
- value="准备开始编译..."
- )
-
- # 创建CPU使用率显示
- self.cpu_label = widgets.HTML(
- value="CPU: ---%"
- )
-
- self.memory_label = widgets.HTML(
- value="内存: ---%"
- )
-
- # 组合界面
- self.monitor_box = widgets.VBox([
- widgets.HTML("
🔄 XLA编译监控
"),
- self.progress_bar,
- self.status_label,
- widgets.HBox([self.cpu_label, self.memory_label]),
- self.output_widget
- ])
-
- def start_monitoring(self):
- """开始监控"""
- self.start_time = time.time()
- self.is_monitoring = True
-
- display(self.monitor_box)
-
- # 启动监控线程
- self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
- self.monitor_thread.start()
-
- def _monitor_loop(self):
- """监控循环"""
- while self.is_monitoring:
- try:
- elapsed = time.time() - self.start_time
- minutes = int(elapsed // 60)
- seconds = int(elapsed % 60)
-
- # 更新进度条 (模拟进度)
- progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
- self.progress_bar.value = progress
-
- # 获取系统资源
- cpu_percent = psutil.cpu_percent(interval=0.1)
- memory_percent = psutil.virtual_memory().percent
-
- # 更新显示
- self.status_label.value = f"编译进行中... ⏱️ {minutes:02d}:{seconds:02d}"
- self.cpu_label.value = f"🖥️ CPU: {cpu_percent:5.1f}%"
- self.memory_label.value = f"💾 内存: {memory_percent:5.1f}%"
-
- # 检测是否编译完成 (CPU使用率突然下降)
- if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
- self.complete_monitoring()
- break
-
- time.sleep(1)
-
- except Exception as e:
- with self.output_widget:
- print(f"监控错误: {e}")
- break
-
- def complete_monitoring(self):
- """完成监控"""
- if self.is_monitoring:
- self.is_monitoring = False
- elapsed = time.time() - self.start_time
-
- self.progress_bar.value = 100
- self.progress_bar.bar_style = 'success'
- self.status_label.value = f"✅ 编译完成! 总耗时: {elapsed:.2f}秒"
-
- with self.output_widget:
- print(f"\n🎉 XLA编译成功完成!")
- print(f"⏱️ 总耗时: {elapsed:.2f}秒")
- if elapsed < 60:
- print("✅ 编译速度正常")
- elif elapsed < 300:
- print("⚠️ 编译稍慢,但可接受")
- else:
- print("❌ 编译过慢,建议检查设置")
-
-# 创建全局监控器
-compilation_monitor = JupyterCompilationMonitor()
-
-print("✅ 编译监控器已准备就绪!")
-print("💡 运行下一个单元格开始XLA编译测试")
\ No newline at end of file
diff --git a/model_training_nnn_tpu/jupyter_xla_test.py b/model_training_nnn_tpu/jupyter_xla_test.py
deleted file mode 100644
index 2d45b9d..0000000
--- a/model_training_nnn_tpu/jupyter_xla_test.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# ====================
-# 单元格3: 快速XLA编译测试
-# ====================
-
-# 简化测试模型
-class QuickTestModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.linear1 = nn.Linear(512, 128)
- self.gru = nn.GRU(128, 64, batch_first=True)
- self.linear2 = nn.Linear(64, 41)
-
- def forward(self, x):
- x = torch.relu(self.linear1(x))
- x, _ = self.gru(x)
- x = self.linear2(x)
- return x
-
-print("🧪 开始XLA编译快速测试...")
-
-# 启动监控
-compilation_monitor.start_monitoring()
-
-try:
- # 获取TPU设备
- device = xm.xla_device()
-
- # 创建小模型
- model = QuickTestModel().to(device)
- param_count = sum(p.numel() for p in model.parameters())
- print(f"📊 测试模型参数: {param_count:,}")
-
- # 创建测试数据 (很小的batch)
- x = torch.randn(2, 20, 512, device=device)
- print(f"📥 输入数据形状: {x.shape}")
-
- print("🔄 开始首次前向传播 (触发XLA编译)...")
-
- # 首次前向传播 - 这会触发XLA编译
- with torch.no_grad():
- start_compile = time.time()
- output = model(x)
- compile_time = time.time() - start_compile
-
- print(f"✅ XLA编译完成!")
- print(f"📤 输出形状: {output.shape}")
-
- # 完成监控
- compilation_monitor.complete_monitoring()
-
- # 测试编译后的性能
- print("\n🚀 测试编译后的执行速度...")
- with torch.no_grad():
- start_exec = time.time()
- for _ in range(10):
- output = model(x)
- avg_exec_time = (time.time() - start_exec) / 10
-
- print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
-
- # 性能评估
- if compile_time < 30:
- print("✅ 编译速度优秀! 可以尝试完整模型")
- test_result = "excellent"
- elif compile_time < 120:
- print("✅ 编译速度良好! 建议使用简化配置")
- test_result = "good"
- else:
- print("⚠️ 编译速度较慢,建议进一步优化")
- test_result = "slow"
-
-except Exception as e:
- compilation_monitor.complete_monitoring()
- print(f"❌ 测试失败: {e}")
- test_result = "failed"
-
-print(f"\n📋 测试结果: {test_result}")
-print("💡 如果测试通过,可以运行下一个单元格进行完整训练")
\ No newline at end of file
diff --git a/model_training_nnn_tpu/quick_tpu_test.py b/model_training_nnn_tpu/quick_tpu_test.py
new file mode 100644
index 0000000..2cb93fb
--- /dev/null
+++ b/model_training_nnn_tpu/quick_tpu_test.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+"""
+快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行
+"""
+
+import os
+import time
+import torch
+import torch.nn as nn
+
+# 设置环境变量
+os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true'
+os.environ['XLA_USE_BF16'] = '1'
+
+import torch_xla.core.xla_model as xm
+
+def quick_test():
+ """快速测试TPU是否工作正常"""
+ print("🚀 开始快速TPU测试...")
+
+ try:
+ # 获取TPU设备
+ device = xm.xla_device()
+ print(f"📱 TPU设备: {device}")
+
+ # 创建简单模型
+ model = nn.Sequential(
+ nn.Linear(512, 256),
+ nn.ReLU(),
+ nn.GRU(256, 128, batch_first=True),
+ nn.Linear(128, 41)
+ ).to(device)
+
+ print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
+
+ # 创建测试数据
+ x = torch.randn(8, 50, 512, device=device)
+ print(f"📥 输入形状: {x.shape}")
+
+ # 测试前向传播
+ print("🔄 测试前向传播...")
+ start_time = time.time()
+
+ with torch.no_grad():
+ if hasattr(model, '__getitem__'):
+ # 对于Sequential模型,手动处理GRU层
+ x_proj = model[1](model[0](x)) # Linear + ReLU
+ gru_out, _ = model[2](x_proj) # GRU
+ output = model[3](gru_out) # Final Linear
+ else:
+ output = model(x)
+
+ # 同步TPU操作
+ xm.mark_step()
+ xm.wait_device_ops()
+
+ forward_time = time.time() - start_time
+ print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}秒")
+ print(f"📤 输出形状: {output.shape}")
+
+ # 测试反向传播
+ print("🔄 测试反向传播...")
+ model.train()
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+
+ start_time = time.time()
+
+ # 创建虚拟标签
+ labels = torch.randint(0, 41, (8, 50), device=device)
+ criterion = nn.CrossEntropyLoss()
+
+ # 前向传播
+ if hasattr(model, '__getitem__'):
+ x_proj = model[1](model[0](x))
+ gru_out, _ = model[2](x_proj)
+ output = model[3](gru_out)
+ else:
+ output = model(x)
+
+ # 计算损失
+ loss = criterion(output.view(-1, 41), labels.view(-1))
+
+ # 反向传播
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # 同步TPU操作
+ xm.mark_step()
+ xm.wait_device_ops()
+
+ backward_time = time.time() - start_time
+ print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}秒")
+ print(f"🎯 损失值: {loss.item():.4f}")
+
+ # 总结
+ print(f"\n📈 性能总结:")
+ print(f" 前向传播: {forward_time:.3f}秒")
+ print(f" 反向传播: {backward_time:.3f}秒")
+ print(f" 总计: {forward_time + backward_time:.3f}秒")
+
+ if (forward_time + backward_time) < 10: # 10秒内完成
+ print("✅ TPU测试通过! 可以进行完整训练")
+ return True
+ else:
+ print("⚠️ TPU性能较慢,可能需要优化")
+ return False
+
+ except Exception as e:
+ print(f"❌ TPU测试失败: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+
+if __name__ == "__main__":
+ print("=" * 50)
+ print("⚡ 快速TPU测试")
+ print("=" * 50)
+
+ success = quick_test()
+
+ if success:
+ print("\n🎉 测试成功! 现在可以运行:")
+ print(" python simple_tpu_model.py")
+ else:
+ print("\n❌ 测试失败,请检查TPU配置")
+
+ print("=" * 50)
\ No newline at end of file
diff --git a/model_training_nnn_tpu/simple_tpu_model.py b/model_training_nnn_tpu/simple_tpu_model.py
new file mode 100644
index 0000000..21089ed
--- /dev/null
+++ b/model_training_nnn_tpu/simple_tpu_model.py
@@ -0,0 +1,367 @@
+#!/usr/bin/env python3
+"""
+简单TPU模型训练和测试脚本
+基于大脑到文本数据的简化版本,专门为TPU优化
+"""
+
+import os
+import time
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import numpy as np
+from typing import Dict, Any, Tuple
+
+# 设置XLA环境变量
+os.environ['XLA_FLAGS'] = (
+ '--xla_cpu_multi_thread_eigen=true '
+ '--xla_cpu_enable_fast_math=true '
+ f'--xla_force_host_platform_device_count={os.cpu_count()}'
+)
+os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
+os.environ['XLA_USE_BF16'] = '1'
+
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+
+
+class SimpleBrainToTextModel(nn.Module):
+ """简化的大脑到文本模型 - TPU优化版本"""
+
+ def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3):
+ super().__init__()
+
+ # 输入处理层
+ self.input_proj = nn.Linear(input_features, hidden_size)
+ self.input_dropout = nn.Dropout(0.2)
+
+ # GRU层 - 使用较小的隐藏层以提高TPU效率
+ self.gru = nn.GRU(
+ input_size=hidden_size,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=0.3 if num_layers > 1 else 0
+ )
+
+ # 输出层
+ self.output_proj = nn.Linear(hidden_size, num_classes)
+
+ # 初始化权重
+ self._init_weights()
+
+ def _init_weights(self):
+ """初始化模型权重"""
+ for name, param in self.named_parameters():
+ if 'weight' in name:
+ if 'gru' in name:
+ nn.init.orthogonal_(param)
+ else:
+ nn.init.xavier_uniform_(param)
+ elif 'bias' in name:
+ nn.init.zeros_(param)
+
+ def forward(self, x):
+ """
+ 前向传播
+ Args:
+ x: (batch_size, seq_len, input_features)
+ Returns:
+ logits: (batch_size, seq_len, num_classes)
+ """
+ # 输入投影
+ x = torch.relu(self.input_proj(x))
+ x = self.input_dropout(x)
+
+ # GRU处理
+ output, _ = self.gru(x)
+
+ # 输出投影
+ logits = self.output_proj(output)
+
+ return logits
+
+
+class SimpleDataGenerator:
+ """简单的数据生成器 - 模拟大脑信号数据"""
+
+ def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41):
+ self.batch_size = batch_size
+ self.seq_len = seq_len
+ self.input_features = input_features
+ self.num_classes = num_classes
+
+ def generate_batch(self, device):
+ """生成一个批次的模拟数据"""
+ # 生成模拟的神经信号数据
+ features = torch.randn(
+ self.batch_size, self.seq_len, self.input_features,
+ device=device, dtype=torch.float32
+ )
+
+ # 生成模拟的标签(音素序列)
+ labels = torch.randint(
+ 0, self.num_classes,
+ (self.batch_size, self.seq_len),
+ device=device
+ )
+
+ # 生成序列长度
+ seq_lengths = torch.randint(
+ self.seq_len // 2, self.seq_len + 1,
+ (self.batch_size,),
+ device=device
+ )
+
+ return {
+ 'features': features,
+ 'labels': labels,
+ 'seq_lengths': seq_lengths
+ }
+
+
+class SimpleTpuTrainer:
+ """简单的TPU训练器"""
+
+ def __init__(self, model, device, learning_rate=0.001):
+ self.model = model
+ self.device = device
+ self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+ self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
+
+ # 数据生成器
+ self.data_generator = SimpleDataGenerator()
+
+ # 训练统计
+ self.step = 0
+ self.best_loss = float('inf')
+
+ def train_step(self, batch):
+ """单个训练步骤"""
+ self.model.train()
+ self.optimizer.zero_grad()
+
+ # 前向传播
+ features = batch['features']
+ labels = batch['labels']
+
+ logits = self.model(features)
+
+ # 计算损失 - 重新调整形状以适应CrossEntropyLoss
+ batch_size, seq_len, num_classes = logits.shape
+ loss = self.criterion(
+ logits.reshape(-1, num_classes),
+ labels.reshape(-1)
+ )
+
+ # 反向传播
+ loss.backward()
+
+ # 梯度裁剪
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
+
+ # 更新参数
+ self.optimizer.step()
+
+ return loss.item()
+
+ def evaluate_step(self, batch):
+ """单个评估步骤"""
+ self.model.eval()
+
+ with torch.no_grad():
+ features = batch['features']
+ labels = batch['labels']
+
+ logits = self.model(features)
+
+ # 计算损失
+ batch_size, seq_len, num_classes = logits.shape
+ loss = self.criterion(
+ logits.reshape(-1, num_classes),
+ labels.reshape(-1)
+ )
+
+ # 计算准确率
+ predictions = torch.argmax(logits, dim=-1)
+ correct = (predictions == labels).float()
+ accuracy = correct.mean()
+
+ return loss.item(), accuracy.item()
+
+ def train(self, num_steps=1000, eval_every=100, save_every=500):
+ """训练模型"""
+ print(f"🚀 开始TPU训练 - 设备: {self.device}")
+ print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}")
+
+ train_losses = []
+ eval_losses = []
+ eval_accuracies = []
+
+ start_time = time.time()
+
+ for step in range(num_steps):
+ # 生成训练数据
+ train_batch = self.data_generator.generate_batch(self.device)
+
+ # 训练步骤
+ train_loss = self.train_step(train_batch)
+ train_losses.append(train_loss)
+
+ # XLA同步
+ if step % 10 == 0: # 每10步同步一次以提高效率
+ xm.mark_step()
+
+ # 评估
+ if step % eval_every == 0:
+ eval_batch = self.data_generator.generate_batch(self.device)
+ eval_loss, eval_acc = self.evaluate_step(eval_batch)
+ eval_losses.append(eval_loss)
+ eval_accuracies.append(eval_acc)
+
+ # 同步XLA操作以获得准确的时间
+ xm.mark_step()
+ xm.wait_device_ops()
+
+ current_time = time.time()
+ elapsed = current_time - start_time
+
+ print(f"步骤 {step:4d}/{num_steps} | "
+ f"训练损失: {train_loss:.4f} | "
+ f"验证损失: {eval_loss:.4f} | "
+ f"验证准确率: {eval_acc:.4f} | "
+ f"耗时: {elapsed:.1f}s")
+
+ # 保存最佳模型
+ if eval_loss < self.best_loss:
+ self.best_loss = eval_loss
+ print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}")
+
+ # 定期保存
+ if step > 0 and step % save_every == 0:
+ self.save_checkpoint(f"checkpoint_step_{step}.pt")
+
+ # 最终同步
+ xm.mark_step()
+ xm.wait_device_ops()
+
+ total_time = time.time() - start_time
+ print(f"\n✅ 训练完成!")
+ print(f"⏱️ 总耗时: {total_time:.1f}秒")
+ print(f"🎯 最终训练损失: {train_losses[-1]:.4f}")
+ if eval_losses:
+ print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}")
+ print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}")
+
+ return {
+ 'train_losses': train_losses,
+ 'eval_losses': eval_losses,
+ 'eval_accuracies': eval_accuracies,
+ 'total_time': total_time
+ }
+
+ def save_checkpoint(self, filename):
+ """保存检查点"""
+ checkpoint = {
+ 'model_state_dict': self.model.state_dict(),
+ 'optimizer_state_dict': self.optimizer.state_dict(),
+ 'step': self.step,
+ 'best_loss': self.best_loss,
+ }
+
+ # 在TPU上需要先移动到CPU再保存
+ if 'xla' in str(self.device):
+ checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu'))
+
+ torch.save(checkpoint, filename)
+ print(f"💾 保存检查点: {filename}")
+
+ def load_checkpoint(self, filename):
+ """加载检查点"""
+ checkpoint = torch.load(filename, map_location='cpu')
+
+ self.model.load_state_dict(checkpoint['model_state_dict'])
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ self.step = checkpoint['step']
+ self.best_loss = checkpoint['best_loss']
+
+ print(f"📂 加载检查点: {filename}")
+ print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}")
+
+
+def test_simple_inference():
+ """测试简单推理"""
+ print("\n🧪 测试简单推理...")
+
+ device = xm.xla_device()
+
+ # 创建模型
+ model = SimpleBrainToTextModel().to(device)
+
+ # 创建测试数据
+ batch_size = 4
+ seq_len = 50
+ test_input = torch.randn(batch_size, seq_len, 512, device=device)
+
+ # 推理
+ model.eval()
+ with torch.no_grad():
+ start_time = time.time()
+ output = model(test_input)
+ xm.mark_step()
+ xm.wait_device_ops()
+ inference_time = time.time() - start_time
+
+ print(f"✅ 推理完成!")
+ print(f" 输入形状: {test_input.shape}")
+ print(f" 输出形状: {output.shape}")
+ print(f" 推理时间: {inference_time:.4f}秒")
+
+ return True
+
+
+def main():
+ """主函数"""
+ print("=" * 60)
+ print("🧠 简单TPU大脑到文本模型训练")
+ print("=" * 60)
+
+ try:
+ # 检查TPU设备
+ device = xm.xla_device()
+ print(f"📱 使用设备: {device}")
+
+ # 创建模型
+ model = SimpleBrainToTextModel(
+ input_features=512,
+ hidden_size=256,
+ num_classes=41,
+ num_layers=3
+ ).to(device)
+
+ # 创建训练器
+ trainer = SimpleTpuTrainer(model, device, learning_rate=0.001)
+
+ # 开始训练
+ results = trainer.train(
+ num_steps=1000,
+ eval_every=100,
+ save_every=500
+ )
+
+ # 保存最终模型
+ trainer.save_checkpoint("final_simple_model.pt")
+
+ # 测试推理
+ test_simple_inference()
+
+ print("\n🎉 所有测试完成!")
+
+ except Exception as e:
+ print(f"❌ 训练失败: {e}")
+ import traceback
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/model_training_nnn_tpu/test_simple_model.py b/model_training_nnn_tpu/test_simple_model.py
deleted file mode 100644
index 4a18c5a..0000000
--- a/model_training_nnn_tpu/test_simple_model.py
+++ /dev/null
@@ -1,162 +0,0 @@
-#!/usr/bin/env python3
-"""
-简化模型测试脚本 - 验证XLA编译是否正常工作
-"""
-
-import os
-import time
-import torch
-import torch.nn as nn
-
-# 设置XLA环境变量(必须在导入torch_xla之前)
-os.environ['XLA_FLAGS'] = (
- '--xla_cpu_multi_thread_eigen=true '
- '--xla_cpu_enable_fast_math=true '
- f'--xla_force_host_platform_device_count={os.cpu_count()}'
-)
-os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
-os.environ['XLA_USE_BF16'] = '1'
-
-print(f"🔧 XLA环境变量设置:")
-print(f" CPU核心数: {os.cpu_count()}")
-print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
-print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
-
-import torch_xla.core.xla_model as xm
-
-class SimpleModel(nn.Module):
- """简化的测试模型"""
- def __init__(self):
- super().__init__()
- self.linear1 = nn.Linear(512, 256)
- self.gru = nn.GRU(256, 128, batch_first=True)
- self.linear2 = nn.Linear(128, 41) # 41个音素类别
-
- def forward(self, x):
- x = torch.relu(self.linear1(x))
- x, _ = self.gru(x)
- x = self.linear2(x)
- return x
-
-def test_xla_compilation():
- """测试XLA编译速度"""
- print("\n🚀 开始简化模型XLA编译测试...")
-
- # 检查TPU设备
- device = xm.xla_device()
- print(f"📱 TPU设备: {device}")
- print(f"🌍 TPU World Size: {xm.xrt_world_size()}")
-
- # 创建简化模型
- model = SimpleModel().to(device)
- print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
-
- # 创建测试数据
- batch_size = 8 # 小批次
- seq_len = 100 # 短序列
- x = torch.randn(batch_size, seq_len, 512, device=device)
-
- print(f"📥 输入形状: {x.shape}")
-
- # 首次前向传播 - 触发XLA编译
- print(f"🔄 开始首次前向传播 (XLA编译)...")
- start_time = time.time()
-
- with torch.no_grad():
- output = model(x)
-
- compile_time = time.time() - start_time
- print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}秒")
- print(f"📤 输出形状: {output.shape}")
-
- # 再次前向传播 - 使用编译后的图
- print(f"🔄 第二次前向传播 (使用编译后的图)...")
- start_time = time.time()
-
- with torch.no_grad():
- output2 = model(x)
-
- execution_time = time.time() - start_time
- print(f"⚡ 执行完成! 耗时: {execution_time:.4f}秒")
-
- # 性能对比
- speedup = compile_time / execution_time if execution_time > 0 else float('inf')
- print(f"\n📈 性能分析:")
- print(f" 编译时间: {compile_time:.2f}秒")
- print(f" 执行时间: {execution_time:.4f}秒")
- print(f" 加速比: {speedup:.1f}x")
-
- if compile_time < 60: # 1分钟内编译完成
- print("✅ XLA编译正常!")
- return True
- else:
- print("❌ XLA编译过慢,可能有问题")
- return False
-
-def test_training_step():
- """测试训练步骤"""
- print("\n🎯 测试简化训练步骤...")
-
- device = xm.xla_device()
- model = SimpleModel().to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- criterion = nn.CrossEntropyLoss()
-
- # 创建训练数据
- x = torch.randn(4, 50, 512, device=device)
- labels = torch.randint(0, 41, (4, 50), device=device)
-
- print(f"🔄 开始训练步骤 (包含反向传播)...")
- start_time = time.time()
-
- # 前向传播
- outputs = model(x)
-
- # 计算损失
- loss = criterion(outputs.view(-1, 41), labels.view(-1))
-
- # 反向传播
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- step_time = time.time() - start_time
- print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}")
-
- return step_time < 120 # 2分钟内完成
-
-def main():
- print("=" * 60)
- print("🧪 XLA编译快速测试")
- print("=" * 60)
-
- try:
- # 测试1: 简单模型编译
- compilation_ok = test_xla_compilation()
-
- if compilation_ok:
- # 测试2: 训练步骤
- training_ok = test_training_step()
-
- if training_ok:
- print("\n✅ 所有测试通过! 可以尝试完整模型训练")
- print("💡 建议:")
- print(" 1. 确保有足够内存 (32GB+)")
- print(" 2. 减小batch_size (比如从32改为16)")
- print(" 3. 使用gradient_accumulation_steps补偿")
- else:
- print("\n⚠️ 训练步骤较慢,建议优化")
- else:
- print("\n❌ XLA编译有问题,需要检查环境")
-
- except Exception as e:
- print(f"\n💥 测试失败: {e}")
- print("💡 可能的问题:")
- print(" - TPU资源不可用")
- print(" - PyTorch XLA安装问题")
- print(" - 内存不足")
-
- print("=" * 60)
-
-if __name__ == "__main__":
- main()
\ No newline at end of file