Remove setup script, TPU memory monitor, and training model script

- Deleted `setup_tensorflow_tpu.sh` which was responsible for setting up the TensorFlow environment on TPU v5e-8.
- Removed `tpu_memory_monitor.py`, a tool for monitoring TPU memory usage during training.
- Eliminated `train_model.py`, the script for training the Brain-to-Text RNN model.
This commit is contained in:
Zchen
2025-10-20 11:05:03 +08:00
parent 7c272b7c5b
commit f8fb4d7133
10 changed files with 1 additions and 3302 deletions

View File

@@ -1,315 +0,0 @@
#!/usr/bin/env python3
"""
使用AMP的TPU训练脚本
正确处理混合精度训练避免dtype不匹配问题
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 设置AMP相关的环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true'
)
os.environ['XLA_USE_BF16'] = '1' # 启用bf16
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.amp as xla_amp
class AMPModel(nn.Module):
"""支持AMP的简单模型"""
def __init__(self, input_size=784, hidden_size=512, num_classes=10):
super(AMPModel, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(hidden_size // 2, num_classes)
)
def forward(self, x):
# 展平输入
x = x.view(x.size(0), -1)
return self.network(x)
class AMPTrainer:
"""AMP训练器"""
def __init__(self, model, device, learning_rate=0.001):
self.model = model
self.device = device
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss()
# 初始化AMP scaler
self.scaler = xla_amp.GradScaler()
print(f"✅ AMP训练器初始化完成")
print(f" 设备: {device}")
print(f" 模型参数: {sum(p.numel() for p in model.parameters()):,}")
def train_step(self, data, target):
"""单个AMP训练步骤"""
self.model.train()
self.optimizer.zero_grad()
# 使用autocast进行混合精度前向传播
with xla_amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
# 使用scaler进行反向传播
self.scaler.scale(loss).backward()
# 梯度裁剪(可选)
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
# 计算准确率
pred = output.argmax(dim=1)
correct = pred.eq(target).sum().item()
accuracy = correct / target.size(0)
return loss.item(), accuracy
def evaluate_step(self, data, target):
"""单个评估步骤"""
self.model.eval()
with torch.no_grad():
with xla_amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
pred = output.argmax(dim=1)
correct = pred.eq(target).sum().item()
accuracy = correct / target.size(0)
return loss.item(), accuracy
def get_mnist_loaders(batch_size=64):
"""获取MNIST数据加载器"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=False,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
return train_loader, test_loader
def train_with_amp():
"""使用AMP进行训练"""
print("🚀 开始AMP TPU训练...")
# 获取设备
device = xm.xla_device()
print(f"📱 设备: {device}")
# 创建模型
model = AMPModel(input_size=784, hidden_size=512, num_classes=10).to(device)
# 创建训练器
trainer = AMPTrainer(model, device, learning_rate=0.001)
# 获取数据
print("📥 加载MNIST数据...")
train_loader, test_loader = get_mnist_loaders(batch_size=64)
# 使用XLA并行加载器
train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
print("🎯 开始AMP训练...")
# 训练循环
num_epochs = 2
train_losses = []
train_accuracies = []
for epoch in range(num_epochs):
print(f"\n📊 Epoch {epoch + 1}/{num_epochs}")
epoch_start = time.time()
epoch_loss = 0.0
epoch_acc = 0.0
num_batches = 0
max_batches_per_epoch = 200 # 限制每个epoch的批次数
for batch_idx, (data, target) in enumerate(train_device_loader):
if batch_idx >= max_batches_per_epoch:
break
# 训练步骤
loss, accuracy = trainer.train_step(data, target)
epoch_loss += loss
epoch_acc += accuracy
num_batches += 1
# 每20个批次同步一次
if batch_idx % 20 == 0:
xm.mark_step()
avg_loss = epoch_loss / num_batches
avg_acc = epoch_acc / num_batches * 100
print(f" 批次 {batch_idx:3d}/{max_batches_per_epoch} | "
f"损失: {avg_loss:.4f} | "
f"准确率: {avg_acc:.2f}%")
# Epoch结束同步
xm.mark_step()
xm.wait_device_ops()
epoch_time = time.time() - epoch_start
final_loss = epoch_loss / num_batches
final_acc = epoch_acc / num_batches * 100
train_losses.append(final_loss)
train_accuracies.append(final_acc)
print(f"✅ Epoch {epoch + 1} 完成 | "
f"耗时: {epoch_time:.2f}s | "
f"平均损失: {final_loss:.4f} | "
f"平均准确率: {final_acc:.2f}%")
return trainer, train_losses, train_accuracies
def test_with_amp(trainer):
"""使用AMP进行测试"""
print("\n🧪 开始AMP测试...")
device = xm.xla_device()
_, test_loader = get_mnist_loaders(batch_size=64)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
total_loss = 0.0
total_acc = 0.0
num_batches = 0
max_test_batches = 100
start_time = time.time()
for batch_idx, (data, target) in enumerate(test_device_loader):
if batch_idx >= max_test_batches:
break
loss, accuracy = trainer.evaluate_step(data, target)
total_loss += loss
total_acc += accuracy
num_batches += 1
if batch_idx % 20 == 0:
xm.mark_step()
xm.mark_step()
xm.wait_device_ops()
test_time = time.time() - start_time
avg_loss = total_loss / num_batches
avg_acc = total_acc / num_batches * 100
print(f"✅ 测试完成!")
print(f"⏱️ 测试时间: {test_time:.2f}")
print(f"🎯 测试损失: {avg_loss:.4f}")
print(f"🎯 测试准确率: {avg_acc:.2f}%")
return avg_loss, avg_acc
def main():
"""主函数"""
print("=" * 60)
print("⚡ AMP TPU训练示例")
print("=" * 60)
try:
# 训练
trainer, train_losses, train_accuracies = train_with_amp()
# 测试
test_loss, test_acc = test_with_amp(trainer)
# 保存模型
print("\n💾 保存模型...")
model_cpu = trainer.model.cpu()
torch.save({
'model_state_dict': model_cpu.state_dict(),
'train_losses': train_losses,
'train_accuracies': train_accuracies,
'test_loss': test_loss,
'test_accuracy': test_acc
}, 'amp_mnist_model.pth')
print("✅ 模型已保存到 amp_mnist_model.pth")
print("\n🎉 AMP训练完成!")
print(f"📊 最终训练准确率: {train_accuracies[-1]:.2f}%")
print(f"📊 测试准确率: {test_acc:.2f}%")
if train_accuracies[-1] > 85 and test_acc > 80:
print("✅ AMP训练成功! 模型性能优秀")
else:
print("⚠️ 模型性能一般但AMP功能正常")
except Exception as e:
print(f"❌ AMP训练失败: {e}")
import traceback
traceback.print_exc()
print("\n💡 故障排除建议:")
print(" 1. 确保PyTorch XLA版本支持AMP")
print(" 2. 检查TPU资源是否充足")
print(" 3. 尝试减小batch_size")
if __name__ == "__main__":
main()

View File

