Files
b2txt25/model_training_nnn_tpu/jupyter_debug_full_model.py

124 lines
3.8 KiB
Python
Raw Normal View History

2025-10-15 14:26:11 +08:00
# ====================
# 单元格4: 逐步调试完整模型编译
# ====================
# 如果单元格3测试通过运行这个单元格
print("🔧 逐步测试完整TripleGRUDecoder模型...")
# 导入完整模型
import sys
sys.path.append('.') # 确保能导入本地模块
try:
from rnn_model import TripleGRUDecoder
print("✅ TripleGRUDecoder导入成功")
except ImportError as e:
print(f"❌ 模型导入失败: {e}")
print("请确保rnn_model.py在当前目录中")
# 分阶段测试
def test_model_compilation_stages():
"""分阶段测试模型编译"""
device = xm.xla_device()
# 阶段1: 测试NoiseModel单独编译
print("\n🔬 阶段1: 测试NoiseModel...")
try:
from rnn_model import NoiseModel
noise_model = NoiseModel(
neural_dim=512,
n_units=384, # 减小参数
n_days=4,
patch_size=8 # 减小patch size
).to(device)
x = torch.randn(2, 20, 512, device=device)
day_idx = torch.tensor([0, 1], device=device)
start_time = time.time()
with torch.no_grad():
output, states = noise_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}")
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
return True, compile_time
except Exception as e:
print(f"❌ NoiseModel编译失败: {e}")
return False, 0
# 阶段2: 测试CleanSpeechModel
print("\n🔬 阶段2: 测试CleanSpeechModel...")
try:
from rnn_model import CleanSpeechModel
clean_model = CleanSpeechModel(
neural_dim=512,
n_units=384,
n_days=4,
n_classes=41,
patch_size=8
).to(device)
start_time = time.time()
with torch.no_grad():
output = clean_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}")
return True, compile_time
except Exception as e:
print(f"❌ CleanSpeechModel编译失败: {e}")
return False, 0
# 阶段3: 测试完整TripleGRUDecoder
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
try:
model = TripleGRUDecoder(
neural_dim=512,
n_units=384, # 比原来的768小
n_days=4, # 减少天数
n_classes=41,
patch_size=8 # 减小patch size
).to(device)
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 启动编译监控
compilation_monitor.start_monitoring()
start_time = time.time()
with torch.no_grad():
# 测试inference模式
logits = model(x, day_idx, None, False, 'inference')
compile_time = time.time() - start_time
compilation_monitor.complete_monitoring()
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}")
print(f"📤 输出形状: {logits.shape}")
return True, compile_time
except Exception as e:
compilation_monitor.complete_monitoring()
print(f"❌ TripleGRUDecoder编译失败: {e}")
return False, 0
# 运行分阶段测试
stage_results = test_model_compilation_stages()
if stage_results:
print(f"\n🎉 所有编译测试完成!")
print("💡 下一步可以尝试:")
print(" 1. 使用简化配置进行训练")
print(" 2. 逐步增加模型复杂度")
print(" 3. 监控TPU资源使用情况")
else:
print(f"\n⚠️ 编译测试发现问题")
print("💡 建议:")
print(" 1. 进一步减小模型参数")
print(" 2. 检查内存使用情况")
print(" 3. 使用CPU模式进行调试")