tpu支持

This commit is contained in:
Zchen
2025-10-12 18:41:26 +08:00
parent 1a906d3248
commit 40e4d00576
5 changed files with 456 additions and 226 deletions

View File

@@ -90,12 +90,12 @@ class BrainToTextDataset(Dataset):
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
def __len__(self):
'''
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches if self.n_batches is not None else 0
def __getitem__(self, idx):
'''
@@ -269,7 +269,9 @@ def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
good_trial_indices = []