@@ -1,403 +0,0 @@
#!/usr/bin/env python3
"""
TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控
适用于TPU v5e-8环境
"""
import tensorflow as tf
import time
import numpy as np
def monitor_tpu_during_training():
"""训练过程中的TPU实时内存和MXU监控"""
print("📊 TPU训练实时监控工具")
print("=" * 50)
# 获取TPU设备
try:
tpu_devices = tf.config.list_logical_devices('TPU')
print(f"📍 发现TPU设备: {len(tpu_devices)}")
if not tpu_devices:
print("❌ 未发现TPU设备")
return
except Exception as e:
print(f"❌ 无法检测TPU设备: {e}")
return
def get_detailed_memory_snapshot():
"""获取详细的内存快照,包含所有核心信息"""
snapshot = {}
total_current = 0
total_peak = 0
active_cores = 0
core_details = []
for i, device in enumerate(tpu_devices):
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
if current_mb > 1: # >1MB算活跃
active_cores += 1
total_current += current_mb
total_peak += peak_mb
core_details.append(f"Core{i}:{current_mb}MB")
snapshot[f'core_{i}'] = {
'current': current_mb,
'peak': peak_mb,
'device': device.name
}
else:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name}
except Exception as e:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)}
snapshot['summary'] = {
'total_current': total_current,
'total_peak': total_peak,
'active_cores': active_cores,
'total_cores': len(tpu_devices),
'core_details': core_details
}
return snapshot
def test_mxu_performance():
"""测试MXU性能和计算能力"""
print("\n🧮 MXU计算性能测试:")
mxu_results = []
try:
with tf.device(tpu_devices[0].name):
# 测试不同规模的矩阵运算
test_configs = [
(2000, "2K×2K", tf.bfloat16),
(4000, "4K×4K", tf.bfloat16),
(6000, "6K×6K", tf.bfloat16),
]
for size, desc, dtype in test_configs:
try:
# 获取测试前内存
pre_mem = get_detailed_memory_snapshot()
start_time = time.time()
# 创建矩阵并执行MXU密集型运算
matrix_a = tf.random.normal([size, size], dtype=dtype)
matrix_b = tf.random.normal([size, size], dtype=dtype)
@tf.function
def mxu_operation():
# 连续矩阵运算充分使用MXU
result = tf.matmul(matrix_a, matrix_b)
result = tf.matmul(result, matrix_a)
return tf.reduce_sum(result)
result = mxu_operation()
# 使用result确保计算被执行
_ = result.numpy()
end_time = time.time()
# 获取测试后内存
post_mem = get_detailed_memory_snapshot()
duration = end_time - start_time
# 计算FLOPS (两次矩阵乘法)
flops = 2 * (2 * size**3)
tflops = flops / duration / 1e12
memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current']
print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB")
mxu_results.append({
'size': size,
'tflops': tflops,
'duration': duration,
'memory_used': memory_used
})
except Exception as e:
print(f" {desc}: 测试失败 - {str(e)[:50]}")
# MXU性能分析
if mxu_results:
max_tflops = max(r['tflops'] for r in mxu_results)
total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0)
# TPU v5e-8单核理论性能
theoretical_tflops = 275 # bf16峰值性能
efficiency = (max_tflops / theoretical_tflops) * 100
print(f"\n 📊 MXU性能汇总:")
print(f" 峰值性能: {max_tflops:.1f} TFLOPS")
print(f" 理论峰值: {theoretical_tflops} TFLOPS")
print(f" MXU效率: {efficiency:.1f}%")
print(f" 计算内存占用: {total_memory}MB")
if efficiency > 80:
status = "🟢 优秀"
elif efficiency > 50:
status = "🟡 良好"
elif efficiency > 20:
status = "🟠 中等"
else:
status = "🔴 需优化"
print(f" 性能评级: {status}")
except Exception as e:
print(f" MXU测试失败: {e}")
try:
print("🎯 开始TPU训练监控...")
# 1. 获取初始状态
print("\n📸 初始TPU状态:")
baseline_snapshot = get_detailed_memory_snapshot()
print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB")
print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}")
# 显示各核心详细状态
for i in range(len(tpu_devices)):
core = baseline_snapshot[f'core_{i}']
if core['current'] > 0 or core['peak'] > 0:
print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB")
# 2. MXU性能基准测试
test_mxu_performance()
# 3. 创建分布式策略 - 使用项目验证的TPU初始化代码
print(f"\n🔄 使用项目标准TPU初始化...")
try:
# 使用项目里验证过的TPU初始化代码
# 禁用GPU避免冲突
try:
tf.config.set_visible_devices([], 'GPU')
print("🚫 GPU已禁用避免CUDA冲突")
except:
pass
# 使用标准的TPU初始化流程
print("🚀 使用官方TensorFlow TPU初始化...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
# 验证TPU设备
tpu_devices_check = tf.config.list_logical_devices('TPU')
print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备")
# 创建TPU策略
strategy = tf.distribute.TPUStrategy(resolver)
print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本")
use_distributed = True
except Exception as e:
print(f"⚠️ 分布式策略失败: {str(e)[:80]}")
print(" 将使用单设备模拟")
use_distributed = False
# 4. 模拟Brain-to-Text训练场景
print(f"\n🧠 模拟Brain-to-Text训练场景...")
if use_distributed:
# 分布式训练模拟
with strategy.scope():
print("📦 创建分布式模型参数...")
# 创建接近真实Brain-to-Text模型的参数 (修复维度匹配)
model_components = {
# GRU层权重第一层接收512维输入后续层接收256维
'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'),
'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'),
'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'),
'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'),
# 添加day-specific层模拟 (输入512维输出512维)
'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)]
}
# 检查模型加载后内存
after_model = get_detailed_memory_snapshot()
model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心")
# 训练循环模拟
print(f"\n🔄 开始训练循环监控...")
for step in range(10):
step_start_time = time.time()
@tf.function
def distributed_training_step():
# 模拟真实训练数据大小
batch_size = 32
seq_length = 1000
features = 512
# 输入数据
neural_data = tf.random.normal([batch_size, seq_length, features])
targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32)
# 模拟前向传播
x = neural_data
# Day-specific transformation (简化版本避免复杂的维度操作)
# 模拟day-specific变换对每个时间步应用相同变换
day_weight = model_components['day_weights'][0] # 简化使用第一个day权重
# 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512]
x = tf.matmul(x, day_weight)
# 为CTC损失添加目标使用模拟
target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1)
# 简化的CTC相关计算
batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32))
# GRU layers simulation
for i in range(3):
layer_name = f'gru_layer_{i}'
weight = model_components[layer_name]
# 处理张量维度第一层从3D输入后续层从2D输入
if i == 0:
# 第一层:取最后时间步 [batch, seq, features] -> [batch, features]
if len(x.shape) == 3:
x = x[:, -1, :] # 取最后时间步
x = tf.nn.tanh(tf.matmul(x, weight))
else:
# 后续层直接处理2D张量 [batch, features] -> [batch, features]
x = tf.nn.tanh(tf.matmul(x, weight))
# 输出投影
logits = tf.matmul(x, model_components['output_projection'])
# CTC loss模拟使用batch_loss_weight作为权重
base_loss = tf.reduce_mean(tf.square(logits))
loss = base_loss * batch_loss_weight
return loss
# 执行训练步骤
per_replica_loss = strategy.run(distributed_training_step)
# 聚合分布式结果
loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None)
step_duration = time.time() - step_start_time
# 获取当前内存状态
current_snapshot = get_detailed_memory_snapshot()
step_memory = current_snapshot['summary']['total_current']
memory_delta = step_memory - baseline_snapshot['summary']['total_current']
# 显示详细训练状态
active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)"
print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, "
f"时间={step_duration:.3f}s, "
f"内存={step_memory}MB(+{memory_delta}), "
f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}")
# 每5步显示峰值内存
if step % 5 == 0:
peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB"
print(f" {peak_info}")
time.sleep(0.2) # 短暂暂停观察
else:
# 单设备训练模拟(改进版)
print("🔸 单设备训练模拟...")
with tf.device(tpu_devices[0].name):
# 创建较小的模型参数
simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net')
for step in range(8):
step_start = time.time()
# 创建较大的数据批次
batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size
# 模拟计算密集型操作
@tf.function
def compute_step():
x = tf.reshape(batch_data, [-1, 512])
result = tf.matmul(x, simple_weights)
result = tf.nn.relu(result)
return tf.reduce_mean(result)
result = compute_step()
step_duration = time.time() - step_start
# 获取内存状态
snapshot = get_detailed_memory_snapshot()
memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f" Step {step}: result={result.numpy():.4f}, "
f"时间={step_duration:.3f}s, "
f"内存变化=+{memory_change}MB, "
f"峰值={snapshot['summary']['total_peak']}MB")
# 5. 最终分析报告
final_snapshot = get_detailed_memory_snapshot()
total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
peak_usage = final_snapshot['summary']['total_peak']
print(f"\n📈 训练监控报告:")
print(f" 总内存增长: +{total_growth}MB")
print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)")
print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}")
# 各核心最终状态
print(f" 各核心最终状态:")
has_changes = False
for i in range(len(tpu_devices)):
final_core = final_snapshot[f'core_{i}']
baseline_core = baseline_snapshot[f'core_{i}']
current_change = final_core['current'] - baseline_core['current']
peak_change = final_core['peak'] - baseline_core['peak']
if current_change != 0 or peak_change != 0:
has_changes = True
print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})")
if not has_changes:
print(f" 所有核心内存无明显变化")
# 分布式使用分析
if final_snapshot['summary']['active_cores'] == 1:
print(f"\n⚠️ 分布式问题诊断:")
print(f" 只有1个核心活跃其他7个核心空闲")
print(f" 可能原因: TPU策略配置问题或模型未正确分布")
print(f" 建议: 检查分布式策略和模型分片")
elif final_snapshot['summary']['active_cores'] > 4:
print(f"\n✅ 分布式状态良好:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常")
else:
print(f"\n🟡 分布式部分工作:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡")
print("✅ TPU训练监控完成")
except Exception as e:
print(f"❌ 训练监控失败: {e}")
import traceback
print(f"详细错误: {traceback.format_exc()[:300]}")
if __name__ == "__main__":
print("🚀 TPU训练内存监控工具")
print("专注于训练过程中的实时内存和性能监控")
print("适用于TPU v5e-8环境")
print()
monitor_tpu_during_training()
print(f"\n🎯 监控要点总结:")
print(f" 1. 确认所有8个TPU核心是否活跃")
print(f" 2. 监控内存增长模式和峰值使用")
print(f" 3. 检测MXU计算性能和效率")
print(f" 4. 验证分布式策略是否正常工作")
print(f" 5. 识别可能的内存泄漏或性能瓶颈")

