From 082018cd46713837cf2c7c2eacfac69b828a3190 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 15 Oct 2025 15:14:01 +0800 Subject: [PATCH] tpu-test --- .../jupyter_xla_compatibility.py | 93 ----- model_training_nnn_tpu/jupyter_xla_monitor.py | 167 -------- model_training_nnn_tpu/jupyter_xla_test.py | 78 ---- model_training_nnn_tpu/quick_tpu_test.py | 129 ++++++ model_training_nnn_tpu/simple_tpu_model.py | 367 ++++++++++++++++++ model_training_nnn_tpu/test_simple_model.py | 162 -------- 6 files changed, 496 insertions(+), 500 deletions(-) delete mode 100644 model_training_nnn_tpu/jupyter_xla_compatibility.py delete mode 100644 model_training_nnn_tpu/jupyter_xla_monitor.py delete mode 100644 model_training_nnn_tpu/jupyter_xla_test.py create mode 100644 model_training_nnn_tpu/quick_tpu_test.py create mode 100644 model_training_nnn_tpu/simple_tpu_model.py delete mode 100644 model_training_nnn_tpu/test_simple_model.py diff --git a/model_training_nnn_tpu/jupyter_xla_compatibility.py b/model_training_nnn_tpu/jupyter_xla_compatibility.py deleted file mode 100644 index ac275af..0000000 --- a/model_training_nnn_tpu/jupyter_xla_compatibility.py +++ /dev/null @@ -1,93 +0,0 @@ -# ==================== -# 单元格: XLA版本兼容性检查和修复 -# ==================== - -import torch -import torch.nn as nn -print("🔧 PyTorch XLA版本兼容性检查...") - -# 导入XLA -import torch_xla.core.xla_model as xm - -print("✅ PyTorch XLA导入成功!") - -# 定义兼容性函数 -def get_xla_world_size(): - """获取XLA world size,兼容不同版本""" - try: - return xm.xrt_world_size() - except AttributeError: - try: - return xm.get_world_size() - except AttributeError: - return 1 # 默认返回1 - -def get_xla_ordinal(): - """获取XLA ordinal,兼容不同版本""" - try: - return xm.get_ordinal() - except AttributeError: - return 0 # 默认返回0 - -def xla_mark_step(): - """XLA mark step,兼容不同版本""" - try: - xm.mark_step() - except AttributeError: - try: - xm.wait_device_ops() - except AttributeError: - pass # 如果都不可用,则跳过 - -def check_xla_device(): - """检查XLA设备状态""" - try: - device = xm.xla_device() - print(f"📱 XLA设备: {device}") - - world_size = get_xla_world_size() - ordinal = get_xla_ordinal() - - print(f"🌍 World Size: {world_size}") - print(f"🔢 Ordinal: {ordinal}") - - # 检测设备类型 - device_str = str(device) - if 'xla' in device_str and 'cpu' not in device_str: - print("✅ 检测到TPU设备") - return True, "TPU" - elif 'xla' in device_str and 'cpu' in device_str: - print("⚠️ XLA CPU模拟模式") - return True, "XLA_CPU" - else: - print("❌ 未检测到XLA设备") - return False, "CPU" - - except Exception as e: - print(f"❌ XLA设备检查失败: {e}") - return False, "ERROR" - -# 执行兼容性检查 -device_available, device_type = check_xla_device() - -if device_available: - print(f"✅ XLA环境正常,设备类型: {device_type}") - - # 测试基本XLA操作 - print("🧪 测试基本XLA操作...") - try: - device = xm.xla_device() - x = torch.randn(2, 2, device=device) - y = torch.matmul(x, x) - - # 测试同步函数 - xla_mark_step() - - print("✅ 基本XLA操作测试成功") - - except Exception as e: - print(f"❌ XLA操作测试失败: {e}") -else: - print("❌ XLA环境不可用") - -print("\n💡 兼容性检查完成,可以运行后续单元格") \ No newline at end of file diff --git a/model_training_nnn_tpu/jupyter_xla_monitor.py b/model_training_nnn_tpu/jupyter_xla_monitor.py deleted file mode 100644 index e194752..0000000 --- a/model_training_nnn_tpu/jupyter_xla_monitor.py +++ /dev/null @@ -1,167 +0,0 @@ -# ==================== -# 单元格2: XLA编译进度监控 -# ==================== - -import torch -import torch.nn as nn -import time -import threading -from IPython.display import display, HTML, clear_output -import ipywidgets as widgets - -# 导入XLA (环境变量已在单元格1中设置) -print("🚀 导入PyTorch XLA...") -import torch_xla.core.xla_model as xm - -print(f"✅ XLA导入成功!") -print(f" TPU设备: {xm.xla_device()}") - -# 兼容新版本PyTorch XLA -try: - world_size = xm.xrt_world_size() - print(f" World Size (旧API): {world_size}") -except AttributeError: - try: - world_size = xm.get_world_size() - print(f" World Size (新API): {world_size}") - except AttributeError: - print(" World Size: 无法获取 (可能在CPU模式)") - -# 检查XLA版本兼容性 -print("🔍 检查XLA API兼容性:") -api_available = [] -api_deprecated = [] - -# 检查各种API -test_apis = [ - ('xrt_world_size', 'xrt_world_size()'), - ('get_world_size', 'get_world_size()'), - ('mark_step', 'mark_step()'), - ('wait_device_ops', 'wait_device_ops()'), - ('get_ordinal', 'get_ordinal()'), - ('xla_device_count', 'xla_device_count()') -] - -for api_name, api_desc in test_apis: - if hasattr(xm, api_name): - api_available.append(api_desc) - else: - api_deprecated.append(api_desc) - -if api_available: - print(f" ✅ 可用API: {', '.join(api_available)}") -if api_deprecated: - print(f" ❌ 不可用API: {', '.join(api_deprecated)}") - -# 创建编译进度监控器 -class JupyterCompilationMonitor: - def __init__(self): - self.start_time = None - self.is_monitoring = False - - # 创建输出widget - self.output_widget = widgets.Output() - - # 创建进度条 - self.progress_bar = widgets.IntProgress( - value=0, - min=0, - max=100, - description='XLA编译:', - bar_style='info', - style={'bar_color': '#1f77b4'}, - orientation='horizontal' - ) - - # 创建状态标签 - self.status_label = widgets.HTML( - value="准备开始编译..." - ) - - # 创建CPU使用率显示 - self.cpu_label = widgets.HTML( - value="CPU: ---%" - ) - - self.memory_label = widgets.HTML( - value="内存: ---%" - ) - - # 组合界面 - self.monitor_box = widgets.VBox([ - widgets.HTML("

