Remove setup script, TPU memory monitor, and training model script
- Deleted `setup_tensorflow_tpu.sh` which was responsible for setting up the TensorFlow environment on TPU v5e-8. - Removed `tpu_memory_monitor.py`, a tool for monitoring TPU memory usage during training. - Eliminated `train_model.py`, the script for training the Brain-to-Text RNN model.
This commit is contained in:
@@ -1,315 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
使用AMP的TPU训练脚本
|
||||
正确处理混合精度训练,避免dtype不匹配问题
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# 设置AMP相关的环境变量
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true'
|
||||
)
|
||||
os.environ['XLA_USE_BF16'] = '1' # 启用bf16
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
import torch_xla.amp as xla_amp
|
||||
|
||||
|
||||
class AMPModel(nn.Module):
|
||||
"""支持AMP的简单模型"""
|
||||
|
||||
def __init__(self, input_size=784, hidden_size=512, num_classes=10):
|
||||
super(AMPModel, self).__init__()
|
||||
|
||||
self.network = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_size, hidden_size // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_size // 2, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# 展平输入
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.network(x)
|
||||
|
||||
|
||||
class AMPTrainer:
|
||||
"""AMP训练器"""
|
||||
|
||||
def __init__(self, model, device, learning_rate=0.001):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# 初始化AMP scaler
|
||||
self.scaler = xla_amp.GradScaler()
|
||||
|
||||
print(f"✅ AMP训练器初始化完成")
|
||||
print(f" 设备: {device}")
|
||||
print(f" 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
def train_step(self, data, target):
|
||||
"""单个AMP训练步骤"""
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 使用autocast进行混合精度前向传播
|
||||
with xla_amp.autocast():
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# 使用scaler进行反向传播
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 梯度裁剪(可选)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
|
||||
# 更新参数
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# 计算准确率
|
||||
pred = output.argmax(dim=1)
|
||||
correct = pred.eq(target).sum().item()
|
||||
accuracy = correct / target.size(0)
|
||||
|
||||
return loss.item(), accuracy
|
||||
|
||||
def evaluate_step(self, data, target):
|
||||
"""单个评估步骤"""
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with xla_amp.autocast():
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
pred = output.argmax(dim=1)
|
||||
correct = pred.eq(target).sum().item()
|
||||
accuracy = correct / target.size(0)
|
||||
|
||||
return loss.item(), accuracy
|
||||
|
||||
|
||||
def get_mnist_loaders(batch_size=64):
|
||||
"""获取MNIST数据加载器"""
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
test_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
def train_with_amp():
|
||||
"""使用AMP进行训练"""
|
||||
print("🚀 开始AMP TPU训练...")
|
||||
|
||||
# 获取设备
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 创建模型
|
||||
model = AMPModel(input_size=784, hidden_size=512, num_classes=10).to(device)
|
||||
|
||||
# 创建训练器
|
||||
trainer = AMPTrainer(model, device, learning_rate=0.001)
|
||||
|
||||
# 获取数据
|
||||
print("📥 加载MNIST数据...")
|
||||
train_loader, test_loader = get_mnist_loaders(batch_size=64)
|
||||
|
||||
# 使用XLA并行加载器
|
||||
train_device_loader = pl.MpDeviceLoader(train_loader, device)
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
print("🎯 开始AMP训练...")
|
||||
|
||||
# 训练循环
|
||||
num_epochs = 2
|
||||
train_losses = []
|
||||
train_accuracies = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\n📊 Epoch {epoch + 1}/{num_epochs}")
|
||||
|
||||
epoch_start = time.time()
|
||||
epoch_loss = 0.0
|
||||
epoch_acc = 0.0
|
||||
num_batches = 0
|
||||
max_batches_per_epoch = 200 # 限制每个epoch的批次数
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_device_loader):
|
||||
if batch_idx >= max_batches_per_epoch:
|
||||
break
|
||||
|
||||
# 训练步骤
|
||||
loss, accuracy = trainer.train_step(data, target)
|
||||
|
||||
epoch_loss += loss
|
||||
epoch_acc += accuracy
|
||||
num_batches += 1
|
||||
|
||||
# 每20个批次同步一次
|
||||
if batch_idx % 20 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
avg_loss = epoch_loss / num_batches
|
||||
avg_acc = epoch_acc / num_batches * 100
|
||||
|
||||
print(f" 批次 {batch_idx:3d}/{max_batches_per_epoch} | "
|
||||
f"损失: {avg_loss:.4f} | "
|
||||
f"准确率: {avg_acc:.2f}%")
|
||||
|
||||
# Epoch结束同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
final_loss = epoch_loss / num_batches
|
||||
final_acc = epoch_acc / num_batches * 100
|
||||
|
||||
train_losses.append(final_loss)
|
||||
train_accuracies.append(final_acc)
|
||||
|
||||
print(f"✅ Epoch {epoch + 1} 完成 | "
|
||||
f"耗时: {epoch_time:.2f}s | "
|
||||
f"平均损失: {final_loss:.4f} | "
|
||||
f"平均准确率: {final_acc:.2f}%")
|
||||
|
||||
return trainer, train_losses, train_accuracies
|
||||
|
||||
|
||||
def test_with_amp(trainer):
|
||||
"""使用AMP进行测试"""
|
||||
print("\n🧪 开始AMP测试...")
|
||||
|
||||
device = xm.xla_device()
|
||||
_, test_loader = get_mnist_loaders(batch_size=64)
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
total_loss = 0.0
|
||||
total_acc = 0.0
|
||||
num_batches = 0
|
||||
max_test_batches = 100
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, (data, target) in enumerate(test_device_loader):
|
||||
if batch_idx >= max_test_batches:
|
||||
break
|
||||
|
||||
loss, accuracy = trainer.evaluate_step(data, target)
|
||||
|
||||
total_loss += loss
|
||||
total_acc += accuracy
|
||||
num_batches += 1
|
||||
|
||||
if batch_idx % 20 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
test_time = time.time() - start_time
|
||||
avg_loss = total_loss / num_batches
|
||||
avg_acc = total_acc / num_batches * 100
|
||||
|
||||
print(f"✅ 测试完成!")
|
||||
print(f"⏱️ 测试时间: {test_time:.2f}秒")
|
||||
print(f"🎯 测试损失: {avg_loss:.4f}")
|
||||
print(f"🎯 测试准确率: {avg_acc:.2f}%")
|
||||
|
||||
return avg_loss, avg_acc
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("⚡ AMP TPU训练示例")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 训练
|
||||
trainer, train_losses, train_accuracies = train_with_amp()
|
||||
|
||||
# 测试
|
||||
test_loss, test_acc = test_with_amp(trainer)
|
||||
|
||||
# 保存模型
|
||||
print("\n💾 保存模型...")
|
||||
model_cpu = trainer.model.cpu()
|
||||
torch.save({
|
||||
'model_state_dict': model_cpu.state_dict(),
|
||||
'train_losses': train_losses,
|
||||
'train_accuracies': train_accuracies,
|
||||
'test_loss': test_loss,
|
||||
'test_accuracy': test_acc
|
||||
}, 'amp_mnist_model.pth')
|
||||
print("✅ 模型已保存到 amp_mnist_model.pth")
|
||||
|
||||
print("\n🎉 AMP训练完成!")
|
||||
print(f"📊 最终训练准确率: {train_accuracies[-1]:.2f}%")
|
||||
print(f"📊 测试准确率: {test_acc:.2f}%")
|
||||
|
||||
if train_accuracies[-1] > 85 and test_acc > 80:
|
||||
print("✅ AMP训练成功! 模型性能优秀")
|
||||
else:
|
||||
print("⚠️ 模型性能一般,但AMP功能正常")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ AMP训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n💡 故障排除建议:")
|
||||
print(" 1. 确保PyTorch XLA版本支持AMP")
|
||||
print(" 2. 检查TPU资源是否充足")
|
||||
print(" 3. 尝试减小batch_size")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,403 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控
|
||||
适用于TPU v5e-8环境
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
def monitor_tpu_during_training():
|
||||
"""训练过程中的TPU实时内存和MXU监控"""
|
||||
print("📊 TPU训练实时监控工具")
|
||||
print("=" * 50)
|
||||
|
||||
# 获取TPU设备
|
||||
try:
|
||||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||
print(f"📍 发现TPU设备: {len(tpu_devices)}个")
|
||||
if not tpu_devices:
|
||||
print("❌ 未发现TPU设备")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"❌ 无法检测TPU设备: {e}")
|
||||
return
|
||||
|
||||
def get_detailed_memory_snapshot():
|
||||
"""获取详细的内存快照,包含所有核心信息"""
|
||||
snapshot = {}
|
||||
total_current = 0
|
||||
total_peak = 0
|
||||
active_cores = 0
|
||||
core_details = []
|
||||
|
||||
for i, device in enumerate(tpu_devices):
|
||||
try:
|
||||
memory_info = tf.config.experimental.get_memory_info(device.name)
|
||||
if memory_info and 'current' in memory_info:
|
||||
current_mb = memory_info['current'] // (1024 * 1024)
|
||||
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
|
||||
|
||||
if current_mb > 1: # >1MB算活跃
|
||||
active_cores += 1
|
||||
total_current += current_mb
|
||||
total_peak += peak_mb
|
||||
core_details.append(f"Core{i}:{current_mb}MB")
|
||||
|
||||
snapshot[f'core_{i}'] = {
|
||||
'current': current_mb,
|
||||
'peak': peak_mb,
|
||||
'device': device.name
|
||||
}
|
||||
else:
|
||||
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name}
|
||||
except Exception as e:
|
||||
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)}
|
||||
|
||||
snapshot['summary'] = {
|
||||
'total_current': total_current,
|
||||
'total_peak': total_peak,
|
||||
'active_cores': active_cores,
|
||||
'total_cores': len(tpu_devices),
|
||||
'core_details': core_details
|
||||
}
|
||||
return snapshot
|
||||
|
||||
def test_mxu_performance():
|
||||
"""测试MXU性能和计算能力"""
|
||||
print("\n🧮 MXU计算性能测试:")
|
||||
|
||||
mxu_results = []
|
||||
try:
|
||||
with tf.device(tpu_devices[0].name):
|
||||
# 测试不同规模的矩阵运算
|
||||
test_configs = [
|
||||
(2000, "2K×2K", tf.bfloat16),
|
||||
(4000, "4K×4K", tf.bfloat16),
|
||||
(6000, "6K×6K", tf.bfloat16),
|
||||
]
|
||||
|
||||
for size, desc, dtype in test_configs:
|
||||
try:
|
||||
# 获取测试前内存
|
||||
pre_mem = get_detailed_memory_snapshot()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建矩阵并执行MXU密集型运算
|
||||
matrix_a = tf.random.normal([size, size], dtype=dtype)
|
||||
matrix_b = tf.random.normal([size, size], dtype=dtype)
|
||||
|
||||
@tf.function
|
||||
def mxu_operation():
|
||||
# 连续矩阵运算,充分使用MXU
|
||||
result = tf.matmul(matrix_a, matrix_b)
|
||||
result = tf.matmul(result, matrix_a)
|
||||
return tf.reduce_sum(result)
|
||||
|
||||
result = mxu_operation()
|
||||
# 使用result确保计算被执行
|
||||
_ = result.numpy()
|
||||
end_time = time.time()
|
||||
|
||||
# 获取测试后内存
|
||||
post_mem = get_detailed_memory_snapshot()
|
||||
|
||||
duration = end_time - start_time
|
||||
# 计算FLOPS (两次矩阵乘法)
|
||||
flops = 2 * (2 * size**3)
|
||||
tflops = flops / duration / 1e12
|
||||
|
||||
memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current']
|
||||
|
||||
print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB")
|
||||
|
||||
mxu_results.append({
|
||||
'size': size,
|
||||
'tflops': tflops,
|
||||
'duration': duration,
|
||||
'memory_used': memory_used
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f" {desc}: 测试失败 - {str(e)[:50]}")
|
||||
|
||||
# MXU性能分析
|
||||
if mxu_results:
|
||||
max_tflops = max(r['tflops'] for r in mxu_results)
|
||||
total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0)
|
||||
|
||||
# TPU v5e-8单核理论性能
|
||||
theoretical_tflops = 275 # bf16峰值性能
|
||||
efficiency = (max_tflops / theoretical_tflops) * 100
|
||||
|
||||
print(f"\n 📊 MXU性能汇总:")
|
||||
print(f" 峰值性能: {max_tflops:.1f} TFLOPS")
|
||||
print(f" 理论峰值: {theoretical_tflops} TFLOPS")
|
||||
print(f" MXU效率: {efficiency:.1f}%")
|
||||
print(f" 计算内存占用: {total_memory}MB")
|
||||
|
||||
if efficiency > 80:
|
||||
status = "🟢 优秀"
|
||||
elif efficiency > 50:
|
||||
status = "🟡 良好"
|
||||
elif efficiency > 20:
|
||||
status = "🟠 中等"
|
||||
else:
|
||||
status = "🔴 需优化"
|
||||
|
||||
print(f" 性能评级: {status}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" MXU测试失败: {e}")
|
||||
|
||||
try:
|
||||
print("🎯 开始TPU训练监控...")
|
||||
|
||||
# 1. 获取初始状态
|
||||
print("\n📸 初始TPU状态:")
|
||||
baseline_snapshot = get_detailed_memory_snapshot()
|
||||
|
||||
print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB")
|
||||
print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}")
|
||||
|
||||
# 显示各核心详细状态
|
||||
for i in range(len(tpu_devices)):
|
||||
core = baseline_snapshot[f'core_{i}']
|
||||
if core['current'] > 0 or core['peak'] > 0:
|
||||
print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB")
|
||||
|
||||
# 2. MXU性能基准测试
|
||||
test_mxu_performance()
|
||||
|
||||
# 3. 创建分布式策略 - 使用项目验证的TPU初始化代码
|
||||
print(f"\n🔄 使用项目标准TPU初始化...")
|
||||
try:
|
||||
# 使用项目里验证过的TPU初始化代码
|
||||
# 禁用GPU避免冲突
|
||||
try:
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
print("🚫 GPU已禁用,避免CUDA冲突")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 使用标准的TPU初始化流程
|
||||
print("🚀 使用官方TensorFlow TPU初始化...")
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
|
||||
# 验证TPU设备
|
||||
tpu_devices_check = tf.config.list_logical_devices('TPU')
|
||||
print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备")
|
||||
|
||||
# 创建TPU策略
|
||||
strategy = tf.distribute.TPUStrategy(resolver)
|
||||
print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本")
|
||||
use_distributed = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 分布式策略失败: {str(e)[:80]}")
|
||||
print(" 将使用单设备模拟")
|
||||
use_distributed = False
|
||||
|
||||
# 4. 模拟Brain-to-Text训练场景
|
||||
print(f"\n🧠 模拟Brain-to-Text训练场景...")
|
||||
|
||||
if use_distributed:
|
||||
# 分布式训练模拟
|
||||
with strategy.scope():
|
||||
print("📦 创建分布式模型参数...")
|
||||
|
||||
# 创建接近真实Brain-to-Text模型的参数 (修复维度匹配)
|
||||
model_components = {
|
||||
# GRU层权重:第一层接收512维输入,后续层接收256维
|
||||
'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'),
|
||||
'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'),
|
||||
'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'),
|
||||
'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'),
|
||||
# 添加day-specific层模拟 (输入512维,输出512维)
|
||||
'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)]
|
||||
}
|
||||
|
||||
# 检查模型加载后内存
|
||||
after_model = get_detailed_memory_snapshot()
|
||||
model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current']
|
||||
print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心")
|
||||
|
||||
# 训练循环模拟
|
||||
print(f"\n🔄 开始训练循环监控...")
|
||||
|
||||
for step in range(10):
|
||||
step_start_time = time.time()
|
||||
|
||||
@tf.function
|
||||
def distributed_training_step():
|
||||
# 模拟真实训练数据大小
|
||||
batch_size = 32
|
||||
seq_length = 1000
|
||||
features = 512
|
||||
|
||||
# 输入数据
|
||||
neural_data = tf.random.normal([batch_size, seq_length, features])
|
||||
targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32)
|
||||
|
||||
# 模拟前向传播
|
||||
x = neural_data
|
||||
|
||||
# Day-specific transformation (简化版本避免复杂的维度操作)
|
||||
# 模拟day-specific变换:对每个时间步应用相同变换
|
||||
day_weight = model_components['day_weights'][0] # 简化:使用第一个day权重
|
||||
# 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512]
|
||||
x = tf.matmul(x, day_weight)
|
||||
|
||||
# 为CTC损失添加目标使用(模拟)
|
||||
target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1)
|
||||
# 简化的CTC相关计算
|
||||
batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32))
|
||||
|
||||
# GRU layers simulation
|
||||
for i in range(3):
|
||||
layer_name = f'gru_layer_{i}'
|
||||
weight = model_components[layer_name]
|
||||
|
||||
# 处理张量维度:第一层从3D输入,后续层从2D输入
|
||||
if i == 0:
|
||||
# 第一层:取最后时间步 [batch, seq, features] -> [batch, features]
|
||||
if len(x.shape) == 3:
|
||||
x = x[:, -1, :] # 取最后时间步
|
||||
x = tf.nn.tanh(tf.matmul(x, weight))
|
||||
else:
|
||||
# 后续层:直接处理2D张量 [batch, features] -> [batch, features]
|
||||
x = tf.nn.tanh(tf.matmul(x, weight))
|
||||
|
||||
# 输出投影
|
||||
logits = tf.matmul(x, model_components['output_projection'])
|
||||
|
||||
# CTC loss模拟(使用batch_loss_weight作为权重)
|
||||
base_loss = tf.reduce_mean(tf.square(logits))
|
||||
loss = base_loss * batch_loss_weight
|
||||
|
||||
return loss
|
||||
|
||||
# 执行训练步骤
|
||||
per_replica_loss = strategy.run(distributed_training_step)
|
||||
# 聚合分布式结果
|
||||
loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None)
|
||||
step_duration = time.time() - step_start_time
|
||||
|
||||
# 获取当前内存状态
|
||||
current_snapshot = get_detailed_memory_snapshot()
|
||||
step_memory = current_snapshot['summary']['total_current']
|
||||
memory_delta = step_memory - baseline_snapshot['summary']['total_current']
|
||||
|
||||
# 显示详细训练状态
|
||||
active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)"
|
||||
|
||||
print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, "
|
||||
f"时间={step_duration:.3f}s, "
|
||||
f"内存={step_memory}MB(+{memory_delta}), "
|
||||
f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}")
|
||||
|
||||
# 每5步显示峰值内存
|
||||
if step % 5 == 0:
|
||||
peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB"
|
||||
print(f" {peak_info}")
|
||||
|
||||
time.sleep(0.2) # 短暂暂停观察
|
||||
|
||||
else:
|
||||
# 单设备训练模拟(改进版)
|
||||
print("🔸 单设备训练模拟...")
|
||||
|
||||
with tf.device(tpu_devices[0].name):
|
||||
# 创建较小的模型参数
|
||||
simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net')
|
||||
|
||||
for step in range(8):
|
||||
step_start = time.time()
|
||||
|
||||
# 创建较大的数据批次
|
||||
batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size
|
||||
|
||||
# 模拟计算密集型操作
|
||||
@tf.function
|
||||
def compute_step():
|
||||
x = tf.reshape(batch_data, [-1, 512])
|
||||
result = tf.matmul(x, simple_weights)
|
||||
result = tf.nn.relu(result)
|
||||
return tf.reduce_mean(result)
|
||||
|
||||
result = compute_step()
|
||||
step_duration = time.time() - step_start
|
||||
|
||||
# 获取内存状态
|
||||
snapshot = get_detailed_memory_snapshot()
|
||||
memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
|
||||
|
||||
print(f" Step {step}: result={result.numpy():.4f}, "
|
||||
f"时间={step_duration:.3f}s, "
|
||||
f"内存变化=+{memory_change}MB, "
|
||||
f"峰值={snapshot['summary']['total_peak']}MB")
|
||||
|
||||
# 5. 最终分析报告
|
||||
final_snapshot = get_detailed_memory_snapshot()
|
||||
total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
|
||||
peak_usage = final_snapshot['summary']['total_peak']
|
||||
|
||||
print(f"\n📈 训练监控报告:")
|
||||
print(f" 总内存增长: +{total_growth}MB")
|
||||
print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)")
|
||||
print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}")
|
||||
|
||||
# 各核心最终状态
|
||||
print(f" 各核心最终状态:")
|
||||
has_changes = False
|
||||
for i in range(len(tpu_devices)):
|
||||
final_core = final_snapshot[f'core_{i}']
|
||||
baseline_core = baseline_snapshot[f'core_{i}']
|
||||
current_change = final_core['current'] - baseline_core['current']
|
||||
peak_change = final_core['peak'] - baseline_core['peak']
|
||||
|
||||
if current_change != 0 or peak_change != 0:
|
||||
has_changes = True
|
||||
print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})")
|
||||
|
||||
if not has_changes:
|
||||
print(f" 所有核心内存无明显变化")
|
||||
|
||||
# 分布式使用分析
|
||||
if final_snapshot['summary']['active_cores'] == 1:
|
||||
print(f"\n⚠️ 分布式问题诊断:")
|
||||
print(f" 只有1个核心活跃,其他7个核心空闲")
|
||||
print(f" 可能原因: TPU策略配置问题或模型未正确分布")
|
||||
print(f" 建议: 检查分布式策略和模型分片")
|
||||
elif final_snapshot['summary']['active_cores'] > 4:
|
||||
print(f"\n✅ 分布式状态良好:")
|
||||
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常")
|
||||
else:
|
||||
print(f"\n🟡 分布式部分工作:")
|
||||
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡")
|
||||
|
||||
print("✅ TPU训练监控完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 训练监控失败: {e}")
|
||||
import traceback
|
||||
print(f"详细错误: {traceback.format_exc()[:300]}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 TPU训练内存监控工具")
|
||||
print("专注于训练过程中的实时内存和性能监控")
|
||||
print("适用于TPU v5e-8环境")
|
||||
print()
|
||||
|
||||
monitor_tpu_during_training()
|
||||
|
||||
print(f"\n🎯 监控要点总结:")
|
||||
print(f" 1. 确认所有8个TPU核心是否活跃")
|
||||
print(f" 2. 监控内存增长模式和峰值使用")
|
||||
print(f" 3. 检测MXU计算性能和效率")
|
||||
print(f" 4. 验证分布式策略是否正常工作")
|
||||
print(f" 5. 识别可能的内存泄漏或性能瓶颈")
|
||||
@@ -1,336 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import h5py
|
||||
import numpy as np
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import math
|
||||
|
||||
class BrainToTextDataset(Dataset):
|
||||
'''
|
||||
Dataset for brain-to-text data
|
||||
|
||||
Returns an entire batch of data instead of a single example
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trial_indicies,
|
||||
n_batches,
|
||||
split = 'train',
|
||||
batch_size = 64,
|
||||
days_per_batch = 1,
|
||||
random_seed = -1,
|
||||
must_include_days = None,
|
||||
feature_subset = None
|
||||
):
|
||||
'''
|
||||
trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values
|
||||
n_batches: (int) - number of random training batches to create
|
||||
split: (string) - string specifying if this is a train or test dataset
|
||||
batch_size: (int) - number of examples to include in batch returned from __getitem_()
|
||||
days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates
|
||||
to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch
|
||||
random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random
|
||||
must_include_days ([int]) - list of days that must be included in every batch
|
||||
feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data
|
||||
'''
|
||||
|
||||
# Set random seed for reproducibility
|
||||
if random_seed != -1:
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
self.split = split
|
||||
|
||||
# Ensure the split is valid
|
||||
if self.split not in ['train', 'test']:
|
||||
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
|
||||
|
||||
self.days_per_batch = days_per_batch
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.n_batches = n_batches
|
||||
|
||||
self.days = {}
|
||||
self.n_trials = 0
|
||||
self.trial_indicies = trial_indicies
|
||||
self.n_days = len(trial_indicies.keys())
|
||||
|
||||
self.feature_subset = feature_subset
|
||||
|
||||
# Calculate total number of trials in the dataset
|
||||
for d in trial_indicies:
|
||||
self.n_trials += len(trial_indicies[d]['trials'])
|
||||
|
||||
if must_include_days is not None and len(must_include_days) > days_per_batch:
|
||||
raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}')
|
||||
|
||||
if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train':
|
||||
raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset')
|
||||
|
||||
if must_include_days is not None:
|
||||
# Map must_include_days to correct indicies if they are negative
|
||||
for i, d in enumerate(must_include_days):
|
||||
if d < 0:
|
||||
must_include_days[i] = self.n_days + d
|
||||
|
||||
self.must_include_days = must_include_days
|
||||
|
||||
# Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error
|
||||
if self.split == 'train' and self.days_per_batch > self.n_days:
|
||||
raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.')
|
||||
|
||||
|
||||
if self.split == 'train':
|
||||
self.batch_index = self.create_batch_index_train()
|
||||
else:
|
||||
self.batch_index = self.create_batch_index_test()
|
||||
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
|
||||
|
||||
def __len__(self):
|
||||
'''
|
||||
How many batches are in this dataset.
|
||||
Because training data is sampled randomly, there is no fixed dataset length,
|
||||
however this method is required for DataLoader to work
|
||||
'''
|
||||
return self.n_batches if self.n_batches is not None else 0
|
||||
|
||||
def __getitem__(self, idx):
|
||||
'''
|
||||
Gets an entire batch of data from the dataset, not just a single item
|
||||
'''
|
||||
batch = {
|
||||
'input_features' : [],
|
||||
'seq_class_ids' : [],
|
||||
'n_time_steps' : [],
|
||||
'phone_seq_lens' : [],
|
||||
'day_indicies' : [],
|
||||
'transcriptions' : [],
|
||||
'block_nums' : [],
|
||||
'trial_nums' : [],
|
||||
}
|
||||
|
||||
index = self.batch_index[idx]
|
||||
|
||||
# Iterate through each day in the index
|
||||
for d in index.keys():
|
||||
|
||||
# Open the hdf5 file for that day
|
||||
with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f:
|
||||
|
||||
# For each trial in the selected trials in that day
|
||||
for t in index[d]:
|
||||
|
||||
try:
|
||||
g = f[f'trial_{t:04d}']
|
||||
|
||||
# Remove features is neccessary
|
||||
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility
|
||||
if self.feature_subset:
|
||||
input_features = input_features[:,self.feature_subset]
|
||||
|
||||
batch['input_features'].append(input_features)
|
||||
|
||||
batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels
|
||||
batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions
|
||||
batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding
|
||||
batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding
|
||||
batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers
|
||||
batch['block_nums'].append(g.attrs['block_num'])
|
||||
batch['trial_nums'].append(g.attrs['trial_num'])
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
|
||||
continue
|
||||
|
||||
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
|
||||
|
||||
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
|
||||
batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens'])
|
||||
batch['day_indicies'] = torch.tensor(batch['day_indicies'])
|
||||
batch['transcriptions'] = torch.stack(batch['transcriptions'])
|
||||
batch['block_nums'] = torch.tensor(batch['block_nums'])
|
||||
batch['trial_nums'] = torch.tensor(batch['trial_nums'])
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def create_batch_index_train(self):
|
||||
'''
|
||||
Create an index that maps a batch_number to batch_size number of trials
|
||||
|
||||
Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days
|
||||
(or as even as possible if batch_size is not divisible by days_per_batch)
|
||||
'''
|
||||
|
||||
batch_index = {}
|
||||
|
||||
# Precompute the days that are not in must_include_days
|
||||
if self.must_include_days is not None:
|
||||
non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days]
|
||||
|
||||
for batch_idx in range(self.n_batches):
|
||||
batch = {}
|
||||
|
||||
# Which days will be used for this batch. Picked randomly without replacement
|
||||
# TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day
|
||||
|
||||
# If must_include_days is not empty, we will use those days and then randomly sample the rest
|
||||
if self.must_include_days is not None and len(self.must_include_days) > 0:
|
||||
|
||||
days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False)))
|
||||
|
||||
# Otherwise we will select random days without replacement
|
||||
else:
|
||||
days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False)
|
||||
|
||||
# How many trials will be sampled from each day
|
||||
num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials
|
||||
|
||||
for d in days:
|
||||
|
||||
# Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem
|
||||
trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True)
|
||||
batch[d] = trial_idxs
|
||||
|
||||
# Remove extra trials
|
||||
extra_trials = (num_trials * len(days)) - self.batch_size
|
||||
|
||||
# While we still have extra trials, remove the last trial from a random day
|
||||
while extra_trials > 0:
|
||||
d = np.random.choice(days)
|
||||
batch[d] = batch[d][:-1]
|
||||
extra_trials -= 1
|
||||
|
||||
batch_index[batch_idx] = batch
|
||||
|
||||
return batch_index
|
||||
|
||||
def create_batch_index_test(self):
|
||||
'''
|
||||
Create an index that is all validation/testing data in batches of up to self.batch_size
|
||||
|
||||
If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size
|
||||
|
||||
This index will ensures that every trial in the validation set is seen once and only once
|
||||
'''
|
||||
batch_index = {}
|
||||
batch_idx = 0
|
||||
|
||||
for d in self.trial_indicies.keys():
|
||||
|
||||
# Calculate how many batches we need for this day
|
||||
num_trials = len(self.trial_indicies[d]['trials'])
|
||||
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
|
||||
|
||||
# Create batches for this day
|
||||
for i in range(num_batches):
|
||||
start_idx = i * self.batch_size
|
||||
end_idx = min((i + 1) * self.batch_size, num_trials)
|
||||
|
||||
# Get the trial indices for this batch
|
||||
batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx]
|
||||
|
||||
# Add to batch_index
|
||||
batch_index[batch_idx] = {d : batch_trials}
|
||||
batch_idx += 1
|
||||
|
||||
return batch_index
|
||||
|
||||
def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None):
|
||||
'''
|
||||
Split data from file_paths into train and test splits
|
||||
Returns two dictionaries that detail which trials in each day will be a part of that split:
|
||||
Example:
|
||||
{
|
||||
0: trials[1,2,3], session_path: 'path'
|
||||
1: trials[2,5,6], session_path: 'path'
|
||||
}
|
||||
|
||||
Args:
|
||||
file_paths (list): List of file paths to the hdf5 files containing the data
|
||||
test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing
|
||||
seed (int): Seed for reproducibility. If set to -1, the split will be random
|
||||
bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as:
|
||||
{
|
||||
'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
|
||||
'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
|
||||
...
|
||||
}
|
||||
'''
|
||||
# Set seed for reporoducibility
|
||||
if seed != -1:
|
||||
np.random.seed(seed)
|
||||
|
||||
# Get trials in each day
|
||||
trials_per_day = {}
|
||||
for i, path in enumerate(file_paths):
|
||||
# Handle both Windows and Unix path separators
|
||||
path_parts = path.replace('\\', '/').split('/')
|
||||
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||||
|
||||
good_trial_indices = []
|
||||
|
||||
if os.path.exists(path):
|
||||
with h5py.File(path, 'r') as f:
|
||||
num_trials = len(list(f.keys()))
|
||||
for t in range(num_trials):
|
||||
key = f'trial_{t:04d}'
|
||||
|
||||
block_num = f[key].attrs['block_num']
|
||||
trial_num = f[key].attrs['trial_num']
|
||||
|
||||
if (
|
||||
bad_trials_dict is not None
|
||||
and session in bad_trials_dict
|
||||
and str(block_num) in bad_trials_dict[session]
|
||||
and trial_num in bad_trials_dict[session][str(block_num)]
|
||||
):
|
||||
# print(f'Bad trial: {session}_{block_num}_{trial_num}')
|
||||
continue
|
||||
|
||||
good_trial_indices.append(t)
|
||||
|
||||
trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path}
|
||||
|
||||
# Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training
|
||||
train_trials = {}
|
||||
test_trials = {}
|
||||
|
||||
for day in trials_per_day.keys():
|
||||
|
||||
num_trials = trials_per_day[day]['num_trials']
|
||||
|
||||
# Generate all trial indices for this day (assuming 0-indexed)
|
||||
all_trial_indices = trials_per_day[day]['trial_indices']
|
||||
|
||||
# If test_percentage is 0 or 1, we can just assign all trials to either train or test
|
||||
if test_percentage == 0:
|
||||
train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
|
||||
continue
|
||||
|
||||
elif test_percentage == 1:
|
||||
train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
|
||||
test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
continue
|
||||
|
||||
else:
|
||||
# Calculate how many trials to use for testing
|
||||
num_test = max(1, int(num_trials * test_percentage))
|
||||
|
||||
# Randomly select indices for testing
|
||||
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
|
||||
|
||||
# Remaining indices go to training
|
||||
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
|
||||
|
||||
# Store the split indices
|
||||
train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
|
||||
return train_trials, test_trials
|
||||
@@ -1,304 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import redis
|
||||
from omegaconf import OmegaConf
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import editdistance
|
||||
import argparse
|
||||
|
||||
from rnn_model import GRUDecoder
|
||||
from evaluate_model_helpers import *
|
||||
|
||||
# argument parser for command line arguments
|
||||
parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
|
||||
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
|
||||
help='Path to the pretrained model directory (relative to the current working directory).')
|
||||
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
||||
help='Path to the dataset directory (relative to the current working directory).')
|
||||
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
|
||||
help='Evaluation type: "val" for validation set, "test" for test set. '
|
||||
'If "test", ground truth is not available.')
|
||||
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
||||
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
|
||||
parser.add_argument('--gpu_number', type=int, default=-1,
|
||||
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# paths to model and data directories
|
||||
# Note: these paths are relative to the current working directory
|
||||
model_path = args.model_path
|
||||
data_dir = args.data_dir
|
||||
|
||||
# define evaluation type
|
||||
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
|
||||
|
||||
# load csv file
|
||||
b2txt_csv_df = pd.read_csv(args.csv_path)
|
||||
|
||||
# load model args
|
||||
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
|
||||
|
||||
# set up gpu device
|
||||
gpu_number = args.gpu_number
|
||||
if torch.cuda.is_available() and gpu_number >= 0:
|
||||
if gpu_number >= torch.cuda.device_count():
|
||||
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
|
||||
device = f'cuda:{gpu_number}'
|
||||
device = torch.device(device)
|
||||
print(f'Using {device} for model inference.')
|
||||
else:
|
||||
if gpu_number >= 0:
|
||||
print(f'GPU number {gpu_number} requested but not available.')
|
||||
print('Using CPU for model inference.')
|
||||
device = torch.device('cpu')
|
||||
|
||||
# define model
|
||||
model = GRUDecoder(
|
||||
neural_dim = model_args['model']['n_input_features'],
|
||||
n_units = model_args['model']['n_units'],
|
||||
n_days = len(model_args['dataset']['sessions']),
|
||||
n_classes = model_args['dataset']['n_classes'],
|
||||
rnn_dropout = model_args['model']['rnn_dropout'],
|
||||
input_dropout = model_args['model']['input_network']['input_layer_dropout'],
|
||||
n_layers = model_args['model']['n_layers'],
|
||||
patch_size = model_args['model']['patch_size'],
|
||||
patch_stride = model_args['model']['patch_stride'],
|
||||
)
|
||||
|
||||
# load model weights
|
||||
checkpoint = torch.load(
|
||||
os.path.join(model_path, 'checkpoint/best_checkpoint'),
|
||||
map_location=device,
|
||||
weights_only=False,
|
||||
)
|
||||
# rename keys to not start with "module." (happens if model was saved with DataParallel)
|
||||
for key in list(checkpoint['model_state_dict'].keys()):
|
||||
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
|
||||
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# add model to device
|
||||
model.to(device)
|
||||
|
||||
# set model to eval mode
|
||||
model.eval()
|
||||
|
||||
# load data for each session
|
||||
test_data = {}
|
||||
total_test_trials = 0
|
||||
for session in model_args['dataset']['sessions']:
|
||||
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
|
||||
if f'data_{eval_type}.hdf5' in files:
|
||||
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
||||
|
||||
data = load_h5py_file(eval_file, b2txt_csv_df)
|
||||
test_data[session] = data
|
||||
|
||||
total_test_trials += len(test_data[session]["neural_features"])
|
||||
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
|
||||
print(f'Total number of {eval_type} trials: {total_test_trials}')
|
||||
print()
|
||||
|
||||
|
||||
# put neural data through the pretrained model to get phoneme predictions (logits)
|
||||
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
|
||||
for session, data in test_data.items():
|
||||
|
||||
data['logits'] = []
|
||||
data['pred_seq'] = []
|
||||
input_layer = model_args['dataset']['sessions'].index(session)
|
||||
|
||||
for trial in range(len(data['neural_features'])):
|
||||
# get neural input for the trial
|
||||
neural_input = data['neural_features'][trial]
|
||||
|
||||
# add batch dimension
|
||||
neural_input = np.expand_dims(neural_input, axis=0)
|
||||
|
||||
# convert to torch tensor
|
||||
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
|
||||
|
||||
# run decoding step
|
||||
logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
|
||||
data['logits'].append(logits)
|
||||
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
|
||||
|
||||
# convert logits to phoneme sequences and print them out
|
||||
for session, data in test_data.items():
|
||||
data['pred_seq'] = []
|
||||
for trial in range(len(data['logits'])):
|
||||
logits = data['logits'][trial][0]
|
||||
pred_seq = np.argmax(logits, axis=-1)
|
||||
# remove blanks (0)
|
||||
pred_seq = [int(p) for p in pred_seq if p != 0]
|
||||
# remove consecutive duplicates
|
||||
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
||||
# convert to phonemes
|
||||
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
||||
# add to data
|
||||
data['pred_seq'].append(pred_seq)
|
||||
|
||||
# print out the predicted sequences
|
||||
block_num = data['block_num'][trial]
|
||||
trial_num = data['trial_num'][trial]
|
||||
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
|
||||
if eval_type == 'val':
|
||||
sentence_label = data['sentence_label'][trial]
|
||||
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
|
||||
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
||||
|
||||
print(f'Sentence label: {sentence_label}')
|
||||
print(f'True sequence: {" ".join(true_seq)}')
|
||||
print(f'Predicted Sequence: {" ".join(pred_seq)}')
|
||||
print()
|
||||
|
||||
|
||||
# language model inference via redis
|
||||
# make sure that the standalone language model is running on the localhost redis ip
|
||||
# see README.md for instructions on how to run the language model
|
||||
|
||||
def connect_to_redis_with_retry(host, port, password, db=0, max_retries=10, retry_delay=3):
|
||||
"""Connect to Redis with retry logic"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f"Attempting to connect to Redis at {host}:{port} (attempt {attempt + 1}/{max_retries})...")
|
||||
r = redis.Redis(host=host, port=port, db=db, password=password)
|
||||
r.ping() # Test the connection
|
||||
print(f"Successfully connected to Redis at {host}:{port}")
|
||||
return r
|
||||
except redis.exceptions.ConnectionError as e:
|
||||
print(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
if attempt < max_retries - 1:
|
||||
print(f"Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
print("Max retries reached. Could not connect to Redis.")
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Unexpected error connecting to Redis: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
print(f"Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
raise e
|
||||
|
||||
r = connect_to_redis_with_retry('hs.zchens.cn', 6379, 'admin01')
|
||||
r.flushall() # clear all streams in redis
|
||||
|
||||
# define redis streams for the remote language model
|
||||
remote_lm_input_stream = 'remote_lm_input'
|
||||
remote_lm_output_partial_stream = 'remote_lm_output_partial'
|
||||
remote_lm_output_final_stream = 'remote_lm_output_final'
|
||||
|
||||
# set timestamps for last entries seen in the redis streams
|
||||
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
|
||||
lm_results = {
|
||||
'session': [],
|
||||
'block': [],
|
||||
'trial': [],
|
||||
'true_sentence': [],
|
||||
'pred_sentence': [],
|
||||
}
|
||||
|
||||
# loop through all trials and put logits into the remote language model to get text predictions
|
||||
# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090)
|
||||
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
|
||||
for session in test_data.keys():
|
||||
for trial in range(len(test_data[session]['logits'])):
|
||||
# get trial logits and rearrange them for the LM
|
||||
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
|
||||
|
||||
# reset language model
|
||||
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
|
||||
|
||||
'''
|
||||
# update language model parameters
|
||||
remote_lm_done_updating_lastEntrySeen = update_remote_lm_params(
|
||||
r,
|
||||
remote_lm_done_updating_lastEntrySeen,
|
||||
acoustic_scale=0.35,
|
||||
blank_penalty=90.0,
|
||||
alpha=0.55,
|
||||
)
|
||||
'''
|
||||
|
||||
# put logits into LM
|
||||
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
|
||||
r,
|
||||
remote_lm_input_stream,
|
||||
remote_lm_output_partial_stream,
|
||||
remote_lm_output_partial_lastEntrySeen,
|
||||
logits,
|
||||
)
|
||||
|
||||
# finalize remote LM
|
||||
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
|
||||
r,
|
||||
remote_lm_output_final_stream,
|
||||
remote_lm_output_final_lastEntrySeen,
|
||||
)
|
||||
|
||||
# get the best candidate sentence
|
||||
best_candidate_sentence = lm_out['candidate_sentences'][0]
|
||||
|
||||
# store results
|
||||
lm_results['session'].append(session)
|
||||
lm_results['block'].append(test_data[session]['block_num'][trial])
|
||||
lm_results['trial'].append(test_data[session]['trial_num'][trial])
|
||||
if eval_type == 'val':
|
||||
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
|
||||
else:
|
||||
lm_results['true_sentence'].append(None)
|
||||
lm_results['pred_sentence'].append(best_candidate_sentence)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
|
||||
|
||||
# if using the validation set, lets calculate the aggregate word error rate (WER)
|
||||
if eval_type == 'val':
|
||||
total_true_length = 0
|
||||
total_edit_distance = 0
|
||||
|
||||
lm_results['edit_distance'] = []
|
||||
lm_results['num_words'] = []
|
||||
|
||||
for i in range(len(lm_results['pred_sentence'])):
|
||||
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
|
||||
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
|
||||
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
|
||||
|
||||
total_true_length += len(true_sentence.split())
|
||||
total_edit_distance += ed
|
||||
|
||||
lm_results['edit_distance'].append(ed)
|
||||
lm_results['num_words'].append(len(true_sentence.split()))
|
||||
|
||||
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
|
||||
print(f'True sentence: {true_sentence}')
|
||||
print(f'Predicted sentence: {pred_sentence}')
|
||||
print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%')
|
||||
print()
|
||||
|
||||
print(f'Total true sentence length: {total_true_length}')
|
||||
print(f'Total edit distance: {total_edit_distance}')
|
||||
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
|
||||
|
||||
|
||||
# write predicted sentences to a csv file. put a timestamp in the filename (YYYYMMDD_HHMMSS)
|
||||
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.csv')
|
||||
ids = [i for i in range(len(lm_results['pred_sentence']))]
|
||||
df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']})
|
||||
df_out.to_csv(output_file, index=False)
|
||||
@@ -1,580 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import cast
|
||||
|
||||
class GradientReversalFn(torch.autograd.Function):
|
||||
"""
|
||||
Gradient Reversal Layer (GRL)
|
||||
Forward: identity
|
||||
Backward: multiply incoming gradient by -lambda
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x, lambd: float):
|
||||
ctx.lambd = lambd
|
||||
return x.view_as(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return -ctx.lambd * grad_output, None
|
||||
|
||||
def gradient_reverse(x, lambd: float = 1.0):
|
||||
return GradientReversalFn.apply(x, lambd)
|
||||
|
||||
class NoiseModel(nn.Module):
|
||||
'''
|
||||
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
||||
'''
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0):
|
||||
super(NoiseModel, self).__init__()
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Day-specific input layers
|
||||
self.day_layer_activation = nn.Softsign()
|
||||
# Let Accelerator handle dtype automatically for TPU compatibility
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 2-layer GRU for noise estimation
|
||||
self.gru = nn.GRU(
|
||||
input_size=self.input_size,
|
||||
hidden_size=self.input_size, # Output same dimension as input
|
||||
num_layers=2,
|
||||
dropout=self.rnn_dropout,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
)
|
||||
|
||||
# Initialize GRU parameters
|
||||
for name, param in self.gru.named_parameters():
|
||||
if "weight_hh" in name:
|
||||
nn.init.orthogonal_(param)
|
||||
if "weight_ih" in name:
|
||||
nn.init.xavier_uniform_(param)
|
||||
|
||||
# Learnable initial hidden state - let Accelerator handle dtype
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
|
||||
|
||||
def forward(self, x, day_idx, states=None):
|
||||
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Stack all day weights and biases upfront for static indexing
|
||||
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
||||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
||||
|
||||
# XLA-friendly gather operation
|
||||
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
# XLA-friendly conditional dropout
|
||||
if self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x)
|
||||
|
||||
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x = x.unsqueeze(1)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
|
||||
gru_dtype = next(self.gru.parameters()).dtype
|
||||
if x.dtype != gru_dtype:
|
||||
x = x.to(gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
||||
if states is None:
|
||||
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
|
||||
if states.dtype != gru_dtype:
|
||||
states = states.to(gru_dtype)
|
||||
|
||||
# Disable autocast for GRU to avoid dtype mismatches on XLA
|
||||
device_type = x.device.type
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
output, hidden_states = self.gru(x, states)
|
||||
|
||||
return output, hidden_states
|
||||
|
||||
|
||||
class CleanSpeechModel(nn.Module):
|
||||
'''
|
||||
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
|
||||
'''
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0):
|
||||
super(CleanSpeechModel, self).__init__()
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.n_classes = n_classes
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Day-specific input layers
|
||||
self.day_layer_activation = nn.Softsign()
|
||||
# Let Accelerator handle dtype automatically for TPU compatibility
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 3-layer GRU for clean speech recognition
|
||||
self.gru = nn.GRU(
|
||||
input_size=self.input_size,
|
||||
hidden_size=self.n_units,
|
||||
num_layers=3,
|
||||
dropout=self.rnn_dropout,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
)
|
||||
|
||||
# Initialize GRU parameters
|
||||
for name, param in self.gru.named_parameters():
|
||||
if "weight_hh" in name:
|
||||
nn.init.orthogonal_(param)
|
||||
if "weight_ih" in name:
|
||||
nn.init.xavier_uniform_(param)
|
||||
|
||||
# Output classification layer
|
||||
self.out = nn.Linear(self.n_units, self.n_classes)
|
||||
nn.init.xavier_uniform_(self.out.weight)
|
||||
|
||||
# Learnable initial hidden state
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||
|
||||
def forward(self, x, day_idx, states=None, return_state=False):
|
||||
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Stack all day weights and biases upfront for static indexing
|
||||
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
||||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
||||
|
||||
# XLA-friendly gather operation
|
||||
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
if self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x)
|
||||
|
||||
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x = x.unsqueeze(1)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
|
||||
gru_dtype = next(self.gru.parameters()).dtype
|
||||
if x.dtype != gru_dtype:
|
||||
x = x.to(gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
|
||||
if states.dtype != gru_dtype:
|
||||
states = states.to(gru_dtype)
|
||||
|
||||
device_type = x.device.type
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
output, hidden_states = self.gru(x, states)
|
||||
|
||||
# Classification
|
||||
logits = self.out(output)
|
||||
|
||||
if return_state:
|
||||
return logits, hidden_states
|
||||
return logits
|
||||
|
||||
|
||||
class NoisySpeechModel(nn.Module):
|
||||
'''
|
||||
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
|
||||
'''
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0):
|
||||
super(NoisySpeechModel, self).__init__()
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.n_classes = n_classes
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 2-layer GRU for noisy speech recognition
|
||||
self.gru = nn.GRU(
|
||||
input_size=self.input_size,
|
||||
hidden_size=self.n_units,
|
||||
num_layers=2,
|
||||
dropout=self.rnn_dropout,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
)
|
||||
|
||||
# Initialize GRU parameters
|
||||
for name, param in self.gru.named_parameters():
|
||||
if "weight_hh" in name:
|
||||
nn.init.orthogonal_(param)
|
||||
if "weight_ih" in name:
|
||||
nn.init.xavier_uniform_(param)
|
||||
|
||||
# Output classification layer
|
||||
self.out = nn.Linear(self.n_units, self.n_classes)
|
||||
nn.init.xavier_uniform_(self.out.weight)
|
||||
|
||||
# Learnable initial hidden state
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||
|
||||
def forward(self, x, states=None, return_state=False):
|
||||
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
||||
batch_size = x.size(0)
|
||||
|
||||
gru_dtype = next(self.gru.parameters()).dtype
|
||||
if x.dtype != gru_dtype:
|
||||
x = x.to(gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
|
||||
if states.dtype != gru_dtype:
|
||||
states = states.to(gru_dtype)
|
||||
|
||||
device_type = x.device.type
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
output, hidden_states = self.gru(x, states)
|
||||
|
||||
# Classification
|
||||
logits = self.out(output)
|
||||
|
||||
if return_state:
|
||||
return logits, hidden_states
|
||||
return logits
|
||||
|
||||
|
||||
class TripleGRUDecoder(nn.Module):
|
||||
'''
|
||||
Three-model adversarial architecture for neural speech decoding
|
||||
|
||||
Combines:
|
||||
- NoiseModel: estimates noise in neural data
|
||||
- CleanSpeechModel: processes denoised signal for recognition
|
||||
- NoisySpeechModel: processes noise signal for recognition
|
||||
'''
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
):
|
||||
'''
|
||||
neural_dim (int) - number of channels in a single timestep (e.g. 512)
|
||||
n_units (int) - number of hidden units in each recurrent layer
|
||||
n_days (int) - number of days in the dataset
|
||||
n_classes (int) - number of classes (phonemes)
|
||||
rnn_dropout (float) - percentage of units to dropout during training
|
||||
input_dropout (float) - percentage of input units to dropout during training
|
||||
patch_size (int) - number of timesteps to concat on initial input layer
|
||||
patch_stride(int) - number of timesteps to stride over when concatenating initial input
|
||||
'''
|
||||
super(TripleGRUDecoder, self).__init__()
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_classes = n_classes
|
||||
self.n_days = n_days
|
||||
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Create the three models
|
||||
self.noise_model = NoiseModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride
|
||||
)
|
||||
|
||||
self.clean_speech_model = CleanSpeechModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
n_classes=n_classes,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride
|
||||
)
|
||||
|
||||
self.noisy_speech_model = NoisySpeechModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
n_classes=n_classes,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride
|
||||
)
|
||||
|
||||
# Training mode flag
|
||||
self.training_mode = 'full' # 'full', 'inference'
|
||||
|
||||
def _apply_preprocessing(self, x, day_idx):
|
||||
'''XLA-friendly preprocessing with static operations'''
|
||||
batch_size = x.size(0)
|
||||
|
||||
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
|
||||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
|
||||
|
||||
# XLA-friendly gather operation
|
||||
day_weights = torch.index_select(all_day_weights, 0, day_idx)
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||
|
||||
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x_processed = x_processed.unsqueeze(1)
|
||||
x_processed = x_processed.permute(0, 3, 1, 2)
|
||||
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x_processed = x_processed.to(original_dtype)
|
||||
|
||||
return x_processed
|
||||
|
||||
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
|
||||
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
|
||||
batch_size = x_processed.size(0)
|
||||
|
||||
clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
|
||||
if x_processed.dtype != clean_gru_dtype:
|
||||
x_processed = x_processed.to(clean_gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization with dtype consistency
|
||||
if states is None:
|
||||
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||||
# Ensure hidden states match input dtype for mixed precision training
|
||||
if states.dtype != clean_gru_dtype:
|
||||
states = states.to(clean_gru_dtype)
|
||||
|
||||
# GRU forward pass (skip preprocessing since input is already processed)
|
||||
device_type = x_processed.device.type
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
||||
|
||||
# Classification
|
||||
logits = self.clean_speech_model.out(output)
|
||||
return logits
|
||||
|
||||
def _noisy_forward_with_processed_input(self, x_processed, states=None):
|
||||
'''Forward pass for NoisySpeechModel with already processed input'''
|
||||
batch_size = x_processed.size(0)
|
||||
|
||||
noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
|
||||
if x_processed.dtype != noisy_gru_dtype:
|
||||
x_processed = x_processed.to(noisy_gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization with dtype consistency
|
||||
if states is None:
|
||||
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||||
# Ensure hidden states match input dtype for mixed precision training
|
||||
if states.dtype != noisy_gru_dtype:
|
||||
states = states.to(noisy_gru_dtype)
|
||||
|
||||
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
||||
device_type = x_processed.device.type
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
||||
|
||||
# Classification
|
||||
logits = self.noisy_speech_model.out(output)
|
||||
return logits
|
||||
|
||||
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
|
||||
'''
|
||||
Three-model adversarial forward pass
|
||||
|
||||
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
|
||||
day_idx (tensor) - tensor of day indices for each example in the batch
|
||||
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
|
||||
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
|
||||
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
|
||||
'''
|
||||
|
||||
if mode == 'full':
|
||||
# Training mode: run all three models
|
||||
|
||||
# 1. Noise model estimates noise in the data
|
||||
noise_output, noise_hidden = self.noise_model(x, day_idx,
|
||||
states['noise'] if states else None)
|
||||
|
||||
# 2. For residual connection, we need x in the same space as noise_output
|
||||
# Apply the same preprocessing that the models use internally
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
clean_dtype = next(self.clean_speech_model.parameters()).dtype
|
||||
if x_processed.dtype != clean_dtype:
|
||||
x_processed = x_processed.to(clean_dtype)
|
||||
|
||||
# Ensure dtype consistency between processed input and noise output
|
||||
if noise_output.dtype != clean_dtype:
|
||||
noise_output = noise_output.to(clean_dtype)
|
||||
|
||||
# 3. Clean speech model processes denoised signal
|
||||
denoised_input = x_processed - noise_output # Residual connection in processed space
|
||||
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
||||
# But we need to reverse the preprocessing first, then let clean model do its own
|
||||
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
||||
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||
states['clean'] if states else None)
|
||||
|
||||
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
||||
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
||||
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
||||
noisy_input = cast(torch.Tensor, noisy_input)
|
||||
noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
|
||||
if noisy_input.dtype != noisy_dtype:
|
||||
noisy_input = noisy_input.to(noisy_dtype)
|
||||
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
||||
states['noisy'] if states else None)
|
||||
|
||||
# XLA-friendly return - use tuple instead of dict for better compilation
|
||||
if return_state:
|
||||
return (clean_logits, noisy_logits, noise_output), noise_hidden
|
||||
return clean_logits, noisy_logits, noise_output
|
||||
|
||||
elif mode == 'inference':
|
||||
# Inference mode: only noise model + clean speech model
|
||||
|
||||
# 1. Estimate noise
|
||||
noise_output, noise_hidden = self.noise_model(x, day_idx,
|
||||
states['noise'] if states else None)
|
||||
|
||||
# 2. For residual connection, we need x in the same space as noise_output
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
clean_dtype = next(self.clean_speech_model.parameters()).dtype
|
||||
if x_processed.dtype != clean_dtype:
|
||||
x_processed = x_processed.to(clean_dtype)
|
||||
|
||||
# Ensure dtype consistency for mixed precision residual connection
|
||||
if noise_output.dtype != clean_dtype:
|
||||
noise_output = noise_output.to(clean_dtype)
|
||||
denoised_input = x_processed - noise_output
|
||||
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||
states['clean'] if states else None)
|
||||
|
||||
# XLA-friendly return - use tuple for consistency
|
||||
if return_state:
|
||||
return clean_logits, noise_hidden
|
||||
return clean_logits
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
|
||||
|
||||
def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
|
||||
'''
|
||||
Apply combined gradients to noise model parameters
|
||||
|
||||
clean_grad (tensor) - gradients from clean speech model output layer
|
||||
noisy_grad (tensor) - gradients from noisy speech model output layer
|
||||
'''
|
||||
# Combine gradients: negative from clean model, positive from noisy model
|
||||
combined_grad = -clean_grad + noisy_grad
|
||||
|
||||
# Apply gradients to noise model parameters
|
||||
# This is a simplified implementation - in practice you'd want more sophisticated update rules
|
||||
with torch.no_grad():
|
||||
for param in self.noise_model.parameters():
|
||||
if param.grad is not None:
|
||||
# Scale the combined gradient appropriately
|
||||
# This is a placeholder - you'd need to implement proper gradient mapping
|
||||
param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
|
||||
|
||||
def set_mode(self, mode):
|
||||
'''Set the operating mode'''
|
||||
self.training_mode = mode
|
||||
|
||||
|
||||
@@ -1,952 +0,0 @@
|
||||
import os
|
||||
|
||||
# XLA multi-threading optimization - MUST be set before importing torch_xla
|
||||
# Set these environment variables early to ensure they take effect
|
||||
if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ:
|
||||
# Enable XLA multi-threading for compilation speedup
|
||||
os.environ.setdefault('XLA_FLAGS',
|
||||
'--xla_cpu_multi_thread_eigen=true ' +
|
||||
'--xla_cpu_enable_fast_math=true ' +
|
||||
f'--xla_force_host_platform_device_count={os.cpu_count()}'
|
||||
)
|
||||
# Set PyTorch XLA threading
|
||||
os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count()))
|
||||
print(f"Set XLA compilation threads to {os.cpu_count()}")
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
import math
|
||||
import pathlib
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from contextlib import nullcontext
|
||||
|
||||
from dataset import BrainToTextDataset, train_test_split_indicies
|
||||
from data_augmentations import gauss_smooth
|
||||
|
||||
import torchaudio.functional as F # for edit distance
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Import Accelerate for TPU support
|
||||
from accelerate import Accelerator, DataLoaderConfiguration
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
# Import XLA after setting environment variables
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
|
||||
torch.backends.cudnn.deterministic = True # makes training more reproducible
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
|
||||
from rnn_model import TripleGRUDecoder
|
||||
|
||||
class BrainToTextDecoder_Trainer:
|
||||
"""
|
||||
This class will initialize and train a brain-to-text phoneme decoder
|
||||
|
||||
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
'''
|
||||
args : dictionary of training arguments
|
||||
'''
|
||||
|
||||
# Configure DataLoader behavior for TPU compatibility
|
||||
dataloader_config = DataLoaderConfiguration(
|
||||
even_batches=False # Required for batch_size=None DataLoaders on TPU
|
||||
)
|
||||
|
||||
# Initialize Accelerator for TPU/multi-device support
|
||||
self.use_xla = bool(xm.get_xla_supported_devices())
|
||||
self.amp_requested = args.get('use_amp', True)
|
||||
mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
mixed_precision=mixed_precision_mode,
|
||||
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
|
||||
log_with=None, # We'll use our own logging
|
||||
project_dir=args.get('output_dir', './output'),
|
||||
dataloader_config=dataloader_config,
|
||||
)
|
||||
|
||||
|
||||
# Trainer fields
|
||||
self.args = args
|
||||
self.logger = None
|
||||
self.device = self.accelerator.device # Use accelerator device instead of manual device selection
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.learning_rate_scheduler = None
|
||||
self.ctc_loss = None
|
||||
|
||||
self.best_val_PER = torch.inf # track best PER for checkpointing
|
||||
self.best_val_loss = torch.inf # track best loss for checkpointing
|
||||
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
self.train_loader = None
|
||||
self.val_loader = None
|
||||
|
||||
self.transform_args = self.args['dataset']['data_transforms']
|
||||
|
||||
# Adversarial training config (safe defaults if not provided)
|
||||
adv_cfg = self.args.get('adversarial', {})
|
||||
self.adv_enabled = adv_cfg.get('enabled', False)
|
||||
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
|
||||
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
|
||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
|
||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
|
||||
|
||||
# Create output directory
|
||||
if args['mode'] == 'train':
|
||||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||
|
||||
# Create checkpoint directory
|
||||
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
|
||||
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
|
||||
|
||||
# Set up logging
|
||||
self.logger = logging.getLogger(__name__)
|
||||
for handler in self.logger.handlers[:]: # make a copy of the list
|
||||
self.logger.removeHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
|
||||
|
||||
if args['mode']=='train':
|
||||
# During training, save logs to file in output directory
|
||||
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
|
||||
fh.setFormatter(formatter)
|
||||
self.logger.addHandler(fh)
|
||||
|
||||
# Always print logs to stdout
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.setFormatter(formatter)
|
||||
self.logger.addHandler(sh)
|
||||
|
||||
# Log device information (managed by Accelerator)
|
||||
self.logger.info(f'Using device: {self.device}')
|
||||
self.logger.info(f'Accelerator state: {self.accelerator.state}')
|
||||
if self.accelerator.num_processes > 1:
|
||||
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
|
||||
if self.use_xla and self.amp_requested:
|
||||
self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
|
||||
|
||||
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
|
||||
if self.args['seed'] != -1:
|
||||
set_seed(self.args['seed'])
|
||||
|
||||
# Initialize the model
|
||||
self.model = TripleGRUDecoder(
|
||||
neural_dim = self.args['model']['n_input_features'],
|
||||
n_units = self.args['model']['n_units'],
|
||||
n_days = len(self.args['dataset']['sessions']),
|
||||
n_classes = self.args['dataset']['n_classes'],
|
||||
rnn_dropout = self.args['model']['rnn_dropout'],
|
||||
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
|
||||
patch_size = self.args['model']['patch_size'],
|
||||
patch_stride = self.args['model']['patch_stride'],
|
||||
)
|
||||
|
||||
if self.use_xla and self.amp_requested:
|
||||
self.model = self.model.to(torch.bfloat16)
|
||||
self.logger.info('Converted model parameters to bfloat16 for TPU training.')
|
||||
|
||||
self.model_dtype = next(self.model.parameters()).dtype
|
||||
|
||||
# Temporarily disable torch.compile for compatibility with new model architecture
|
||||
# TODO: Re-enable torch.compile once model is stable
|
||||
# self.logger.info("Using torch.compile")
|
||||
# self.model = torch.compile(self.model)
|
||||
self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
|
||||
|
||||
self.logger.info(f"Initialized RNN decoding model")
|
||||
|
||||
self.logger.info(self.model)
|
||||
|
||||
# Log how many parameters are in the model
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
self.logger.info(f"Model has {total_params:,} parameters")
|
||||
|
||||
# Determine how many day-specific parameters are in the model
|
||||
day_params = 0
|
||||
for name, param in self.model.named_parameters():
|
||||
if 'day' in name:
|
||||
day_params += param.numel()
|
||||
|
||||
self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
|
||||
|
||||
# Create datasets and dataloaders
|
||||
train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
|
||||
val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
|
||||
|
||||
# Ensure that there are no duplicate days
|
||||
if len(set(train_file_paths)) != len(train_file_paths):
|
||||
raise ValueError("There are duplicate sessions listed in the train dataset")
|
||||
if len(set(val_file_paths)) != len(val_file_paths):
|
||||
raise ValueError("There are duplicate sessions listed in the val dataset")
|
||||
|
||||
# Split trials into train and test sets
|
||||
train_trials, _ = train_test_split_indicies(
|
||||
file_paths = train_file_paths,
|
||||
test_percentage = 0,
|
||||
seed = self.args['dataset']['seed'],
|
||||
bad_trials_dict = None,
|
||||
)
|
||||
_, val_trials = train_test_split_indicies(
|
||||
file_paths = val_file_paths,
|
||||
test_percentage = 1,
|
||||
seed = self.args['dataset']['seed'],
|
||||
bad_trials_dict = None,
|
||||
)
|
||||
|
||||
# Save dictionaries to output directory to know which trials were train vs val
|
||||
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
|
||||
json.dump({'train' : train_trials, 'val': val_trials}, f)
|
||||
|
||||
# Determine if a only a subset of neural features should be used
|
||||
feature_subset = None
|
||||
if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
|
||||
feature_subset = self.args['dataset']['feature_subset']
|
||||
self.logger.info(f'Using only a subset of features: {feature_subset}')
|
||||
|
||||
# train dataset and dataloader
|
||||
self.train_dataset = BrainToTextDataset(
|
||||
trial_indicies = train_trials,
|
||||
split = 'train',
|
||||
days_per_batch = self.args['dataset']['days_per_batch'],
|
||||
n_batches = self.args['num_training_batches'],
|
||||
batch_size = self.args['dataset']['batch_size'],
|
||||
must_include_days = None,
|
||||
random_seed = self.args['dataset']['seed'],
|
||||
feature_subset = feature_subset
|
||||
)
|
||||
# Custom collate function that handles pre-batched data from our dataset
|
||||
def collate_fn(batch):
|
||||
# Our dataset returns full batches, so batch will be a list of single batch dict
|
||||
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
|
||||
if len(batch) == 1 and isinstance(batch[0], dict):
|
||||
return batch[0]
|
||||
else:
|
||||
# Fallback for unexpected batch structure
|
||||
return batch
|
||||
|
||||
# DataLoader configuration compatible with Accelerate
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||
shuffle = self.args['dataset']['loader_shuffle'],
|
||||
num_workers = self.args['dataset']['num_dataloader_workers'],
|
||||
pin_memory = True,
|
||||
collate_fn = collate_fn
|
||||
)
|
||||
|
||||
# val dataset and dataloader
|
||||
self.val_dataset = BrainToTextDataset(
|
||||
trial_indicies = val_trials,
|
||||
split = 'test',
|
||||
days_per_batch = None,
|
||||
n_batches = None,
|
||||
batch_size = self.args['dataset']['batch_size'],
|
||||
must_include_days = None,
|
||||
random_seed = self.args['dataset']['seed'],
|
||||
feature_subset = feature_subset
|
||||
)
|
||||
# Validation DataLoader with same collate function
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||
shuffle = False,
|
||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||
pin_memory = True,
|
||||
collate_fn = collate_fn # Use same collate function
|
||||
)
|
||||
|
||||
self.logger.info("Successfully initialized datasets")
|
||||
|
||||
# Create optimizer, learning rate scheduler, and loss
|
||||
self.optimizer = self.create_optimizer()
|
||||
|
||||
if self.args['lr_scheduler_type'] == 'linear':
|
||||
self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
|
||||
optimizer = self.optimizer,
|
||||
start_factor = 1.0,
|
||||
end_factor = self.args['lr_min'] / self.args['lr_max'],
|
||||
total_iters = self.args['lr_decay_steps'],
|
||||
)
|
||||
elif self.args['lr_scheduler_type'] == 'cosine':
|
||||
self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
|
||||
|
||||
self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
|
||||
|
||||
# If a checkpoint is provided, then load from checkpoint
|
||||
if self.args['init_from_checkpoint']:
|
||||
self.load_model_checkpoint(self.args['init_checkpoint_path'])
|
||||
|
||||
# Set rnn and/or input layers to not trainable if specified
|
||||
for name, param in self.model.named_parameters():
|
||||
if not self.args['model']['rnn_trainable'] and 'gru' in name:
|
||||
param.requires_grad = False
|
||||
|
||||
elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
|
||||
param.requires_grad = False
|
||||
|
||||
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
||||
# Let Accelerator handle everything automatically for both GPU and TPU
|
||||
(
|
||||
self.model,
|
||||
self.optimizer,
|
||||
self.learning_rate_scheduler,
|
||||
self.train_loader,
|
||||
self.val_loader,
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.optimizer,
|
||||
self.learning_rate_scheduler,
|
||||
self.train_loader,
|
||||
self.val_loader,
|
||||
)
|
||||
|
||||
self.model_dtype = next(self.model.parameters()).dtype
|
||||
|
||||
self.logger.info("Prepared model and dataloaders with Accelerator")
|
||||
if self.adv_enabled:
|
||||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
|
||||
|
||||
def autocast_context(self):
|
||||
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
|
||||
if self.device.type == 'xla':
|
||||
return nullcontext()
|
||||
return self.accelerator.autocast()
|
||||
|
||||
def create_optimizer(self):
|
||||
'''
|
||||
Create the optimizer with special param groups
|
||||
|
||||
Biases and day weights should not be decayed
|
||||
|
||||
Day weights should have a separate learning rate
|
||||
'''
|
||||
bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
|
||||
day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
|
||||
other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
|
||||
|
||||
if len(day_params) != 0:
|
||||
param_groups = [
|
||||
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
|
||||
{'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
|
||||
{'params' : other_params, 'group_type' : 'other'}
|
||||
]
|
||||
else:
|
||||
param_groups = [
|
||||
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
|
||||
{'params' : other_params, 'group_type' : 'other'}
|
||||
]
|
||||
|
||||
optim = torch.optim.AdamW(
|
||||
param_groups,
|
||||
lr = self.args['lr_max'],
|
||||
betas = (self.args['beta0'], self.args['beta1']),
|
||||
eps = self.args['epsilon'],
|
||||
weight_decay = self.args['weight_decay'],
|
||||
fused = True
|
||||
)
|
||||
|
||||
return optim
|
||||
|
||||
def create_cosine_lr_scheduler(self, optim):
|
||||
lr_max = self.args['lr_max']
|
||||
lr_min = self.args['lr_min']
|
||||
lr_decay_steps = self.args['lr_decay_steps']
|
||||
|
||||
lr_max_day = self.args['lr_max_day']
|
||||
lr_min_day = self.args['lr_min_day']
|
||||
lr_decay_steps_day = self.args['lr_decay_steps_day']
|
||||
|
||||
lr_warmup_steps = self.args['lr_warmup_steps']
|
||||
lr_warmup_steps_day = self.args['lr_warmup_steps_day']
|
||||
|
||||
def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
|
||||
'''
|
||||
Create lr lambdas for each param group that implement cosine decay
|
||||
|
||||
Different lr lambda decaying for day params vs rest of the model
|
||||
'''
|
||||
# Warmup phase
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
|
||||
# Cosine decay phase
|
||||
if current_step < decay_steps:
|
||||
progress = float(current_step - warmup_steps) / float(
|
||||
max(1, decay_steps - warmup_steps)
|
||||
)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
||||
# Scale from 1.0 to min_lr_ratio
|
||||
return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
|
||||
|
||||
# After cosine decay is complete, maintain min_lr_ratio
|
||||
return min_lr_ratio
|
||||
|
||||
if len(optim.param_groups) == 3:
|
||||
lr_lambdas = [
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # biases
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min_day / lr_max_day,
|
||||
lr_decay_steps_day,
|
||||
lr_warmup_steps_day,
|
||||
), # day params
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # rest of model weights
|
||||
]
|
||||
elif len(optim.param_groups) == 2:
|
||||
lr_lambdas = [
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # biases
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # rest of model weights
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
|
||||
|
||||
return LambdaLR(optim, lr_lambdas, -1)
|
||||
|
||||
def load_model_checkpoint(self, load_path):
|
||||
'''
|
||||
Load a training checkpoint for distributed training
|
||||
'''
|
||||
# Load checkpoint on CPU first to avoid OOM issues
|
||||
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
|
||||
|
||||
# Get unwrapped model for loading state dict
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
|
||||
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
|
||||
|
||||
# Device handling is managed by Accelerator, no need to manually move to device
|
||||
|
||||
self.logger.info("Loaded model from checkpoint: " + load_path)
|
||||
|
||||
def save_model_checkpoint(self, save_path, PER, loss):
|
||||
'''
|
||||
Save a training checkpoint using Accelerator for distributed training
|
||||
'''
|
||||
# Only save on main process to avoid conflicts
|
||||
if self.accelerator.is_main_process:
|
||||
# Unwrap model to get base model for saving
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
checkpoint = {
|
||||
'model_state_dict' : unwrapped_model.state_dict(),
|
||||
'optimizer_state_dict' : self.optimizer.state_dict(),
|
||||
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
|
||||
'val_PER' : PER,
|
||||
'val_loss' : loss
|
||||
}
|
||||
|
||||
torch.save(checkpoint, save_path)
|
||||
|
||||
self.logger.info("Saved model to checkpoint: " + save_path)
|
||||
|
||||
# Save the args file alongside the checkpoint
|
||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||
OmegaConf.save(config=self.args, f=f)
|
||||
|
||||
# Wait for all processes to complete checkpoint saving
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def create_attention_mask(self, sequence_lengths):
|
||||
|
||||
max_length = torch.max(sequence_lengths).item()
|
||||
|
||||
batch_size = sequence_lengths.size(0)
|
||||
|
||||
# Create a mask for valid key positions (columns)
|
||||
# Shape: [batch_size, max_length]
|
||||
key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
|
||||
key_mask = key_mask < sequence_lengths.unsqueeze(1)
|
||||
|
||||
# Expand key_mask to [batch_size, 1, 1, max_length]
|
||||
# This will be broadcast across all query positions
|
||||
key_mask = key_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
|
||||
# by broadcasting key_mask across all query positions
|
||||
attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
|
||||
|
||||
# Convert boolean mask to float mask:
|
||||
# - True (valid key positions) -> 0.0 (no change to attention scores)
|
||||
# - False (padding key positions) -> -inf (will become 0 after softmax)
|
||||
attention_mask_float = torch.where(attention_mask,
|
||||
True,
|
||||
False)
|
||||
|
||||
return attention_mask_float
|
||||
|
||||
def transform_data(self, features, n_time_steps, mode = 'train'):
|
||||
'''
|
||||
Apply various augmentations and smoothing to data
|
||||
Performing augmentations is much faster on GPU than CPU
|
||||
'''
|
||||
|
||||
# TPU and GPU should now handle data consistently with our improved DataLoader configuration
|
||||
|
||||
data_shape = features.shape
|
||||
batch_size = data_shape[0]
|
||||
channels = data_shape[-1]
|
||||
|
||||
# We only apply these augmentations in training
|
||||
if mode == 'train':
|
||||
# add static gain noise
|
||||
if self.transform_args['static_gain_std'] > 0:
|
||||
warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
|
||||
warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
|
||||
|
||||
features = torch.matmul(features, warp_mat)
|
||||
|
||||
# add white noise
|
||||
if self.transform_args['white_noise_std'] > 0:
|
||||
features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
|
||||
|
||||
# add constant offset noise
|
||||
if self.transform_args['constant_offset_std'] > 0:
|
||||
features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
|
||||
|
||||
# add random walk noise
|
||||
if self.transform_args['random_walk_std'] > 0:
|
||||
features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
|
||||
|
||||
# randomly cutoff part of the data timecourse
|
||||
if self.transform_args['random_cut'] > 0:
|
||||
cut = np.random.randint(0, self.transform_args['random_cut'])
|
||||
features = features[:, cut:, :]
|
||||
n_time_steps = n_time_steps - cut
|
||||
|
||||
# Apply Gaussian smoothing to data
|
||||
# This is done in both training and validation
|
||||
if self.transform_args['smooth_data']:
|
||||
features = gauss_smooth(
|
||||
inputs = features,
|
||||
device = self.device,
|
||||
smooth_kernel_std = self.transform_args['smooth_kernel_std'],
|
||||
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
|
||||
)
|
||||
|
||||
if hasattr(self, 'model_dtype'):
|
||||
features = features.to(self.model_dtype)
|
||||
|
||||
|
||||
return features, n_time_steps
|
||||
|
||||
def train(self):
|
||||
'''
|
||||
Train the model
|
||||
'''
|
||||
|
||||
# Set model to train mode (specificially to make sure dropout layers are engaged)
|
||||
self.model.train()
|
||||
|
||||
# create vars to track performance
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
val_PERs = []
|
||||
val_results = []
|
||||
|
||||
val_steps_since_improvement = 0
|
||||
|
||||
# training params
|
||||
save_best_checkpoint = self.args.get('save_best_checkpoint', True)
|
||||
early_stopping = self.args.get('early_stopping', True)
|
||||
|
||||
early_stopping_val_steps = self.args.get('early_stopping_val_steps', 20)
|
||||
|
||||
train_start_time = time.time()
|
||||
|
||||
# train for specified number of batches
|
||||
self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...")
|
||||
for i, batch in enumerate(self.train_loader):
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Train step
|
||||
start_time = time.time()
|
||||
|
||||
# Data is automatically moved to device by Accelerator
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
phone_seq_lens = batch['phone_seq_lens']
|
||||
day_indicies = batch['day_indicies']
|
||||
|
||||
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
||||
with self.autocast_context():
|
||||
|
||||
# Apply augmentations to the data
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
||||
|
||||
# Ensure proper dtype handling for TPU mixed precision
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Get phoneme predictions using inference mode during training
|
||||
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if features.dtype != self.model_dtype:
|
||||
features = features.to(self.model_dtype)
|
||||
|
||||
# Forward pass: enable full adversarial mode if configured and past warmup
|
||||
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
|
||||
if use_full:
|
||||
clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
|
||||
else:
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
# Calculate CTC Loss
|
||||
if use_full:
|
||||
# Clean CTC loss
|
||||
clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
|
||||
clean_loss = self.ctc_loss(
|
||||
clean_log_probs,
|
||||
labels,
|
||||
adjusted_lens,
|
||||
phone_seq_lens
|
||||
)
|
||||
clean_loss = torch.mean(clean_loss)
|
||||
|
||||
# Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
|
||||
noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
|
||||
noisy_loss = self.ctc_loss(
|
||||
noisy_log_probs,
|
||||
labels,
|
||||
adjusted_lens,
|
||||
phone_seq_lens
|
||||
)
|
||||
noisy_loss = torch.mean(noisy_loss)
|
||||
|
||||
# Optional noise energy regularization
|
||||
noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
|
||||
if self.adv_noise_l2_weight > 0.0:
|
||||
noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
|
||||
|
||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||
else:
|
||||
log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
|
||||
loss = self.ctc_loss(
|
||||
log_probs=log_probs,
|
||||
targets=labels,
|
||||
input_lengths=adjusted_lens,
|
||||
target_lengths=phone_seq_lens
|
||||
)
|
||||
loss = torch.mean(loss) # take mean loss over batches
|
||||
|
||||
# Use Accelerator's backward for distributed training
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# Clip gradient using Accelerator's clip_grad_norm_
|
||||
if self.args['grad_norm_clip_value'] > 0:
|
||||
grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(),
|
||||
max_norm = self.args['grad_norm_clip_value'])
|
||||
|
||||
self.optimizer.step()
|
||||
self.learning_rate_scheduler.step()
|
||||
|
||||
# Save training metrics
|
||||
train_step_duration = time.time() - start_time
|
||||
train_losses.append(loss.detach().item())
|
||||
|
||||
# Incrementally log training progress
|
||||
if i % self.args['batches_per_train_log'] == 0:
|
||||
self.logger.info(f'Train batch {i}: ' +
|
||||
f'loss: {(loss.detach().item()):.2f} ' +
|
||||
f'grad norm: {grad_norm:.2f} '
|
||||
f'time: {train_step_duration:.3f}')
|
||||
|
||||
# Incrementally run a test step
|
||||
if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
|
||||
self.logger.info(f"Running test after training batch: {i}")
|
||||
|
||||
# Calculate metrics on val data
|
||||
start_time = time.time()
|
||||
val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
|
||||
val_step_duration = time.time() - start_time
|
||||
|
||||
|
||||
# Log info
|
||||
self.logger.info(f'Val batch {i}: ' +
|
||||
f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
|
||||
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
|
||||
f'time: {val_step_duration:.3f}')
|
||||
|
||||
if self.args['log_individual_day_val_PER']:
|
||||
for day in val_metrics['day_PERs'].keys():
|
||||
self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
|
||||
|
||||
# Save metrics
|
||||
val_PERs.append(val_metrics['avg_PER'])
|
||||
val_losses.append(val_metrics['avg_loss'])
|
||||
val_results.append(val_metrics)
|
||||
|
||||
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
|
||||
new_best = False
|
||||
if val_metrics['avg_PER'] < self.best_val_PER:
|
||||
self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
|
||||
self.best_val_PER = val_metrics['avg_PER']
|
||||
self.best_val_loss = val_metrics['avg_loss']
|
||||
new_best = True
|
||||
elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
|
||||
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
|
||||
self.best_val_loss = val_metrics['avg_loss']
|
||||
new_best = True
|
||||
|
||||
if new_best:
|
||||
|
||||
# Checkpoint if metrics have improved
|
||||
if save_best_checkpoint:
|
||||
self.logger.info(f"Checkpointing model")
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
|
||||
|
||||
# save validation metrics to pickle file
|
||||
if self.args['save_val_metrics']:
|
||||
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||||
pickle.dump(val_metrics, f)
|
||||
|
||||
val_steps_since_improvement = 0
|
||||
|
||||
else:
|
||||
val_steps_since_improvement +=1
|
||||
|
||||
# Optionally save this validation checkpoint, regardless of performance
|
||||
if self.args['save_all_val_steps']:
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
|
||||
|
||||
# Early stopping
|
||||
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
|
||||
self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
|
||||
break
|
||||
|
||||
# Log final training steps
|
||||
training_duration = time.time() - train_start_time
|
||||
|
||||
|
||||
self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
|
||||
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
|
||||
|
||||
# Save final model
|
||||
if self.args['save_final_model']:
|
||||
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
|
||||
|
||||
train_stats = {}
|
||||
train_stats['train_losses'] = train_losses
|
||||
train_stats['val_losses'] = val_losses
|
||||
train_stats['val_PERs'] = val_PERs
|
||||
train_stats['val_metrics'] = val_results
|
||||
|
||||
return train_stats
|
||||
|
||||
def validation(self, loader, return_logits = False, return_data = False):
|
||||
'''
|
||||
Calculate metrics on the validation dataset
|
||||
'''
|
||||
self.model.eval()
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Record metrics
|
||||
if return_logits:
|
||||
metrics['logits'] = []
|
||||
metrics['n_time_steps'] = []
|
||||
|
||||
if return_data:
|
||||
metrics['input_features'] = []
|
||||
|
||||
metrics['decoded_seqs'] = []
|
||||
metrics['true_seq'] = []
|
||||
metrics['phone_seq_lens'] = []
|
||||
metrics['transcription'] = []
|
||||
metrics['losses'] = []
|
||||
metrics['block_nums'] = []
|
||||
metrics['trial_nums'] = []
|
||||
metrics['day_indicies'] = []
|
||||
|
||||
total_edit_distance = 0
|
||||
total_seq_length = 0
|
||||
|
||||
# Calculate PER for each specific day
|
||||
day_per = {}
|
||||
for d in range(len(self.args['dataset']['sessions'])):
|
||||
if self.args['dataset']['dataset_probability_val'][d] == 1:
|
||||
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
|
||||
|
||||
for i, batch in enumerate(loader):
|
||||
|
||||
# Data is automatically moved to device by Accelerator
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
phone_seq_lens = batch['phone_seq_lens']
|
||||
day_indicies = batch['day_indicies']
|
||||
|
||||
# Determine if we should perform validation on this batch
|
||||
day = day_indicies[0].item()
|
||||
if self.args['dataset']['dataset_probability_val'][day] == 0:
|
||||
if self.args['log_val_skip_logs']:
|
||||
self.logger.info(f"Skipping validation on day {day}")
|
||||
continue
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
with self.autocast_context():
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Ensure proper dtype handling for TPU mixed precision
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
model_param = next(self.model.parameters()) if self.model is not None else None
|
||||
if model_param is not None and features.dtype != model_param.dtype:
|
||||
features = features.to(model_param.dtype)
|
||||
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
|
||||
loss = self.ctc_loss(
|
||||
val_log_probs,
|
||||
labels,
|
||||
adjusted_lens,
|
||||
phone_seq_lens,
|
||||
)
|
||||
loss = torch.mean(loss)
|
||||
|
||||
metrics['losses'].append(loss.cpu().detach().numpy())
|
||||
|
||||
# Calculate PER per day and also avg over entire validation set
|
||||
batch_edit_distance = 0
|
||||
decoded_seqs = []
|
||||
for iterIdx in range(logits.shape[0]):
|
||||
decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
|
||||
decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
|
||||
decoded_seq = decoded_seq.cpu().detach().numpy()
|
||||
decoded_seq = np.array([i for i in decoded_seq if i != 0])
|
||||
|
||||
trueSeq = np.array(
|
||||
labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
|
||||
)
|
||||
|
||||
batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
|
||||
|
||||
decoded_seqs.append(decoded_seq)
|
||||
|
||||
day = batch['day_indicies'][0].item()
|
||||
|
||||
day_per[day]['total_edit_distance'] += batch_edit_distance
|
||||
day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
|
||||
|
||||
|
||||
total_edit_distance += batch_edit_distance
|
||||
total_seq_length += torch.sum(phone_seq_lens)
|
||||
|
||||
# Record metrics
|
||||
if return_logits:
|
||||
metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
|
||||
metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
|
||||
|
||||
if return_data:
|
||||
metrics['input_features'].append(batch['input_features'].cpu().numpy())
|
||||
|
||||
metrics['decoded_seqs'].append(decoded_seqs)
|
||||
metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
|
||||
metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
|
||||
metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
|
||||
metrics['losses'].append(loss.detach().item())
|
||||
metrics['block_nums'].append(batch['block_nums'].numpy())
|
||||
metrics['trial_nums'].append(batch['trial_nums'].numpy())
|
||||
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
|
||||
|
||||
if isinstance(total_seq_length, torch.Tensor):
|
||||
total_length_value = float(total_seq_length.item())
|
||||
else:
|
||||
total_length_value = float(total_seq_length)
|
||||
|
||||
avg_PER = total_edit_distance / max(total_length_value, 1e-6)
|
||||
|
||||
metrics['day_PERs'] = day_per
|
||||
metrics['avg_PER'] = avg_PER
|
||||
metrics['avg_loss'] = float(np.mean(metrics['losses']))
|
||||
|
||||
return metrics
|
||||
|
||||
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
|
||||
'''
|
||||
TPU-compatible inference method for generating phoneme logits
|
||||
'''
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with self.autocast_context():
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if features.dtype != self.model_dtype:
|
||||
features = features.to(self.model_dtype)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
return logits
|
||||
|
||||
def inference_batch(self, batch, mode='inference'):
|
||||
'''
|
||||
Inference method for processing a full batch
|
||||
'''
|
||||
self.model.eval()
|
||||
|
||||
# Data is automatically moved to device by Accelerator
|
||||
features = batch['input_features']
|
||||
day_indicies = batch['day_indicies']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
|
||||
with torch.no_grad():
|
||||
with self.autocast_context():
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Calculate adjusted sequence lengths for CTC with proper dtype handling
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if features.dtype != self.model_dtype:
|
||||
features = features.to(self.model_dtype)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
return logits, adjusted_lens
|
||||
@@ -1,150 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Setup script for TensorFlow Brain-to-Text training on TPU v5e-8
|
||||
#
|
||||
# Usage: ./setup_tensorflow_tpu.sh
|
||||
#
|
||||
# This script prepares the environment for training the brain-to-text model
|
||||
# using TensorFlow on TPU v5e-8 hardware.
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "=== TensorFlow TPU v5e-8 Setup Script ==="
|
||||
echo "Setting up environment for brain-to-text training..."
|
||||
|
||||
# Check if we're in a TPU environment
|
||||
if [[ -z "${TPU_NAME}" ]] && [[ -z "${COLAB_TPU_ADDR}" ]]; then
|
||||
echo "Warning: TPU environment variables not detected."
|
||||
echo "Make sure you're running on a TPU v5e-8 instance."
|
||||
fi
|
||||
|
||||
# Create conda environment for TensorFlow TPU
|
||||
ENV_NAME="b2txt_tf"
|
||||
echo "Creating conda environment: ${ENV_NAME}"
|
||||
|
||||
if conda env list | grep -q "^${ENV_NAME} "; then
|
||||
echo "Environment ${ENV_NAME} already exists. Activating..."
|
||||
conda activate ${ENV_NAME}
|
||||
else
|
||||
echo "Creating new environment..."
|
||||
conda create -n ${ENV_NAME} python=3.10 -y
|
||||
conda activate ${ENV_NAME}
|
||||
fi
|
||||
|
||||
# Install TensorFlow with TPU support
|
||||
echo "Installing TensorFlow with TPU support..."
|
||||
pip install tensorflow[and-cuda]>=2.15.0
|
||||
|
||||
# Install additional requirements
|
||||
echo "Installing additional requirements..."
|
||||
pip install -r requirements_tf.txt
|
||||
|
||||
# Set up TPU environment variables
|
||||
echo "Configuring TPU environment variables..."
|
||||
|
||||
# Create or update .bashrc with TPU optimizations
|
||||
cat >> ~/.bashrc << 'EOF'
|
||||
|
||||
# TPU v5e-8 Environment Variables
|
||||
export TPU_ML_PLATFORM="TensorFlow"
|
||||
export XLA_USE_BF16=1
|
||||
export TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"
|
||||
export TPU_MEGACORE=1
|
||||
export LIBTPU_INIT_ARGS="--xla_tpu_spmd_threshold_for_allgather_cse=10000"
|
||||
|
||||
# Disable TensorFlow warnings for cleaner output
|
||||
export TF_CPP_MIN_LOG_LEVEL=2
|
||||
|
||||
# Memory optimizations
|
||||
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
||||
export TF_GPU_THREAD_MODE=gpu_private
|
||||
|
||||
EOF
|
||||
|
||||
# Source the updated .bashrc
|
||||
source ~/.bashrc
|
||||
|
||||
# Test TPU connectivity
|
||||
echo "Testing TPU connectivity..."
|
||||
python3 << 'EOF'
|
||||
import tensorflow as tf
|
||||
print("TensorFlow version:", tf.__version__)
|
||||
|
||||
try:
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.TPUStrategy(resolver)
|
||||
print(f"TPU cluster initialized successfully!")
|
||||
print(f"Number of TPU cores: {strategy.num_replicas_in_sync}")
|
||||
print(f"TPU devices: {tf.config.list_logical_devices('TPU')}")
|
||||
except Exception as e:
|
||||
print(f"TPU initialization failed: {e}")
|
||||
print("You may be running on CPU/GPU instead of TPU")
|
||||
|
||||
# Test mixed precision
|
||||
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
|
||||
tf.keras.mixed_precision.set_global_policy(policy)
|
||||
print(f"Mixed precision policy: {policy.name}")
|
||||
EOF
|
||||
|
||||
# Verify data directory exists
|
||||
DATA_DIR="../data/hdf5_data_final"
|
||||
if [ -d "$DATA_DIR" ]; then
|
||||
echo "Data directory found: $DATA_DIR"
|
||||
# Count available sessions
|
||||
SESSION_COUNT=$(ls -d $DATA_DIR/t*.20* 2>/dev/null | wc -l)
|
||||
echo "Available sessions: $SESSION_COUNT"
|
||||
else
|
||||
echo "Warning: Data directory not found at $DATA_DIR"
|
||||
echo "Please ensure the dataset is available before training."
|
||||
fi
|
||||
|
||||
# Create output directories
|
||||
echo "Creating output directories..."
|
||||
mkdir -p trained_models/tensorflow_tpu
|
||||
mkdir -p logs/tensorflow_tpu
|
||||
mkdir -p eval_output
|
||||
|
||||
# Make scripts executable
|
||||
echo "Setting script permissions..."
|
||||
chmod +x train_model_tf.py
|
||||
chmod +x evaluate_model_tf.py
|
||||
|
||||
# Display system information
|
||||
echo "=== System Information ==="
|
||||
echo "Python version: $(python --version)"
|
||||
echo "Conda environment: $CONDA_DEFAULT_ENV"
|
||||
echo "Available memory: $(free -h | grep '^Mem:' | awk '{print $7}')"
|
||||
echo "CPU cores: $(nproc)"
|
||||
|
||||
# Check for GPU/TPU
|
||||
echo "=== Hardware Information ==="
|
||||
if nvidia-smi &> /dev/null; then
|
||||
echo "NVIDIA GPUs detected:"
|
||||
nvidia-smi --list-gpus
|
||||
else
|
||||
echo "No NVIDIA GPUs detected"
|
||||
fi
|
||||
|
||||
if [[ -n "${TPU_NAME}" ]]; then
|
||||
echo "TPU Name: $TPU_NAME"
|
||||
elif [[ -n "${COLAB_TPU_ADDR}" ]]; then
|
||||
echo "Colab TPU Address: $COLAB_TPU_ADDR"
|
||||
else
|
||||
echo "No TPU environment variables detected"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Setup Complete ==="
|
||||
echo "Environment '$ENV_NAME' is ready for TensorFlow TPU training."
|
||||
echo ""
|
||||
echo "To activate the environment:"
|
||||
echo " conda activate $ENV_NAME"
|
||||
echo ""
|
||||
echo "To start training:"
|
||||
echo " python train_model_tf.py --config_path rnn_args.yaml"
|
||||
echo ""
|
||||
echo "To run evaluation:"
|
||||
echo " python evaluate_model_tf.py --model_path path/to/checkpoint --config_path rnn_args.yaml"
|
||||
echo ""
|
||||
echo "For more options, use --help with any script."
|
||||
@@ -1,236 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TPU内存监控工具 - 专门用于训练过程
|
||||
解决tf.config.experimental.get_memory_info()在TPU上无法工作的问题
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
import time
|
||||
import psutil
|
||||
import os
|
||||
|
||||
class TPUMemoryMonitor:
|
||||
"""TPU内存监控类"""
|
||||
|
||||
def __init__(self):
|
||||
self.tpu_devices = tf.config.list_logical_devices('TPU')
|
||||
self.baseline_memory = None
|
||||
self.peak_allocations = {}
|
||||
|
||||
def get_tpu_status(self) -> str:
|
||||
"""获取TPU状态 - 实用版本,不依赖get_memory_info"""
|
||||
try:
|
||||
if not self.tpu_devices:
|
||||
return "TPU: No devices"
|
||||
|
||||
num_cores = len(self.tpu_devices)
|
||||
|
||||
# 测试TPU响应性
|
||||
try:
|
||||
with tf.device('/TPU:0'):
|
||||
test_tensor = tf.constant([1.0, 2.0, 3.0])
|
||||
result = tf.reduce_sum(test_tensor)
|
||||
_ = result.numpy() # 强制执行
|
||||
activity = "active"
|
||||
except Exception:
|
||||
activity = "inactive"
|
||||
|
||||
# 获取主机内存作为参考
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
host_mem = f"Host:{memory.percent:.1f}%"
|
||||
except:
|
||||
host_mem = "Host:unknown"
|
||||
|
||||
return f"TPU: {num_cores}cores {activity} {host_mem}"
|
||||
|
||||
except Exception as e:
|
||||
return f"TPU: error({str(e)[:20]})"
|
||||
|
||||
def estimate_tensor_memory(self, tensor_shape, dtype=tf.float32):
|
||||
"""估算张量内存使用量"""
|
||||
if dtype == tf.float32:
|
||||
bytes_per_element = 4
|
||||
elif dtype == tf.float16 or dtype == tf.bfloat16:
|
||||
bytes_per_element = 2
|
||||
elif dtype == tf.int32:
|
||||
bytes_per_element = 4
|
||||
elif dtype == tf.int64:
|
||||
bytes_per_element = 8
|
||||
else:
|
||||
bytes_per_element = 4 # 默认
|
||||
|
||||
total_elements = 1
|
||||
for dim in tensor_shape:
|
||||
total_elements *= dim
|
||||
|
||||
total_bytes = total_elements * bytes_per_element
|
||||
return total_bytes / (1024 * 1024) # 返回MB
|
||||
|
||||
def track_allocation(self, name: str, tensor_shape, dtype=tf.float32):
|
||||
"""跟踪内存分配"""
|
||||
mb = self.estimate_tensor_memory(tensor_shape, dtype)
|
||||
self.peak_allocations[name] = self.peak_allocations.get(name, 0) + mb
|
||||
return mb
|
||||
|
||||
def get_allocation_summary(self) -> str:
|
||||
"""获取分配汇总"""
|
||||
if not self.peak_allocations:
|
||||
return "No allocations tracked"
|
||||
|
||||
total_mb = sum(self.peak_allocations.values())
|
||||
top_3 = sorted(self.peak_allocations.items(), key=lambda x: x[1], reverse=True)[:3]
|
||||
|
||||
summary = f"Tracked:{total_mb:.1f}MB "
|
||||
summary += f"Top:({top_3[0][0]}:{top_3[0][1]:.1f}MB)"
|
||||
|
||||
return summary
|
||||
|
||||
def test_memory_allocation_across_cores(self):
|
||||
"""测试8个核心的内存分配"""
|
||||
print("🧪 测试所有TPU核心内存分配")
|
||||
print("=" * 40)
|
||||
|
||||
allocations_per_core = []
|
||||
|
||||
for i, device in enumerate(self.tpu_devices):
|
||||
print(f"核心 {i+1}: {device.name}")
|
||||
|
||||
try:
|
||||
with tf.device(device.name):
|
||||
# 创建不同大小的测试张量
|
||||
test_sizes = [
|
||||
([1000, 1000], "1K×1K"),
|
||||
([3000, 3000], "3K×3K"),
|
||||
([5000, 5000], "5K×5K"),
|
||||
([7000, 7000], "7K×7K"),
|
||||
]
|
||||
|
||||
core_total = 0
|
||||
successful_allocs = []
|
||||
|
||||
for shape, desc in test_sizes:
|
||||
try:
|
||||
tensor = tf.ones(shape, dtype=tf.float32)
|
||||
mb = self.estimate_tensor_memory(shape)
|
||||
core_total += mb
|
||||
successful_allocs.append(f"{desc}({mb:.1f}MB)")
|
||||
|
||||
# 实际使用张量防止被优化
|
||||
_ = tf.reduce_mean(tensor)
|
||||
|
||||
except Exception as e:
|
||||
print(f" {desc} 失败: {str(e)[:30]}")
|
||||
break
|
||||
|
||||
allocations_per_core.append(core_total)
|
||||
print(f" 成功分配: {' + '.join(successful_allocs)}")
|
||||
print(f" 核心总计: {core_total:.1f}MB")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 核心{i+1}失败: {e}")
|
||||
allocations_per_core.append(0)
|
||||
|
||||
# 汇总结果
|
||||
total_all_cores = sum(allocations_per_core)
|
||||
avg_per_core = total_all_cores / len(self.tpu_devices) if self.tpu_devices else 0
|
||||
|
||||
print(f"\n📊 汇总结果:")
|
||||
print(f" 总分配: {total_all_cores:.1f}MB ({total_all_cores/1024:.2f}GB)")
|
||||
print(f" 平均每核: {avg_per_core:.1f}MB ({avg_per_core/1024:.2f}GB)")
|
||||
|
||||
# 推测内存配置
|
||||
if avg_per_core > 8000: # > 8GB
|
||||
print(" 推测: 每核心≥16GB (高端配置)")
|
||||
elif avg_per_core > 4000: # > 4GB
|
||||
print(" 推测: 每核心8-16GB (标准配置)")
|
||||
elif avg_per_core > 1000: # > 1GB
|
||||
print(" 推测: 每核心2-8GB (受限或共享)")
|
||||
else:
|
||||
print(" 推测: 每核心<2GB (严重受限)")
|
||||
|
||||
return allocations_per_core
|
||||
|
||||
def test_training_memory_pattern():
|
||||
"""测试模拟训练的内存模式"""
|
||||
print("\n🏋️ 模拟训练内存模式测试")
|
||||
print("=" * 30)
|
||||
|
||||
monitor = TPUMemoryMonitor()
|
||||
|
||||
# 模拟典型的brain-to-text模型内存使用
|
||||
with tf.device('/TPU:0'):
|
||||
print("创建模拟模型组件...")
|
||||
|
||||
# 1. 输入数据 (batch_size=32, seq_len=1000, features=512)
|
||||
batch_size, seq_len, features = 32, 1000, 512
|
||||
input_data = tf.random.normal([batch_size, seq_len, features])
|
||||
input_mb = monitor.track_allocation("input_data", [batch_size, seq_len, features])
|
||||
print(f" 输入数据: {input_mb:.1f}MB")
|
||||
|
||||
# 2. GRU权重 (假设3层, 每层256单元)
|
||||
n_layers, n_units = 3, 256
|
||||
for layer in range(n_layers):
|
||||
# GRU有3个门,每个门需要权重矩阵
|
||||
weight_shape = [features if layer == 0 else n_units, n_units * 3]
|
||||
weights = tf.random.normal(weight_shape)
|
||||
weight_mb = monitor.track_allocation(f"gru_layer_{layer}", weight_shape)
|
||||
print(f" GRU层{layer+1}权重: {weight_mb:.1f}MB")
|
||||
|
||||
# 3. 输出投影层 (n_units -> n_classes=41)
|
||||
n_classes = 41
|
||||
output_weights = tf.random.normal([n_units, n_classes])
|
||||
output_mb = monitor.track_allocation("output_projection", [n_units, n_classes])
|
||||
print(f" 输出投影: {output_mb:.1f}MB")
|
||||
|
||||
# 4. 中间激活值 (前向传播)
|
||||
hidden_states = tf.random.normal([batch_size, seq_len, n_units])
|
||||
hidden_mb = monitor.track_allocation("hidden_states", [batch_size, seq_len, n_units])
|
||||
print(f" 隐藏状态: {hidden_mb:.1f}MB")
|
||||
|
||||
# 5. 梯度 (反向传播时会翻倍内存)
|
||||
total_params_mb = sum([v for k, v in monitor.peak_allocations.items() if 'layer' in k or 'projection' in k])
|
||||
gradient_mb = total_params_mb # 梯度内存约等于参数内存
|
||||
print(f" 梯度内存: {gradient_mb:.1f}MB (估算)")
|
||||
|
||||
print(f"\n模型总内存估算: {monitor.get_allocation_summary()}")
|
||||
|
||||
# 实际执行一些操作确保内存被分配
|
||||
result = tf.reduce_mean(input_data) + tf.reduce_mean(hidden_states)
|
||||
print(f"验证计算结果: {result.numpy():.4f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 TPU内存监控工具启动")
|
||||
|
||||
monitor = TPUMemoryMonitor()
|
||||
|
||||
# 基础状态检查
|
||||
print(f"当前TPU状态: {monitor.get_tpu_status()}")
|
||||
|
||||
# 测试所有核心
|
||||
print("\n" + "="*50)
|
||||
core_allocations = monitor.test_memory_allocation_across_cores()
|
||||
|
||||
# 训练内存模式测试
|
||||
print("\n" + "="*50)
|
||||
test_training_memory_pattern()
|
||||
|
||||
print(f"\n🎯 关键发现:")
|
||||
if core_allocations:
|
||||
max_core = max(core_allocations)
|
||||
min_core = min([x for x in core_allocations if x > 0])
|
||||
print(f" 最大单核分配: {max_core:.1f}MB")
|
||||
print(f" 最小单核分配: {min_core:.1f}MB")
|
||||
|
||||
if max_core > 9000: # 你之前测试到9.4GB
|
||||
print(" ✅ 内存充足,可支持大模型训练")
|
||||
elif max_core > 5000:
|
||||
print(" ⚠️ 内存中等,建议优化模型大小")
|
||||
else:
|
||||
print(" ❌ 内存不足,需要大幅减少模型参数")
|
||||
|
||||
print(f"\n💡 针对你的训练卡顿问题:")
|
||||
print(f" - SetPriority错误通常是XLA编译问题,不是内存问题")
|
||||
print(f" - 你的9.4GB测试说明TPU内存工作正常")
|
||||
print(f" - 建议检查模型是否有导致XLA编译卡顿的操作")
|
||||
print(f" - 考虑使用更简单的操作或关闭某些XLA优化")
|
||||
@@ -1,25 +0,0 @@
|
||||
import argparse
|
||||
from omegaconf import OmegaConf
|
||||
from rnn_trainer import BrainToTextDecoder_Trainer
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
|
||||
parser.add_argument('--config_path', default='rnn_args.yaml',
|
||||
help='Path to configuration file (default: rnn_args.yaml)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
config = OmegaConf.load(args.config_path)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BrainToTextDecoder_Trainer(config)
|
||||
|
||||
# Start training
|
||||
trainer.train()
|
||||
|
||||
print("Training completed successfully!")
|
||||
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,7 +8,7 @@ It provides the same functionality as the PyTorch version but with TensorFlow
|
||||
operations optimized for TPU performance.
|
||||
|
||||
Usage:
|
||||
python train_model_tf.py --config_path rnn_args.yaml
|
||||
python train_model_tf.py -config_path rnn_args.yaml
|
||||
|
||||
Requirements:
|
||||
- TensorFlow >= 2.15.0
|
||||
|
||||
Reference in New Issue
Block a user