View File

@@ -1,336 +0,0 @@
import os
import torch
from torch.utils.data import Dataset
import h5py
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import math
class BrainToTextDataset(Dataset):
'''
Dataset for brain-to-text data
Returns an entire batch of data instead of a single example
'''
def __init__(
self,
trial_indicies,
n_batches,
split = 'train',
batch_size = 64,
days_per_batch = 1,
random_seed = -1,
must_include_days = None,
feature_subset = None
):
'''
trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values
n_batches: (int) - number of random training batches to create
split: (string) - string specifying if this is a train or test dataset
batch_size: (int) - number of examples to include in batch returned from __getitem_()
days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates
to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch
random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random
must_include_days ([int]) - list of days that must be included in every batch
feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data
'''
# Set random seed for reproducibility
if random_seed != -1:
np.random.seed(random_seed)
torch.manual_seed(random_seed)
self.split = split
# Ensure the split is valid
if self.split not in ['train', 'test']:
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
self.days_per_batch = days_per_batch
self.batch_size = batch_size
self.n_batches = n_batches
self.days = {}
self.n_trials = 0
self.trial_indicies = trial_indicies
self.n_days = len(trial_indicies.keys())
self.feature_subset = feature_subset
# Calculate total number of trials in the dataset
for d in trial_indicies:
self.n_trials += len(trial_indicies[d]['trials'])
if must_include_days is not None and len(must_include_days) > days_per_batch:
raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}')
if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train':
raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset')
if must_include_days is not None:
# Map must_include_days to correct indicies if they are negative
for i, d in enumerate(must_include_days):
if d < 0:
must_include_days[i] = self.n_days + d
self.must_include_days = must_include_days
# Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error
if self.split == 'train' and self.days_per_batch > self.n_days:
raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.')
if self.split == 'train':
self.batch_index = self.create_batch_index_train()
else:
self.batch_index = self.create_batch_index_test()
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
def __len__(self):
'''
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches if self.n_batches is not None else 0
def __getitem__(self, idx):
'''
Gets an entire batch of data from the dataset, not just a single item
'''
batch = {
'input_features' : [],
'seq_class_ids' : [],
'n_time_steps' : [],
'phone_seq_lens' : [],
'day_indicies' : [],
'transcriptions' : [],
'block_nums' : [],
'trial_nums' : [],
}
index = self.batch_index[idx]
# Iterate through each day in the index
for d in index.keys():
# Open the hdf5 file for that day
with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f:
# For each trial in the selected trials in that day
for t in index[d]:
try:
g = f[f'trial_{t:04d}']
# Remove features is neccessary
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility
if self.feature_subset:
input_features = input_features[:,self.feature_subset]
batch['input_features'].append(input_features)
batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels
batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions
batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding
batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding
batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers
batch['block_nums'].append(g.attrs['block_num'])
batch['trial_nums'].append(g.attrs['trial_num'])
except Exception as e:
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
continue
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens'])
batch['day_indicies'] = torch.tensor(batch['day_indicies'])
batch['transcriptions'] = torch.stack(batch['transcriptions'])
batch['block_nums'] = torch.tensor(batch['block_nums'])
batch['trial_nums'] = torch.tensor(batch['trial_nums'])
return batch
def create_batch_index_train(self):
'''
Create an index that maps a batch_number to batch_size number of trials
Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days
(or as even as possible if batch_size is not divisible by days_per_batch)
'''
batch_index = {}
# Precompute the days that are not in must_include_days
if self.must_include_days is not None:
non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days]
for batch_idx in range(self.n_batches):
batch = {}
# Which days will be used for this batch. Picked randomly without replacement
# TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day
# If must_include_days is not empty, we will use those days and then randomly sample the rest
if self.must_include_days is not None and len(self.must_include_days) > 0:
days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False)))
# Otherwise we will select random days without replacement
else:
days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False)
# How many trials will be sampled from each day
num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials
for d in days:
# Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem
trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True)
batch[d] = trial_idxs
# Remove extra trials
extra_trials = (num_trials * len(days)) - self.batch_size
# While we still have extra trials, remove the last trial from a random day
while extra_trials > 0:
d = np.random.choice(days)
batch[d] = batch[d][:-1]
extra_trials -= 1
batch_index[batch_idx] = batch
return batch_index
def create_batch_index_test(self):
'''
Create an index that is all validation/testing data in batches of up to self.batch_size
If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size
This index will ensures that every trial in the validation set is seen once and only once
'''
batch_index = {}
batch_idx = 0
for d in self.trial_indicies.keys():
# Calculate how many batches we need for this day
num_trials = len(self.trial_indicies[d]['trials'])
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
# Create batches for this day
for i in range(num_batches):
start_idx = i * self.batch_size
end_idx = min((i + 1) * self.batch_size, num_trials)
# Get the trial indices for this batch
batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx]
# Add to batch_index
batch_index[batch_idx] = {d : batch_trials}
batch_idx += 1
return batch_index
def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None):
'''
Split data from file_paths into train and test splits
Returns two dictionaries that detail which trials in each day will be a part of that split:
Example:
{
0: trials[1,2,3], session_path: 'path'
1: trials[2,5,6], session_path: 'path'
}
Args:
file_paths (list): List of file paths to the hdf5 files containing the data
test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing
seed (int): Seed for reproducibility. If set to -1, the split will be random
bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as:
{
'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
...
}
'''
# Set seed for reporoducibility
if seed != -1:
np.random.seed(seed)
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
good_trial_indices = []
if os.path.exists(path):
with h5py.File(path, 'r') as f:
num_trials = len(list(f.keys()))
for t in range(num_trials):
key = f'trial_{t:04d}'
block_num = f[key].attrs['block_num']
trial_num = f[key].attrs['trial_num']
if (
bad_trials_dict is not None
and session in bad_trials_dict
and str(block_num) in bad_trials_dict[session]
and trial_num in bad_trials_dict[session][str(block_num)]
):
# print(f'Bad trial: {session}_{block_num}_{trial_num}')
continue
good_trial_indices.append(t)
trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path}
# Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training
train_trials = {}
test_trials = {}
for day in trials_per_day.keys():
num_trials = trials_per_day[day]['num_trials']
# Generate all trial indices for this day (assuming 0-indexed)
all_trial_indices = trials_per_day[day]['trial_indices']
# If test_percentage is 0 or 1, we can just assign all trials to either train or test
if test_percentage == 0:
train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
continue
elif test_percentage == 1:
train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
continue
else:
# Calculate how many trials to use for testing
num_test = max(1, int(num_trials * test_percentage))
# Randomly select indices for testing
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
# Remaining indices go to training
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
# Store the split indices
train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']}
return train_trials, test_trials

View File

@@ -1,304 +0,0 @@
import os
import torch
import numpy as np
import pandas as pd
import redis
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
from rnn_model import GRUDecoder
from evaluate_model_helpers import *
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
help='Path to the pretrained model directory (relative to the current working directory).')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set. '
'If "test", ground truth is not available.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
parser.add_argument('--gpu_number', type=int, default=-1,
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
args = parser.parse_args()
# paths to model and data directories
# Note: these paths are relative to the current working directory
model_path = args.model_path
data_dir = args.data_dir
# define evaluation type
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
# load csv file
b2txt_csv_df = pd.read_csv(args.csv_path)
# load model args
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
# set up gpu device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using {device} for model inference.')
else:
if gpu_number >= 0:
print(f'GPU number {gpu_number} requested but not available.')
print('Using CPU for model inference.')
device = torch.device('cpu')
# define model
model = GRUDecoder(
neural_dim = model_args['model']['n_input_features'],
n_units = model_args['model']['n_units'],
n_days = len(model_args['dataset']['sessions']),
n_classes = model_args['dataset']['n_classes'],
rnn_dropout = model_args['model']['rnn_dropout'],
input_dropout = model_args['model']['input_network']['input_layer_dropout'],
n_layers = model_args['model']['n_layers'],
patch_size = model_args['model']['patch_size'],
patch_stride = model_args['model']['patch_stride'],
)
# load model weights
checkpoint = torch.load(
os.path.join(model_path, 'checkpoint/best_checkpoint'),
map_location=device,
weights_only=False,
)
# rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(checkpoint['model_state_dict'].keys()):
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
model.load_state_dict(checkpoint['model_state_dict'])
# add model to device
model.to(device)
# set model to eval mode
model.eval()
# load data for each session
test_data = {}
total_test_trials = 0
for session in model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()
# put neural data through the pretrained model to get phoneme predictions (logits)
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = model_args['dataset']['sessions'].index(session)
for trial in range(len(data['neural_features'])):
# get neural input for the trial
neural_input = data['neural_features'][trial]
# add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# convert to torch tensor
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
# run decoding step
logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
data['logits'].append(logits)
pbar.update(1)
pbar.close()
# convert logits to phoneme sequences and print them out
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# convert to phonemes
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# add to data
data['pred_seq'].append(pred_seq)
# print out the predicted sequences
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
print(f'Sentence label: {sentence_label}')
print(f'True sequence: {" ".join(true_seq)}')
print(f'Predicted Sequence: {" ".join(pred_seq)}')
print()
# language model inference via redis
# make sure that the standalone language model is running on the localhost redis ip
# see README.md for instructions on how to run the language model
def connect_to_redis_with_retry(host, port, password, db=0, max_retries=10, retry_delay=3):
"""Connect to Redis with retry logic"""
for attempt in range(max_retries):
try:
print(f"Attempting to connect to Redis at {host}:{port} (attempt {attempt + 1}/{max_retries})...")
r = redis.Redis(host=host, port=port, db=db, password=password)
r.ping() # Test the connection
print(f"Successfully connected to Redis at {host}:{port}")
return r
except redis.exceptions.ConnectionError as e:
print(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print("Max retries reached. Could not connect to Redis.")
raise e
except Exception as e:
print(f"Unexpected error connecting to Redis: {e}")
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
raise e
r = connect_to_redis_with_retry('hs.zchens.cn', 6379, 'admin01')
r.flushall() # clear all streams in redis
# define redis streams for the remote language model
remote_lm_input_stream = 'remote_lm_input'
remote_lm_output_partial_stream = 'remote_lm_output_partial'
remote_lm_output_final_stream = 'remote_lm_output_final'
# set timestamps for last entries seen in the redis streams
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
lm_results = {
'session': [],
'block': [],
'trial': [],
'true_sentence': [],
'pred_sentence': [],
}
# loop through all trials and put logits into the remote language model to get text predictions
# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090)
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
for session in test_data.keys():
for trial in range(len(test_data[session]['logits'])):
# get trial logits and rearrange them for the LM
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
# reset language model
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
'''
# update language model parameters
remote_lm_done_updating_lastEntrySeen = update_remote_lm_params(
r,
remote_lm_done_updating_lastEntrySeen,
acoustic_scale=0.35,
blank_penalty=90.0,
alpha=0.55,
)
'''
# put logits into LM
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
r,
remote_lm_input_stream,
remote_lm_output_partial_stream,
remote_lm_output_partial_lastEntrySeen,
logits,
)
# finalize remote LM
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
r,
remote_lm_output_final_stream,
remote_lm_output_final_lastEntrySeen,
)
# get the best candidate sentence
best_candidate_sentence = lm_out['candidate_sentences'][0]
# store results
lm_results['session'].append(session)
lm_results['block'].append(test_data[session]['block_num'][trial])
lm_results['trial'].append(test_data[session]['trial_num'][trial])
if eval_type == 'val':
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
else:
lm_results['true_sentence'].append(None)
lm_results['pred_sentence'].append(best_candidate_sentence)
# update progress bar
pbar.update(1)
pbar.close()
# if using the validation set, lets calculate the aggregate word error rate (WER)
if eval_type == 'val':
total_true_length = 0
total_edit_distance = 0
lm_results['edit_distance'] = []
lm_results['num_words'] = []
for i in range(len(lm_results['pred_sentence'])):
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
total_true_length += len(true_sentence.split())
total_edit_distance += ed
lm_results['edit_distance'].append(ed)
lm_results['num_words'].append(len(true_sentence.split()))
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
print(f'True sentence: {true_sentence}')
print(f'Predicted sentence: {pred_sentence}')
print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%')
print()
print(f'Total true sentence length: {total_true_length}')
print(f'Total edit distance: {total_edit_distance}')
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
# write predicted sentences to a csv file. put a timestamp in the filename (YYYYMMDD_HHMMSS)
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.csv')
ids = [i for i in range(len(lm_results['pred_sentence']))]
df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']})
df_out.to_csv(output_file, index=False)

View File

@@ -1,580 +0,0 @@
import torch
from torch import nn
from typing import cast
class GradientReversalFn(torch.autograd.Function):
"""
Gradient Reversal Layer (GRL)
Forward: identity
Backward: multiply incoming gradient by -lambda
"""
@staticmethod
def forward(ctx, x, lambd: float):
ctx.lambd = lambd
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambd * grad_output, None
def gradient_reverse(x, lambd: float = 1.0):
return GradientReversalFn.apply(x, lambd)
class NoiseModel(nn.Module):
'''
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
'''
def __init__(self,
neural_dim,
n_units,
n_days,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoiseModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
# Let Accelerator handle dtype automatically for TPU compatibility
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noise estimation
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.input_size, # Output same dimension as input
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Learnable initial hidden state - let Accelerator handle dtype
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
def forward(self, x, day_idx, states=None):
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
# XLA-friendly conditional dropout
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization - avoid dynamic allocation
if states is None:
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# Disable autocast for GRU to avoid dtype mismatches on XLA
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
return output, hidden_states
class CleanSpeechModel(nn.Module):
'''
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(CleanSpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
# Let Accelerator handle dtype automatically for TPU compatibility
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 3-layer GRU for clean speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=3,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
def forward(self, x, day_idx, states=None, return_state=False):
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization
if states is None:
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class NoisySpeechModel(nn.Module):
'''
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoisySpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noisy speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
def forward(self, x, states=None, return_state=False):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
batch_size = x.size(0)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization
if states is None:
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class TripleGRUDecoder(nn.Module):
'''
Three-model adversarial architecture for neural speech decoding
Combines:
- NoiseModel: estimates noise in neural data
- CleanSpeechModel: processes denoised signal for recognition
- NoisySpeechModel: processes noise signal for recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
):
'''
neural_dim (int) - number of channels in a single timestep (e.g. 512)
n_units (int) - number of hidden units in each recurrent layer
n_days (int) - number of days in the dataset
n_classes (int) - number of classes (phonemes)
rnn_dropout (float) - percentage of units to dropout during training
input_dropout (float) - percentage of input units to dropout during training
patch_size (int) - number of timesteps to concat on initial input layer
patch_stride(int) - number of timesteps to stride over when concatenating initial input
'''
super(TripleGRUDecoder, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_classes = n_classes
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Create the three models
self.noise_model = NoiseModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.clean_speech_model = CleanSpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.noisy_speech_model = NoisySpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
# Training mode flag
self.training_mode = 'full' # 'full', 'inference'
def _apply_preprocessing(self, x, day_idx):
'''XLA-friendly preprocessing with static operations'''
batch_size = x.size(0)
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx)
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
# Use bmm (batch matrix multiply) which is highly optimized in XLA
# Ensure dtype consistency for mixed precision training
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
x_processed = x_processed.unsqueeze(1)
x_processed = x_processed.permute(0, 3, 1, 2)
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x_processed = x_processed.to(original_dtype)
return x_processed
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
batch_size = x_processed.size(0)
clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
if x_processed.dtype != clean_gru_dtype:
x_processed = x_processed.to(clean_gru_dtype)
# XLA-friendly hidden state initialization with dtype consistency
if states is None:
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
if states.dtype != clean_gru_dtype:
states = states.to(clean_gru_dtype)
# GRU forward pass (skip preprocessing since input is already processed)
device_type = x_processed.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
# Classification
logits = self.clean_speech_model.out(output)
return logits
def _noisy_forward_with_processed_input(self, x_processed, states=None):
'''Forward pass for NoisySpeechModel with already processed input'''
batch_size = x_processed.size(0)
noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
if x_processed.dtype != noisy_gru_dtype:
x_processed = x_processed.to(noisy_gru_dtype)
# XLA-friendly hidden state initialization with dtype consistency
if states is None:
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
if states.dtype != noisy_gru_dtype:
states = states.to(noisy_gru_dtype)
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
device_type = x_processed.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
# Classification
logits = self.noisy_speech_model.out(output)
return logits
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
'''
Three-model adversarial forward pass
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
day_idx (tensor) - tensor of day indices for each example in the batch
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
'''
if mode == 'full':
# Training mode: run all three models
# 1. Noise model estimates noise in the data
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
# Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency between processed input and noise output
if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
# 3. Clean speech model processes denoised signal
denoised_input = x_processed - noise_output # Residual connection in processed space
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
# But we need to reverse the preprocessing first, then let clean model do its own
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
# 4. Noisy speech model processes noise signal directly (no day layers needed)
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
noisy_input = cast(torch.Tensor, noisy_input)
noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
if noisy_input.dtype != noisy_dtype:
noisy_input = noisy_input.to(noisy_dtype)
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
states['noisy'] if states else None)
# XLA-friendly return - use tuple instead of dict for better compilation
if return_state:
return (clean_logits, noisy_logits, noise_output), noise_hidden
return clean_logits, noisy_logits, noise_output
elif mode == 'inference':
# Inference mode: only noise model + clean speech model
# 1. Estimate noise
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency for mixed precision residual connection
if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
# XLA-friendly return - use tuple for consistency
if return_state:
return clean_logits, noise_hidden
return clean_logits
else:
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
'''
Apply combined gradients to noise model parameters
clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer
'''
# Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad
# Apply gradients to noise model parameters
# This is a simplified implementation - in practice you'd want more sophisticated update rules
with torch.no_grad():
for param in self.noise_model.parameters():
if param.grad is not None:
# Scale the combined gradient appropriately
# This is a placeholder - you'd need to implement proper gradient mapping
param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
def set_mode(self, mode):
'''Set the operating mode'''
self.training_mode = mode

