tpu支持
This commit is contained in:
@@ -90,12 +90,12 @@ class BrainToTextDataset(Dataset):
|
||||
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
|
||||
|
||||
def __len__(self):
|
||||
'''
|
||||
How many batches are in this dataset.
|
||||
Because training data is sampled randomly, there is no fixed dataset length,
|
||||
however this method is required for DataLoader to work
|
||||
'''
|
||||
return self.n_batches
|
||||
How many batches are in this dataset.
|
||||
Because training data is sampled randomly, there is no fixed dataset length,
|
||||
however this method is required for DataLoader to work
|
||||
'''
|
||||
return self.n_batches if self.n_batches is not None else 0
|
||||
|
||||
def __getitem__(self, idx):
|
||||
'''
|
||||
@@ -269,7 +269,9 @@ def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_
|
||||
# Get trials in each day
|
||||
trials_per_day = {}
|
||||
for i, path in enumerate(file_paths):
|
||||
session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||||
# Handle both Windows and Unix path separators
|
||||
path_parts = path.replace('\\', '/').split('/')
|
||||
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||||
|
||||
good_trial_indices = []
|
||||
|
||||
|
Reference in New Issue
Block a user