124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
|
|
# ====================
|
|||
|
|
# 单元格4: 逐步调试完整模型编译
|
|||
|
|
# ====================
|
|||
|
|
|
|||
|
|
# 如果单元格3测试通过,运行这个单元格
|
|||
|
|
print("🔧 逐步测试完整TripleGRUDecoder模型...")
|
|||
|
|
|
|||
|
|
# 导入完整模型
|
|||
|
|
import sys
|
|||
|
|
sys.path.append('.') # 确保能导入本地模块
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from rnn_model import TripleGRUDecoder
|
|||
|
|
print("✅ TripleGRUDecoder导入成功")
|
|||
|
|
except ImportError as e:
|
|||
|
|
print(f"❌ 模型导入失败: {e}")
|
|||
|
|
print("请确保rnn_model.py在当前目录中")
|
|||
|
|
|
|||
|
|
# 分阶段测试
|
|||
|
|
def test_model_compilation_stages():
|
|||
|
|
"""分阶段测试模型编译"""
|
|||
|
|
device = xm.xla_device()
|
|||
|
|
|
|||
|
|
# 阶段1: 测试NoiseModel单独编译
|
|||
|
|
print("\n🔬 阶段1: 测试NoiseModel...")
|
|||
|
|
try:
|
|||
|
|
from rnn_model import NoiseModel
|
|||
|
|
noise_model = NoiseModel(
|
|||
|
|
neural_dim=512,
|
|||
|
|
n_units=384, # 减小参数
|
|||
|
|
n_days=4,
|
|||
|
|
patch_size=8 # 减小patch size
|
|||
|
|
).to(device)
|
|||
|
|
|
|||
|
|
x = torch.randn(2, 20, 512, device=device)
|
|||
|
|
day_idx = torch.tensor([0, 1], device=device)
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output, states = noise_model(x, day_idx)
|
|||
|
|
compile_time = time.time() - start_time
|
|||
|
|
|
|||
|
|
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}秒")
|
|||
|
|
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
|
|||
|
|
|
|||
|
|
return True, compile_time
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ NoiseModel编译失败: {e}")
|
|||
|
|
return False, 0
|
|||
|
|
|
|||
|
|
# 阶段2: 测试CleanSpeechModel
|
|||
|
|
print("\n🔬 阶段2: 测试CleanSpeechModel...")
|
|||
|
|
try:
|
|||
|
|
from rnn_model import CleanSpeechModel
|
|||
|
|
clean_model = CleanSpeechModel(
|
|||
|
|
neural_dim=512,
|
|||
|
|
n_units=384,
|
|||
|
|
n_days=4,
|
|||
|
|
n_classes=41,
|
|||
|
|
patch_size=8
|
|||
|
|
).to(device)
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output = clean_model(x, day_idx)
|
|||
|
|
compile_time = time.time() - start_time
|
|||
|
|
|
|||
|
|
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}秒")
|
|||
|
|
return True, compile_time
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ CleanSpeechModel编译失败: {e}")
|
|||
|
|
return False, 0
|
|||
|
|
|
|||
|
|
# 阶段3: 测试完整TripleGRUDecoder
|
|||
|
|
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
|
|||
|
|
try:
|
|||
|
|
model = TripleGRUDecoder(
|
|||
|
|
neural_dim=512,
|
|||
|
|
n_units=384, # 比原来的768小
|
|||
|
|
n_days=4, # 减少天数
|
|||
|
|
n_classes=41,
|
|||
|
|
patch_size=8 # 减小patch size
|
|||
|
|
).to(device)
|
|||
|
|
|
|||
|
|
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
|||
|
|
|
|||
|
|
# 启动编译监控
|
|||
|
|
compilation_monitor.start_monitoring()
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
with torch.no_grad():
|
|||
|
|
# 测试inference模式
|
|||
|
|
logits = model(x, day_idx, None, False, 'inference')
|
|||
|
|
compile_time = time.time() - start_time
|
|||
|
|
|
|||
|
|
compilation_monitor.complete_monitoring()
|
|||
|
|
|
|||
|
|
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}秒")
|
|||
|
|
print(f"📤 输出形状: {logits.shape}")
|
|||
|
|
|
|||
|
|
return True, compile_time
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
compilation_monitor.complete_monitoring()
|
|||
|
|
print(f"❌ TripleGRUDecoder编译失败: {e}")
|
|||
|
|
return False, 0
|
|||
|
|
|
|||
|
|
# 运行分阶段测试
|
|||
|
|
stage_results = test_model_compilation_stages()
|
|||
|
|
|
|||
|
|
if stage_results:
|
|||
|
|
print(f"\n🎉 所有编译测试完成!")
|
|||
|
|
print("💡 下一步可以尝试:")
|
|||
|
|
print(" 1. 使用简化配置进行训练")
|
|||
|
|
print(" 2. 逐步增加模型复杂度")
|
|||
|
|
print(" 3. 监控TPU资源使用情况")
|
|||
|
|
else:
|
|||
|
|
print(f"\n⚠️ 编译测试发现问题")
|
|||
|
|
print("💡 建议:")
|
|||
|
|
print(" 1. 进一步减小模型参数")
|
|||
|
|
print(" 2. 检查内存使用情况")
|
|||
|
|
print(" 3. 使用CPU模式进行调试")
|