View File

@@ -1,952 +0,0 @@
import os
# XLA multi-threading optimization - MUST be set before importing torch_xla
# Set these environment variables early to ensure they take effect
if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ:
# Enable XLA multi-threading for compilation speedup
os.environ.setdefault('XLA_FLAGS',
'--xla_cpu_multi_thread_eigen=true ' +
'--xla_cpu_enable_fast_math=true ' +
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
# Set PyTorch XLA threading
os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count()))
print(f"Set XLA compilation threads to {os.cpu_count()}")
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import random
import time
import numpy as np
import math
import pathlib
import logging
import sys
import json
import pickle
from contextlib import nullcontext
from dataset import BrainToTextDataset, train_test_split_indicies
from data_augmentations import gauss_smooth
import torchaudio.functional as F # for edit distance
from omegaconf import OmegaConf
# Import Accelerate for TPU support
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import set_seed
# Import XLA after setting environment variables
import torch_xla.core.xla_model as xm
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
torch.backends.cudnn.deterministic = True # makes training more reproducible
torch._dynamo.config.cache_size_limit = 64
from rnn_model import TripleGRUDecoder
class BrainToTextDecoder_Trainer:
"""
This class will initialize and train a brain-to-text phoneme decoder
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
"""
def __init__(self, args):
'''
args : dictionary of training arguments
'''
# Configure DataLoader behavior for TPU compatibility
dataloader_config = DataLoaderConfiguration(
even_batches=False # Required for batch_size=None DataLoaders on TPU
)
# Initialize Accelerator for TPU/multi-device support
self.use_xla = bool(xm.get_xla_supported_devices())
self.amp_requested = args.get('use_amp', True)
mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
self.accelerator = Accelerator(
mixed_precision=mixed_precision_mode,
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
log_with=None, # We'll use our own logging
project_dir=args.get('output_dir', './output'),
dataloader_config=dataloader_config,
)
# Trainer fields
self.args = args
self.logger = None
self.device = self.accelerator.device # Use accelerator device instead of manual device selection
self.model = None
self.optimizer = None
self.learning_rate_scheduler = None
self.ctc_loss = None
self.best_val_PER = torch.inf # track best PER for checkpointing
self.best_val_loss = torch.inf # track best loss for checkpointing
self.train_dataset = None
self.val_dataset = None
self.train_loader = None
self.val_loader = None
self.transform_args = self.args['dataset']['data_transforms']
# Adversarial training config (safe defaults if not provided)
adv_cfg = self.args.get('adversarial', {})
self.adv_enabled = adv_cfg.get('enabled', False)
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
# Create output directory
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=True)
# Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Set up logging
self.logger = logging.getLogger(__name__)
for handler in self.logger.handlers[:]: # make a copy of the list
self.logger.removeHandler(handler)
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
if args['mode']=='train':
# During training, save logs to file in output directory
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
fh.setFormatter(formatter)
self.logger.addHandler(fh)
# Always print logs to stdout
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
self.logger.addHandler(sh)
# Log device information (managed by Accelerator)
self.logger.info(f'Using device: {self.device}')
self.logger.info(f'Accelerator state: {self.accelerator.state}')
if self.accelerator.num_processes > 1:
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
if self.use_xla and self.amp_requested:
self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
if self.args['seed'] != -1:
set_seed(self.args['seed'])
# Initialize the model
self.model = TripleGRUDecoder(
neural_dim = self.args['model']['n_input_features'],
n_units = self.args['model']['n_units'],
n_days = len(self.args['dataset']['sessions']),
n_classes = self.args['dataset']['n_classes'],
rnn_dropout = self.args['model']['rnn_dropout'],
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
patch_size = self.args['model']['patch_size'],
patch_stride = self.args['model']['patch_stride'],
)
if self.use_xla and self.amp_requested:
self.model = self.model.to(torch.bfloat16)
self.logger.info('Converted model parameters to bfloat16 for TPU training.')
self.model_dtype = next(self.model.parameters()).dtype
# Temporarily disable torch.compile for compatibility with new model architecture
# TODO: Re-enable torch.compile once model is stable
# self.logger.info("Using torch.compile")
# self.model = torch.compile(self.model)
self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
self.logger.info(f"Initialized RNN decoding model")
self.logger.info(self.model)
# Log how many parameters are in the model
total_params = sum(p.numel() for p in self.model.parameters())
self.logger.info(f"Model has {total_params:,} parameters")
# Determine how many day-specific parameters are in the model
day_params = 0
for name, param in self.model.named_parameters():
if 'day' in name:
day_params += param.numel()
self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
# Create datasets and dataloaders
train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
# Ensure that there are no duplicate days
if len(set(train_file_paths)) != len(train_file_paths):
raise ValueError("There are duplicate sessions listed in the train dataset")
if len(set(val_file_paths)) != len(val_file_paths):
raise ValueError("There are duplicate sessions listed in the val dataset")
# Split trials into train and test sets
train_trials, _ = train_test_split_indicies(
file_paths = train_file_paths,
test_percentage = 0,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
_, val_trials = train_test_split_indicies(
file_paths = val_file_paths,
test_percentage = 1,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
# Save dictionaries to output directory to know which trials were train vs val
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
json.dump({'train' : train_trials, 'val': val_trials}, f)
# Determine if a only a subset of neural features should be used
feature_subset = None
if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
feature_subset = self.args['dataset']['feature_subset']
self.logger.info(f'Using only a subset of features: {feature_subset}')
# train dataset and dataloader
self.train_dataset = BrainToTextDataset(
trial_indicies = train_trials,
split = 'train',
days_per_batch = self.args['dataset']['days_per_batch'],
n_batches = self.args['num_training_batches'],
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
pin_memory = True,
collate_fn = collate_fn
)
# val dataset and dataloader
self.val_dataset = BrainToTextDataset(
trial_indicies = val_trials,
split = 'test',
days_per_batch = None,
n_batches = None,
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Validation DataLoader with same collate function
self.val_loader = DataLoader(
self.val_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True,
collate_fn = collate_fn # Use same collate function
)
self.logger.info("Successfully initialized datasets")
# Create optimizer, learning rate scheduler, and loss
self.optimizer = self.create_optimizer()
if self.args['lr_scheduler_type'] == 'linear':
self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer = self.optimizer,
start_factor = 1.0,
end_factor = self.args['lr_min'] / self.args['lr_max'],
total_iters = self.args['lr_decay_steps'],
)
elif self.args['lr_scheduler_type'] == 'cosine':
self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
else:
raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
# If a checkpoint is provided, then load from checkpoint
if self.args['init_from_checkpoint']:
self.load_model_checkpoint(self.args['init_checkpoint_path'])
# Set rnn and/or input layers to not trainable if specified
for name, param in self.model.named_parameters():
if not self.args['model']['rnn_trainable'] and 'gru' in name:
param.requires_grad = False
elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
param.requires_grad = False
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
# Let Accelerator handle everything automatically for both GPU and TPU
(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
)
self.model_dtype = next(self.model.parameters()).dtype
self.logger.info("Prepared model and dataloaders with Accelerator")
if self.adv_enabled:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
def autocast_context(self):
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
if self.device.type == 'xla':
return nullcontext()
return self.accelerator.autocast()
def create_optimizer(self):
'''
Create the optimizer with special param groups
Biases and day weights should not be decayed
Day weights should have a separate learning rate
'''
bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
if len(day_params) != 0:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
{'params' : other_params, 'group_type' : 'other'}
]
else:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : other_params, 'group_type' : 'other'}
]
optim = torch.optim.AdamW(
param_groups,
lr = self.args['lr_max'],
betas = (self.args['beta0'], self.args['beta1']),
eps = self.args['epsilon'],
weight_decay = self.args['weight_decay'],
fused = True
)
return optim
def create_cosine_lr_scheduler(self, optim):
lr_max = self.args['lr_max']
lr_min = self.args['lr_min']
lr_decay_steps = self.args['lr_decay_steps']
lr_max_day = self.args['lr_max_day']
lr_min_day = self.args['lr_min_day']
lr_decay_steps_day = self.args['lr_decay_steps_day']
lr_warmup_steps = self.args['lr_warmup_steps']
lr_warmup_steps_day = self.args['lr_warmup_steps_day']
def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
'''
Create lr lambdas for each param group that implement cosine decay
Different lr lambda decaying for day params vs rest of the model
'''
# Warmup phase
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
# Cosine decay phase
if current_step < decay_steps:
progress = float(current_step - warmup_steps) / float(
max(1, decay_steps - warmup_steps)
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
# Scale from 1.0 to min_lr_ratio
return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
# After cosine decay is complete, maintain min_lr_ratio
return min_lr_ratio
if len(optim.param_groups) == 3:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min_day / lr_max_day,
lr_decay_steps_day,
lr_warmup_steps_day,
), # day params
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
elif len(optim.param_groups) == 2:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
else:
raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
return LambdaLR(optim, lr_lambdas, -1)
def load_model_checkpoint(self, load_path):
'''
Load a training checkpoint for distributed training
'''
# Load checkpoint on CPU first to avoid OOM issues
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
# Get unwrapped model for loading state dict
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
# Device handling is managed by Accelerator, no need to manually move to device
self.logger.info("Loaded model from checkpoint: " + load_path)
def save_model_checkpoint(self, save_path, PER, loss):
'''
Save a training checkpoint using Accelerator for distributed training
'''
# Only save on main process to avoid conflicts
if self.accelerator.is_main_process:
# Unwrap model to get base model for saving
unwrapped_model = self.accelerator.unwrap_model(self.model)
checkpoint = {
'model_state_dict' : unwrapped_model.state_dict(),
'optimizer_state_dict' : self.optimizer.state_dict(),
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
'val_PER' : PER,
'val_loss' : loss
}
torch.save(checkpoint, save_path)
self.logger.info("Saved model to checkpoint: " + save_path)
# Save the args file alongside the checkpoint
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
# Wait for all processes to complete checkpoint saving
self.accelerator.wait_for_everyone()
def create_attention_mask(self, sequence_lengths):
max_length = torch.max(sequence_lengths).item()
batch_size = sequence_lengths.size(0)
# Create a mask for valid key positions (columns)
# Shape: [batch_size, max_length]
key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
key_mask = key_mask < sequence_lengths.unsqueeze(1)
# Expand key_mask to [batch_size, 1, 1, max_length]
# This will be broadcast across all query positions
key_mask = key_mask.unsqueeze(1).unsqueeze(1)
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
# by broadcasting key_mask across all query positions
attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
# Convert boolean mask to float mask:
# - True (valid key positions) -> 0.0 (no change to attention scores)
# - False (padding key positions) -> -inf (will become 0 after softmax)
attention_mask_float = torch.where(attention_mask,
True,
False)
return attention_mask_float
def transform_data(self, features, n_time_steps, mode = 'train'):
'''
Apply various augmentations and smoothing to data
Performing augmentations is much faster on GPU than CPU
'''
# TPU and GPU should now handle data consistently with our improved DataLoader configuration
data_shape = features.shape
batch_size = data_shape[0]
channels = data_shape[-1]
# We only apply these augmentations in training
if mode == 'train':
# add static gain noise
if self.transform_args['static_gain_std'] > 0:
warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
features = torch.matmul(features, warp_mat)
# add white noise
if self.transform_args['white_noise_std'] > 0:
features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
# add constant offset noise
if self.transform_args['constant_offset_std'] > 0:
features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
# add random walk noise
if self.transform_args['random_walk_std'] > 0:
features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
# randomly cutoff part of the data timecourse
if self.transform_args['random_cut'] > 0:
cut = np.random.randint(0, self.transform_args['random_cut'])
features = features[:, cut:, :]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing to data
# This is done in both training and validation
if self.transform_args['smooth_data']:
features = gauss_smooth(
inputs = features,
device = self.device,
smooth_kernel_std = self.transform_args['smooth_kernel_std'],
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
)
if hasattr(self, 'model_dtype'):
features = features.to(self.model_dtype)
return features, n_time_steps
def train(self):
'''
Train the model
'''
# Set model to train mode (specificially to make sure dropout layers are engaged)
self.model.train()
# create vars to track performance
train_losses = []
val_losses = []
val_PERs = []
val_results = []
val_steps_since_improvement = 0
# training params
save_best_checkpoint = self.args.get('save_best_checkpoint', True)
early_stopping = self.args.get('early_stopping', True)
early_stopping_val_steps = self.args.get('early_stopping_val_steps', 20)
train_start_time = time.time()
# train for specified number of batches
self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...")
for i, batch in enumerate(self.train_loader):
self.model.train()
self.optimizer.zero_grad()
# Train step
start_time = time.time()
# Data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
with self.autocast_context():
# Apply augmentations to the data
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
# Ensure proper dtype handling for TPU mixed precision
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Forward pass: enable full adversarial mode if configured and past warmup
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
if use_full:
clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
else:
logits = self.model(features, day_indicies, None, False, 'inference')
# Calculate CTC Loss
if use_full:
# Clean CTC loss
clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
clean_loss = self.ctc_loss(
clean_log_probs,
labels,
adjusted_lens,
phone_seq_lens
)
clean_loss = torch.mean(clean_loss)
# Noisy branch CTC loss让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
noisy_loss = self.ctc_loss(
noisy_log_probs,
labels,
adjusted_lens,
phone_seq_lens
)
noisy_loss = torch.mean(noisy_loss)
# Optional noise energy regularization
noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0:
noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
log_probs=log_probs,
targets=labels,
input_lengths=adjusted_lens,
target_lengths=phone_seq_lens
)
loss = torch.mean(loss) # take mean loss over batches
# Use Accelerator's backward for distributed training
self.accelerator.backward(loss)
# Clip gradient using Accelerator's clip_grad_norm_
if self.args['grad_norm_clip_value'] > 0:
grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(),
max_norm = self.args['grad_norm_clip_value'])
self.optimizer.step()
self.learning_rate_scheduler.step()
# Save training metrics
train_step_duration = time.time() - start_time
train_losses.append(loss.detach().item())
# Incrementally log training progress
if i % self.args['batches_per_train_log'] == 0:
self.logger.info(f'Train batch {i}: ' +
f'loss: {(loss.detach().item()):.2f} ' +
f'grad norm: {grad_norm:.2f} '
f'time: {train_step_duration:.3f}')
# Incrementally run a test step
if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
self.logger.info(f"Running test after training batch: {i}")
# Calculate metrics on val data
start_time = time.time()
val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
val_step_duration = time.time() - start_time
# Log info
self.logger.info(f'Val batch {i}: ' +
f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
f'time: {val_step_duration:.3f}')
if self.args['log_individual_day_val_PER']:
for day in val_metrics['day_PERs'].keys():
self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
# Save metrics
val_PERs.append(val_metrics['avg_PER'])
val_losses.append(val_metrics['avg_loss'])
val_results.append(val_metrics)
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
new_best = False
if val_metrics['avg_PER'] < self.best_val_PER:
self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
self.best_val_PER = val_metrics['avg_PER']
self.best_val_loss = val_metrics['avg_loss']
new_best = True
elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
self.best_val_loss = val_metrics['avg_loss']
new_best = True
if new_best:
# Checkpoint if metrics have improved
if save_best_checkpoint:
self.logger.info(f"Checkpointing model")
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
# save validation metrics to pickle file
if self.args['save_val_metrics']:
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
pickle.dump(val_metrics, f)
val_steps_since_improvement = 0
else:
val_steps_since_improvement +=1
# Optionally save this validation checkpoint, regardless of performance
if self.args['save_all_val_steps']:
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
# Early stopping
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
break
# Log final training steps
training_duration = time.time() - train_start_time
self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
# Save final model
if self.args['save_final_model']:
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
train_stats = {}
train_stats['train_losses'] = train_losses
train_stats['val_losses'] = val_losses
train_stats['val_PERs'] = val_PERs
train_stats['val_metrics'] = val_results
return train_stats
def validation(self, loader, return_logits = False, return_data = False):
'''
Calculate metrics on the validation dataset
'''
self.model.eval()
metrics = {}
# Record metrics
if return_logits:
metrics['logits'] = []
metrics['n_time_steps'] = []
if return_data:
metrics['input_features'] = []
metrics['decoded_seqs'] = []
metrics['true_seq'] = []
metrics['phone_seq_lens'] = []
metrics['transcription'] = []
metrics['losses'] = []
metrics['block_nums'] = []
metrics['trial_nums'] = []
metrics['day_indicies'] = []
total_edit_distance = 0
total_seq_length = 0
# Calculate PER for each specific day
day_per = {}
for d in range(len(self.args['dataset']['sessions'])):
if self.args['dataset']['dataset_probability_val'][d] == 1:
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
for i, batch in enumerate(loader):
# Data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Determine if we should perform validation on this batch
day = day_indicies[0].item()
if self.args['dataset']['dataset_probability_val'][day] == 0:
if self.args['log_val_skip_logs']:
self.logger.info(f"Skipping validation on day {day}")
continue
with torch.no_grad():
with self.autocast_context():
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure proper dtype handling for TPU mixed precision
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
logits = self.model(features, day_indicies, None, False, 'inference')
val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
val_log_probs,
labels,
adjusted_lens,
phone_seq_lens,
)
loss = torch.mean(loss)
metrics['losses'].append(loss.cpu().detach().numpy())
# Calculate PER per day and also avg over entire validation set
batch_edit_distance = 0
decoded_seqs = []
for iterIdx in range(logits.shape[0]):
decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
decoded_seq = decoded_seq.cpu().detach().numpy()
decoded_seq = np.array([i for i in decoded_seq if i != 0])
trueSeq = np.array(
labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
)
batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
decoded_seqs.append(decoded_seq)
day = batch['day_indicies'][0].item()
day_per[day]['total_edit_distance'] += batch_edit_distance
day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
total_edit_distance += batch_edit_distance
total_seq_length += torch.sum(phone_seq_lens)
# Record metrics
if return_logits:
metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
if return_data:
metrics['input_features'].append(batch['input_features'].cpu().numpy())
metrics['decoded_seqs'].append(decoded_seqs)
metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
metrics['losses'].append(loss.detach().item())
metrics['block_nums'].append(batch['block_nums'].numpy())
metrics['trial_nums'].append(batch['trial_nums'].numpy())
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
if isinstance(total_seq_length, torch.Tensor):
total_length_value = float(total_seq_length.item())
else:
total_length_value = float(total_seq_length)
avg_PER = total_edit_distance / max(total_length_value, 1e-6)
metrics['day_PERs'] = day_per
metrics['avg_PER'] = avg_PER
metrics['avg_loss'] = float(np.mean(metrics['losses']))
return metrics
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
'''
TPU-compatible inference method for generating phoneme logits
'''
self.model.eval()
with torch.no_grad():
with self.autocast_context():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits
def inference_batch(self, batch, mode='inference'):
'''
Inference method for processing a full batch
'''
self.model.eval()
# Data is automatically moved to device by Accelerator
features = batch['input_features']
day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps']
with torch.no_grad():
with self.autocast_context():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Calculate adjusted sequence lengths for CTC with proper dtype handling
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits, adjusted_lens