🔄 XLA编译监控

"), - self.progress_bar, - self.status_label, - widgets.HBox([self.cpu_label, self.memory_label]), - self.output_widget - ]) - - def start_monitoring(self): - """开始监控""" - self.start_time = time.time() - self.is_monitoring = True - - display(self.monitor_box) - - # 启动监控线程 - self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) - self.monitor_thread.start() - - def _monitor_loop(self): - """监控循环""" - while self.is_monitoring: - try: - elapsed = time.time() - self.start_time - minutes = int(elapsed // 60) - seconds = int(elapsed % 60) - - # 更新进度条 (模拟进度) - progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95% - self.progress_bar.value = progress - - # 获取系统资源 - cpu_percent = psutil.cpu_percent(interval=0.1) - memory_percent = psutil.virtual_memory().percent - - # 更新显示 - self.status_label.value = f"编译进行中... ⏱️ {minutes:02d}:{seconds:02d}" - self.cpu_label.value = f"🖥️ CPU: {cpu_percent:5.1f}%" - self.memory_label.value = f"💾 内存: {memory_percent:5.1f}%" - - # 检测是否编译完成 (CPU使用率突然下降) - if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高 - self.complete_monitoring() - break - - time.sleep(1) - - except Exception as e: - with self.output_widget: - print(f"监控错误: {e}") - break - - def complete_monitoring(self): - """完成监控""" - if self.is_monitoring: - self.is_monitoring = False - elapsed = time.time() - self.start_time - - self.progress_bar.value = 100 - self.progress_bar.bar_style = 'success' - self.status_label.value = f"✅ 编译完成! 总耗时: {elapsed:.2f}秒" - - with self.output_widget: - print(f"\n🎉 XLA编译成功完成!") - print(f"⏱️ 总耗时: {elapsed:.2f}秒") - if elapsed < 60: - print("✅ 编译速度正常") - elif elapsed < 300: - print("⚠️ 编译稍慢,但可接受") - else: - print("❌ 编译过慢,建议检查设置") - -# 创建全局监控器 -compilation_monitor = JupyterCompilationMonitor() - -print("✅ 编译监控器已准备就绪!") -print("💡 运行下一个单元格开始XLA编译测试") \ No newline at end of file diff --git a/model_training_nnn_tpu/jupyter_xla_test.py b/model_training_nnn_tpu/jupyter_xla_test.py deleted file mode 100644 index 2d45b9d..0000000 --- a/model_training_nnn_tpu/jupyter_xla_test.py +++ /dev/null @@ -1,78 +0,0 @@ -# ==================== -# 单元格3: 快速XLA编译测试 -# ==================== - -# 简化测试模型 -class QuickTestModel(nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(512, 128) - self.gru = nn.GRU(128, 64, batch_first=True) - self.linear2 = nn.Linear(64, 41) - - def forward(self, x): - x = torch.relu(self.linear1(x)) - x, _ = self.gru(x) - x = self.linear2(x) - return x - -print("🧪 开始XLA编译快速测试...") - -# 启动监控 -compilation_monitor.start_monitoring() - -try: - # 获取TPU设备 - device = xm.xla_device() - - # 创建小模型 - model = QuickTestModel().to(device) - param_count = sum(p.numel() for p in model.parameters()) - print(f"📊 测试模型参数: {param_count:,}") - - # 创建测试数据 (很小的batch) - x = torch.randn(2, 20, 512, device=device) - print(f"📥 输入数据形状: {x.shape}") - - print("🔄 开始首次前向传播 (触发XLA编译)...") - - # 首次前向传播 - 这会触发XLA编译 - with torch.no_grad(): - start_compile = time.time() - output = model(x) - compile_time = time.time() - start_compile - - print(f"✅ XLA编译完成!") - print(f"📤 输出形状: {output.shape}") - - # 完成监控 - compilation_monitor.complete_monitoring() - - # 测试编译后的性能 - print("\n🚀 测试编译后的执行速度...") - with torch.no_grad(): - start_exec = time.time() - for _ in range(10): - output = model(x) - avg_exec_time = (time.time() - start_exec) / 10 - - print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms") - - # 性能评估 - if compile_time < 30: - print("✅ 编译速度优秀! 可以尝试完整模型") - test_result = "excellent" - elif compile_time < 120: - print("✅ 编译速度良好! 建议使用简化配置") - test_result = "good" - else: - print("⚠️ 编译速度较慢,建议进一步优化") - test_result = "slow" - -except Exception as e: - compilation_monitor.complete_monitoring() - print(f"❌ 测试失败: {e}") - test_result = "failed" - -print(f"\n📋 测试结果: {test_result}") -print("💡 如果测试通过,可以运行下一个单元格进行完整训练") \ No newline at end of file diff --git a/model_training_nnn_tpu/quick_tpu_test.py b/model_training_nnn_tpu/quick_tpu_test.py new file mode 100644 index 0000000..2cb93fb --- /dev/null +++ b/model_training_nnn_tpu/quick_tpu_test.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行 +""" + +import os +import time +import torch +import torch.nn as nn + +# 设置环境变量 +os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true' +os.environ['XLA_USE_BF16'] = '1' + +import torch_xla.core.xla_model as xm + +def quick_test(): + """快速测试TPU是否工作正常""" + print("🚀 开始快速TPU测试...") + + try: + # 获取TPU设备 + device = xm.xla_device() + print(f"📱 TPU设备: {device}") + + # 创建简单模型 + model = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.GRU(256, 128, batch_first=True), + nn.Linear(128, 41) + ).to(device) + + print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}") + + # 创建测试数据 + x = torch.randn(8, 50, 512, device=device) + print(f"📥 输入形状: {x.shape}") + + # 测试前向传播 + print("🔄 测试前向传播...") + start_time = time.time() + + with torch.no_grad(): + if hasattr(model, '__getitem__'): + # 对于Sequential模型,手动处理GRU层 + x_proj = model[1](model[0](x)) # Linear + ReLU + gru_out, _ = model[2](x_proj) # GRU + output = model[3](gru_out) # Final Linear + else: + output = model(x) + + # 同步TPU操作 + xm.mark_step() + xm.wait_device_ops() + + forward_time = time.time() - start_time + print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}秒") + print(f"📤 输出形状: {output.shape}") + + # 测试反向传播 + print("🔄 测试反向传播...") + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + start_time = time.time() + + # 创建虚拟标签 + labels = torch.randint(0, 41, (8, 50), device=device) + criterion = nn.CrossEntropyLoss() + + # 前向传播 + if hasattr(model, '__getitem__'): + x_proj = model[1](model[0](x)) + gru_out, _ = model[2](x_proj) + output = model[3](gru_out) + else: + output = model(x) + + # 计算损失 + loss = criterion(output.view(-1, 41), labels.view(-1)) + + # 反向传播 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 同步TPU操作 + xm.mark_step() + xm.wait_device_ops() + + backward_time = time.time() - start_time + print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}秒") + print(f"🎯 损失值: {loss.item():.4f}") + + # 总结 + print(f"\n📈 性能总结:") + print(f" 前向传播: {forward_time:.3f}秒") + print(f" 反向传播: {backward_time:.3f}秒") + print(f" 总计: {forward_time + backward_time:.3f}秒") + + if (forward_time + backward_time) < 10: # 10秒内完成 + print("✅ TPU测试通过! 可以进行完整训练") + return True + else: + print("⚠️ TPU性能较慢,可能需要优化") + return False + + except Exception as e: + print(f"❌ TPU测试失败: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + print("=" * 50) + print("⚡ 快速TPU测试") + print("=" * 50) + + success = quick_test() + + if success: + print("\n🎉 测试成功! 现在可以运行:") + print(" python simple_tpu_model.py") + else: + print("\n❌ 测试失败,请检查TPU配置") + + print("=" * 50) \ No newline at end of file diff --git a/model_training_nnn_tpu/simple_tpu_model.py b/model_training_nnn_tpu/simple_tpu_model.py new file mode 100644 index 0000000..21089ed --- /dev/null +++ b/model_training_nnn_tpu/simple_tpu_model.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +""" +简单TPU模型训练和测试脚本 +基于大脑到文本数据的简化版本,专门为TPU优化 +""" + +import os +import time +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +from typing import Dict, Any, Tuple + +# 设置XLA环境变量 +os.environ['XLA_FLAGS'] = ( + '--xla_cpu_multi_thread_eigen=true ' + '--xla_cpu_enable_fast_math=true ' + f'--xla_force_host_platform_device_count={os.cpu_count()}' +) +os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count()) +os.environ['XLA_USE_BF16'] = '1' + +import torch_xla.core.xla_model as xm +import torch_xla.distributed.parallel_loader as pl + + +class SimpleBrainToTextModel(nn.Module): + """简化的大脑到文本模型 - TPU优化版本""" + + def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3): + super().__init__() + + # 输入处理层 + self.input_proj = nn.Linear(input_features, hidden_size) + self.input_dropout = nn.Dropout(0.2) + + # GRU层 - 使用较小的隐藏层以提高TPU效率 + self.gru = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=0.3 if num_layers > 1 else 0 + ) + + # 输出层 + self.output_proj = nn.Linear(hidden_size, num_classes) + + # 初始化权重 + self._init_weights() + + def _init_weights(self): + """初始化模型权重""" + for name, param in self.named_parameters(): + if 'weight' in name: + if 'gru' in name: + nn.init.orthogonal_(param) + else: + nn.init.xavier_uniform_(param) + elif 'bias' in name: + nn.init.zeros_(param) + + def forward(self, x): + """ + 前向传播 + Args: + x: (batch_size, seq_len, input_features) + Returns: + logits: (batch_size, seq_len, num_classes) + """ + # 输入投影 + x = torch.relu(self.input_proj(x)) + x = self.input_dropout(x) + + # GRU处理 + output, _ = self.gru(x) + + # 输出投影 + logits = self.output_proj(output) + + return logits + + +class SimpleDataGenerator: + """简单的数据生成器 - 模拟大脑信号数据""" + + def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41): + self.batch_size = batch_size + self.seq_len = seq_len + self.input_features = input_features + self.num_classes = num_classes + + def generate_batch(self, device): + """生成一个批次的模拟数据""" + # 生成模拟的神经信号数据 + features = torch.randn( + self.batch_size, self.seq_len, self.input_features, + device=device, dtype=torch.float32 + ) + + # 生成模拟的标签(音素序列) + labels = torch.randint( + 0, self.num_classes, + (self.batch_size, self.seq_len), + device=device + ) + + # 生成序列长度 + seq_lengths = torch.randint( + self.seq_len // 2, self.seq_len + 1, + (self.batch_size,), + device=device + ) + + return { + 'features': features, + 'labels': labels, + 'seq_lengths': seq_lengths + } + + +class SimpleTpuTrainer: + """简单的TPU训练器""" + + def __init__(self, model, device, learning_rate=0.001): + self.model = model + self.device = device + self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) + self.criterion = nn.CrossEntropyLoss(ignore_index=-1) + + # 数据生成器 + self.data_generator = SimpleDataGenerator() + + # 训练统计 + self.step = 0 + self.best_loss = float('inf') + + def train_step(self, batch): + """单个训练步骤""" + self.model.train() + self.optimizer.zero_grad() + + # 前向传播 + features = batch['features'] + labels = batch['labels'] + + logits = self.model(features) + + # 计算损失 - 重新调整形状以适应CrossEntropyLoss + batch_size, seq_len, num_classes = logits.shape + loss = self.criterion( + logits.reshape(-1, num_classes), + labels.reshape(-1) + ) + + # 反向传播 + loss.backward() + + # 梯度裁剪 + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + # 更新参数 + self.optimizer.step() + + return loss.item() + + def evaluate_step(self, batch): + """单个评估步骤""" + self.model.eval() + + with torch.no_grad(): + features = batch['features'] + labels = batch['labels'] + + logits = self.model(features) + + # 计算损失 + batch_size, seq_len, num_classes = logits.shape + loss = self.criterion( + logits.reshape(-1, num_classes), + labels.reshape(-1) + ) + + # 计算准确率 + predictions = torch.argmax(logits, dim=-1) + correct = (predictions == labels).float() + accuracy = correct.mean() + + return loss.item(), accuracy.item() + + def train(self, num_steps=1000, eval_every=100, save_every=500): + """训练模型""" + print(f"🚀 开始TPU训练 - 设备: {self.device}") + print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}") + + train_losses = [] + eval_losses = [] + eval_accuracies = [] + + start_time = time.time() + + for step in range(num_steps): + # 生成训练数据 + train_batch = self.data_generator.generate_batch(self.device) + + # 训练步骤 + train_loss = self.train_step(train_batch) + train_losses.append(train_loss) + + # XLA同步 + if step % 10 == 0: # 每10步同步一次以提高效率 + xm.mark_step() + + # 评估 + if step % eval_every == 0: + eval_batch = self.data_generator.generate_batch(self.device) + eval_loss, eval_acc = self.evaluate_step(eval_batch) + eval_losses.append(eval_loss) + eval_accuracies.append(eval_acc) + + # 同步XLA操作以获得准确的时间 + xm.mark_step() + xm.wait_device_ops() + + current_time = time.time() + elapsed = current_time - start_time + + print(f"步骤 {step:4d}/{num_steps} | " + f"训练损失: {train_loss:.4f} | " + f"验证损失: {eval_loss:.4f} | " + f"验证准确率: {eval_acc:.4f} | " + f"耗时: {elapsed:.1f}s") + + # 保存最佳模型 + if eval_loss < self.best_loss: + self.best_loss = eval_loss + print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}") + + # 定期保存 + if step > 0 and step % save_every == 0: + self.save_checkpoint(f"checkpoint_step_{step}.pt") + + # 最终同步 + xm.mark_step() + xm.wait_device_ops() + + total_time = time.time() - start_time + print(f"\n✅ 训练完成!") + print(f"⏱️ 总耗时: {total_time:.1f}秒") + print(f"🎯 最终训练损失: {train_losses[-1]:.4f}") + if eval_losses: + print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}") + print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}") + + return { + 'train_losses': train_losses, + 'eval_losses': eval_losses, + 'eval_accuracies': eval_accuracies, + 'total_time': total_time + } + + def save_checkpoint(self, filename): + """保存检查点""" + checkpoint = { + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'step': self.step, + 'best_loss': self.best_loss, + } + + # 在TPU上需要先移动到CPU再保存 + if 'xla' in str(self.device): + checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu')) + + torch.save(checkpoint, filename) + print(f"💾 保存检查点: {filename}") + + def load_checkpoint(self, filename): + """加载检查点""" + checkpoint = torch.load(filename, map_location='cpu') + + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.step = checkpoint['step'] + self.best_loss = checkpoint['best_loss'] + + print(f"📂 加载检查点: {filename}") + print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}") + + +def test_simple_inference(): + """测试简单推理""" + print("\n🧪 测试简单推理...") + + device = xm.xla_device() + + # 创建模型 + model = SimpleBrainToTextModel().to(device) + + # 创建测试数据 + batch_size = 4 + seq_len = 50 + test_input = torch.randn(batch_size, seq_len, 512, device=device) + + # 推理 + model.eval() + with torch.no_grad(): + start_time = time.time() + output = model(test_input) + xm.mark_step() + xm.wait_device_ops() + inference_time = time.time() - start_time + + print(f"✅ 推理完成!") + print(f" 输入形状: {test_input.shape}") + print(f" 输出形状: {output.shape}") + print(f" 推理时间: {inference_time:.4f}秒") + + return True + + +def main(): + """主函数""" + print("=" * 60) + print("🧠 简单TPU大脑到文本模型训练") + print("=" * 60) + + try: + # 检查TPU设备 + device = xm.xla_device() + print(f"📱 使用设备: {device}") + + # 创建模型 + model = SimpleBrainToTextModel( + input_features=512, + hidden_size=256, + num_classes=41, + num_layers=3 + ).to(device) + + # 创建训练器 + trainer = SimpleTpuTrainer(model, device, learning_rate=0.001) + + # 开始训练 + results = trainer.train( + num_steps=1000, + eval_every=100, + save_every=500 + ) + + # 保存最终模型 + trainer.save_checkpoint("final_simple_model.pt") + + # 测试推理 + test_simple_inference() + + print("\n🎉 所有测试完成!") + + except Exception as e: + print(f"❌ 训练失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model_training_nnn_tpu/test_simple_model.py b/model_training_nnn_tpu/test_simple_model.py deleted file mode 100644 index 4a18c5a..0000000 --- a/model_training_nnn_tpu/test_simple_model.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python3 -""" -简化模型测试脚本 - 验证XLA编译是否正常工作 -""" - -import os -import time -import torch -import torch.nn as nn - -# 设置XLA环境变量(必须在导入torch_xla之前) -os.environ['XLA_FLAGS'] = ( - '--xla_cpu_multi_thread_eigen=true ' - '--xla_cpu_enable_fast_math=true ' - f'--xla_force_host_platform_device_count={os.cpu_count()}' -) -os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count()) -os.environ['XLA_USE_BF16'] = '1' - -print(f"🔧 XLA环境变量设置:") -print(f" CPU核心数: {os.cpu_count()}") -print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}") -print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}") - -import torch_xla.core.xla_model as xm - -class SimpleModel(nn.Module): - """简化的测试模型""" - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(512, 256) - self.gru = nn.GRU(256, 128, batch_first=True) - self.linear2 = nn.Linear(128, 41) # 41个音素类别 - - def forward(self, x): - x = torch.relu(self.linear1(x)) - x, _ = self.gru(x) - x = self.linear2(x) - return x - -def test_xla_compilation(): - """测试XLA编译速度""" - print("\n🚀 开始简化模型XLA编译测试...") - - # 检查TPU设备 - device = xm.xla_device() - print(f"📱 TPU设备: {device}") - print(f"🌍 TPU World Size: {xm.xrt_world_size()}") - - # 创建简化模型 - model = SimpleModel().to(device) - print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}") - - # 创建测试数据 - batch_size = 8 # 小批次 - seq_len = 100 # 短序列 - x = torch.randn(batch_size, seq_len, 512, device=device) - - print(f"📥 输入形状: {x.shape}") - - # 首次前向传播 - 触发XLA编译 - print(f"🔄 开始首次前向传播 (XLA编译)...") - start_time = time.time() - - with torch.no_grad(): - output = model(x) - - compile_time = time.time() - start_time - print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}秒") - print(f"📤 输出形状: {output.shape}") - - # 再次前向传播 - 使用编译后的图 - print(f"🔄 第二次前向传播 (使用编译后的图)...") - start_time = time.time() - - with torch.no_grad(): - output2 = model(x) - - execution_time = time.time() - start_time - print(f"⚡ 执行完成! 耗时: {execution_time:.4f}秒") - - # 性能对比 - speedup = compile_time / execution_time if execution_time > 0 else float('inf') - print(f"\n📈 性能分析:") - print(f" 编译时间: {compile_time:.2f}秒") - print(f" 执行时间: {execution_time:.4f}秒") - print(f" 加速比: {speedup:.1f}x") - - if compile_time < 60: # 1分钟内编译完成 - print("✅ XLA编译正常!") - return True - else: - print("❌ XLA编译过慢,可能有问题") - return False - -def test_training_step(): - """测试训练步骤""" - print("\n🎯 测试简化训练步骤...") - - device = xm.xla_device() - model = SimpleModel().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss() - - # 创建训练数据 - x = torch.randn(4, 50, 512, device=device) - labels = torch.randint(0, 41, (4, 50), device=device) - - print(f"🔄 开始训练步骤 (包含反向传播)...") - start_time = time.time() - - # 前向传播 - outputs = model(x) - - # 计算损失 - loss = criterion(outputs.view(-1, 41), labels.view(-1)) - - # 反向传播 - optimizer.zero_grad() - loss.backward() - optimizer.step() - - step_time = time.time() - start_time - print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}") - - return step_time < 120 # 2分钟内完成 - -def main(): - print("=" * 60) - print("🧪 XLA编译快速测试") - print("=" * 60) - - try: - # 测试1: 简单模型编译 - compilation_ok = test_xla_compilation() - - if compilation_ok: - # 测试2: 训练步骤 - training_ok = test_training_step() - - if training_ok: - print("\n✅ 所有测试通过! 可以尝试完整模型训练") - print("💡 建议:") - print(" 1. 确保有足够内存 (32GB+)") - print(" 2. 减小batch_size (比如从32改为16)") - print(" 3. 使用gradient_accumulation_steps补偿") - else: - print("\n⚠️ 训练步骤较慢,建议优化") - else: - print("\n❌ XLA编译有问题,需要检查环境") - - except Exception as e: - print(f"\n💥 测试失败: {e}") - print("💡 可能的问题:") - print(" - TPU资源不可用") - print(" - PyTorch XLA安装问题") - print(" - 内存不足") - - print("=" * 60) - -if __name__ == "__main__": - main() \ No newline at end of file