This commit is contained in:
Zchen
2025-10-12 21:43:12 +08:00
parent 6e1d8e18f7
commit dfb3f7312c
2 changed files with 44 additions and 7 deletions

View File

@@ -126,8 +126,8 @@ class BrainToTextDataset(Dataset):
try:
g = f[f'trial_{t:04d}']
# Remove features is neccessary
input_features = torch.from_numpy(g['input_features'][:]) # neural data
# Remove features is neccessary
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility
if self.feature_subset:
input_features = input_features[:,self.feature_subset]