View File

@@ -1,150 +0,0 @@
#!/bin/bash
# Setup script for TensorFlow Brain-to-Text training on TPU v5e-8
#
# Usage: ./setup_tensorflow_tpu.sh
#
# This script prepares the environment for training the brain-to-text model
# using TensorFlow on TPU v5e-8 hardware.
set -e # Exit on any error
echo "=== TensorFlow TPU v5e-8 Setup Script ==="
echo "Setting up environment for brain-to-text training..."
# Check if we're in a TPU environment
if [[ -z "${TPU_NAME}" ]] && [[ -z "${COLAB_TPU_ADDR}" ]]; then
echo "Warning: TPU environment variables not detected."
echo "Make sure you're running on a TPU v5e-8 instance."
fi
# Create conda environment for TensorFlow TPU
ENV_NAME="b2txt_tf"
echo "Creating conda environment: ${ENV_NAME}"
if conda env list | grep -q "^${ENV_NAME} "; then
echo "Environment ${ENV_NAME} already exists. Activating..."
conda activate ${ENV_NAME}
else
echo "Creating new environment..."
conda create -n ${ENV_NAME} python=3.10 -y
conda activate ${ENV_NAME}
fi
# Install TensorFlow with TPU support
echo "Installing TensorFlow with TPU support..."
pip install tensorflow[and-cuda]>=2.15.0
# Install additional requirements
echo "Installing additional requirements..."
pip install -r requirements_tf.txt
# Set up TPU environment variables
echo "Configuring TPU environment variables..."
# Create or update .bashrc with TPU optimizations
cat >> ~/.bashrc << 'EOF'
# TPU v5e-8 Environment Variables
export TPU_ML_PLATFORM="TensorFlow"
export XLA_USE_BF16=1
export TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"
export TPU_MEGACORE=1
export LIBTPU_INIT_ARGS="--xla_tpu_spmd_threshold_for_allgather_cse=10000"
# Disable TensorFlow warnings for cleaner output
export TF_CPP_MIN_LOG_LEVEL=2
# Memory optimizations
export TF_FORCE_GPU_ALLOW_GROWTH=true
export TF_GPU_THREAD_MODE=gpu_private
EOF
# Source the updated .bashrc
source ~/.bashrc
# Test TPU connectivity
echo "Testing TPU connectivity..."
python3 << 'EOF'
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
try:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print(f"TPU cluster initialized successfully!")
print(f"Number of TPU cores: {strategy.num_replicas_in_sync}")
print(f"TPU devices: {tf.config.list_logical_devices('TPU')}")
except Exception as e:
print(f"TPU initialization failed: {e}")
print("You may be running on CPU/GPU instead of TPU")
# Test mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
tf.keras.mixed_precision.set_global_policy(policy)
print(f"Mixed precision policy: {policy.name}")
EOF
# Verify data directory exists
DATA_DIR="../data/hdf5_data_final"
if [ -d "$DATA_DIR" ]; then
echo "Data directory found: $DATA_DIR"
# Count available sessions
SESSION_COUNT=$(ls -d $DATA_DIR/t*.20* 2>/dev/null | wc -l)
echo "Available sessions: $SESSION_COUNT"
else
echo "Warning: Data directory not found at $DATA_DIR"
echo "Please ensure the dataset is available before training."
fi
# Create output directories
echo "Creating output directories..."
mkdir -p trained_models/tensorflow_tpu
mkdir -p logs/tensorflow_tpu
mkdir -p eval_output
# Make scripts executable
echo "Setting script permissions..."
chmod +x train_model_tf.py
chmod +x evaluate_model_tf.py
# Display system information
echo "=== System Information ==="
echo "Python version: $(python --version)"
echo "Conda environment: $CONDA_DEFAULT_ENV"
echo "Available memory: $(free -h | grep '^Mem:' | awk '{print $7}')"
echo "CPU cores: $(nproc)"
# Check for GPU/TPU
echo "=== Hardware Information ==="
if nvidia-smi &> /dev/null; then
echo "NVIDIA GPUs detected:"
nvidia-smi --list-gpus
else
echo "No NVIDIA GPUs detected"
fi
if [[ -n "${TPU_NAME}" ]]; then
echo "TPU Name: $TPU_NAME"
elif [[ -n "${COLAB_TPU_ADDR}" ]]; then
echo "Colab TPU Address: $COLAB_TPU_ADDR"
else
echo "No TPU environment variables detected"
fi
echo ""
echo "=== Setup Complete ==="
echo "Environment '$ENV_NAME' is ready for TensorFlow TPU training."
echo ""
echo "To activate the environment:"
echo " conda activate $ENV_NAME"
echo ""
echo "To start training:"
echo " python train_model_tf.py --config_path rnn_args.yaml"
echo ""
echo "To run evaluation:"
echo " python evaluate_model_tf.py --model_path path/to/checkpoint --config_path rnn_args.yaml"
echo ""
echo "For more options, use --help with any script."

View File

@@ -1,236 +0,0 @@
#!/usr/bin/env python3
"""
TPU内存监控工具 - 专门用于训练过程
解决tf.config.experimental.get_memory_info()在TPU上无法工作的问题
"""
import tensorflow as tf
import time
import psutil
import os
class TPUMemoryMonitor:
"""TPU内存监控类"""
def __init__(self):
self.tpu_devices = tf.config.list_logical_devices('TPU')
self.baseline_memory = None
self.peak_allocations = {}
def get_tpu_status(self) -> str:
"""获取TPU状态 - 实用版本不依赖get_memory_info"""
try:
if not self.tpu_devices:
return "TPU: No devices"
num_cores = len(self.tpu_devices)
# 测试TPU响应性
try:
with tf.device('/TPU:0'):
test_tensor = tf.constant([1.0, 2.0, 3.0])
result = tf.reduce_sum(test_tensor)
_ = result.numpy() # 强制执行
activity = "active"
except Exception:
activity = "inactive"
# 获取主机内存作为参考
try:
memory = psutil.virtual_memory()
host_mem = f"Host:{memory.percent:.1f}%"
except:
host_mem = "Host:unknown"
return f"TPU: {num_cores}cores {activity} {host_mem}"
except Exception as e:
return f"TPU: error({str(e)[:20]})"
def estimate_tensor_memory(self, tensor_shape, dtype=tf.float32):
"""估算张量内存使用量"""
if dtype == tf.float32:
bytes_per_element = 4
elif dtype == tf.float16 or dtype == tf.bfloat16:
bytes_per_element = 2
elif dtype == tf.int32:
bytes_per_element = 4
elif dtype == tf.int64:
bytes_per_element = 8
else:
bytes_per_element = 4 # 默认
total_elements = 1
for dim in tensor_shape:
total_elements *= dim
total_bytes = total_elements * bytes_per_element
return total_bytes / (1024 * 1024) # 返回MB
def track_allocation(self, name: str, tensor_shape, dtype=tf.float32):
"""跟踪内存分配"""
mb = self.estimate_tensor_memory(tensor_shape, dtype)
self.peak_allocations[name] = self.peak_allocations.get(name, 0) + mb
return mb
def get_allocation_summary(self) -> str:
"""获取分配汇总"""
if not self.peak_allocations:
return "No allocations tracked"
total_mb = sum(self.peak_allocations.values())
top_3 = sorted(self.peak_allocations.items(), key=lambda x: x[1], reverse=True)[:3]
summary = f"Tracked:{total_mb:.1f}MB "
summary += f"Top:({top_3[0][0]}:{top_3[0][1]:.1f}MB)"
return summary
def test_memory_allocation_across_cores(self):
"""测试8个核心的内存分配"""
print("🧪 测试所有TPU核心内存分配")
print("=" * 40)
allocations_per_core = []
for i, device in enumerate(self.tpu_devices):
print(f"核心 {i+1}: {device.name}")
try:
with tf.device(device.name):
# 创建不同大小的测试张量
test_sizes = [
([1000, 1000], "1K×1K"),
([3000, 3000], "3K×3K"),
([5000, 5000], "5K×5K"),
([7000, 7000], "7K×7K"),
]
core_total = 0
successful_allocs = []
for shape, desc in test_sizes:
try:
tensor = tf.ones(shape, dtype=tf.float32)
mb = self.estimate_tensor_memory(shape)
core_total += mb
successful_allocs.append(f"{desc}({mb:.1f}MB)")
# 实际使用张量防止被优化
_ = tf.reduce_mean(tensor)
except Exception as e:
print(f" {desc} 失败: {str(e)[:30]}")
break
allocations_per_core.append(core_total)
print(f" 成功分配: {' + '.join(successful_allocs)}")
print(f" 核心总计: {core_total:.1f}MB")
except Exception as e:
print(f" 核心{i+1}失败: {e}")
allocations_per_core.append(0)
# 汇总结果
total_all_cores = sum(allocations_per_core)
avg_per_core = total_all_cores / len(self.tpu_devices) if self.tpu_devices else 0
print(f"\n📊 汇总结果:")
print(f" 总分配: {total_all_cores:.1f}MB ({total_all_cores/1024:.2f}GB)")
print(f" 平均每核: {avg_per_core:.1f}MB ({avg_per_core/1024:.2f}GB)")
# 推测内存配置
if avg_per_core > 8000: # > 8GB
print(" 推测: 每核心≥16GB (高端配置)")
elif avg_per_core > 4000: # > 4GB
print(" 推测: 每核心8-16GB (标准配置)")
elif avg_per_core > 1000: # > 1GB
print(" 推测: 每核心2-8GB (受限或共享)")
else:
print(" 推测: 每核心<2GB (严重受限)")
return allocations_per_core
def test_training_memory_pattern():
"""测试模拟训练的内存模式"""
print("\n🏋️ 模拟训练内存模式测试")
print("=" * 30)
monitor = TPUMemoryMonitor()
# 模拟典型的brain-to-text模型内存使用
with tf.device('/TPU:0'):
print("创建模拟模型组件...")
# 1. 输入数据 (batch_size=32, seq_len=1000, features=512)
batch_size, seq_len, features = 32, 1000, 512
input_data = tf.random.normal([batch_size, seq_len, features])
input_mb = monitor.track_allocation("input_data", [batch_size, seq_len, features])
print(f" 输入数据: {input_mb:.1f}MB")
# 2. GRU权重 (假设3层, 每层256单元)
n_layers, n_units = 3, 256
for layer in range(n_layers):
# GRU有3个门每个门需要权重矩阵
weight_shape = [features if layer == 0 else n_units, n_units * 3]
weights = tf.random.normal(weight_shape)
weight_mb = monitor.track_allocation(f"gru_layer_{layer}", weight_shape)
print(f" GRU层{layer+1}权重: {weight_mb:.1f}MB")
# 3. 输出投影层 (n_units -> n_classes=41)
n_classes = 41
output_weights = tf.random.normal([n_units, n_classes])
output_mb = monitor.track_allocation("output_projection", [n_units, n_classes])
print(f" 输出投影: {output_mb:.1f}MB")
# 4. 中间激活值 (前向传播)
hidden_states = tf.random.normal([batch_size, seq_len, n_units])
hidden_mb = monitor.track_allocation("hidden_states", [batch_size, seq_len, n_units])
print(f" 隐藏状态: {hidden_mb:.1f}MB")
# 5. 梯度 (反向传播时会翻倍内存)
total_params_mb = sum([v for k, v in monitor.peak_allocations.items() if 'layer' in k or 'projection' in k])
gradient_mb = total_params_mb # 梯度内存约等于参数内存
print(f" 梯度内存: {gradient_mb:.1f}MB (估算)")
print(f"\n模型总内存估算: {monitor.get_allocation_summary()}")
# 实际执行一些操作确保内存被分配
result = tf.reduce_mean(input_data) + tf.reduce_mean(hidden_states)
print(f"验证计算结果: {result.numpy():.4f}")
if __name__ == "__main__":
print("🚀 TPU内存监控工具启动")
monitor = TPUMemoryMonitor()
# 基础状态检查
print(f"当前TPU状态: {monitor.get_tpu_status()}")
# 测试所有核心
print("\n" + "="*50)
core_allocations = monitor.test_memory_allocation_across_cores()
# 训练内存模式测试
print("\n" + "="*50)
test_training_memory_pattern()
print(f"\n🎯 关键发现:")
if core_allocations:
max_core = max(core_allocations)
min_core = min([x for x in core_allocations if x > 0])
print(f" 最大单核分配: {max_core:.1f}MB")
print(f" 最小单核分配: {min_core:.1f}MB")
if max_core > 9000: # 你之前测试到9.4GB
print(" ✅ 内存充足,可支持大模型训练")
elif max_core > 5000:
print(" ⚠️ 内存中等,建议优化模型大小")
else:
print(" ❌ 内存不足,需要大幅减少模型参数")
print(f"\n💡 针对你的训练卡顿问题:")
print(f" - SetPriority错误通常是XLA编译问题不是内存问题")
print(f" - 你的9.4GB测试说明TPU内存工作正常")
print(f" - 建议检查模型是否有导致XLA编译卡顿的操作")
print(f" - 考虑使用更简单的操作或关闭某些XLA优化")

View File

@@ -1,25 +0,0 @@
import argparse
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
def main():
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
parser.add_argument('--config_path', default='rnn_args.yaml',
help='Path to configuration file (default: rnn_args.yaml)')
args = parser.parse_args()
# Load configuration
config = OmegaConf.load(args.config_path)
# Initialize trainer
trainer = BrainToTextDecoder_Trainer(config)
# Start training
trainer.train()
print("Training completed successfully!")
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
if __name__ == "__main__":
main()

View File

@@ -8,7 +8,7 @@ It provides the same functionality as the PyTorch version but with TensorFlow
operations optimized for TPU performance.
Usage:
python train_model_tf.py --config_path rnn_args.yaml
python train_model_tf.py -config_path rnn_args.yaml
Requirements:
- TensorFlow >= 2.15.0