Commit Graph

179 Commits

Author SHA1 Message Date
Zchen
a3989c94cc Implement unified shape analysis for training and validation datasets; add support for fixed predefined shapes to enhance TPU compatibility and performance. 2025-10-22 20:17:39 +08:00
Zchen
8ab5697081 Add support for fixed predefined shapes in create_input_fn to optimize shape handling and skip analysis 2025-10-22 15:50:51 +08:00
Zchen
21b8e4f342 Fix circular dependency in trial data loading and improve error handling; ensure sequence length preservation in Gaussian smoothing 2025-10-22 15:31:34 +08:00
Zchen
fde5ea20ad Enhance safety margin calculations in dataset shape analysis to address double augmentation issues caused by random transformations. Implement intelligent detection of random vs deterministic augmentations, applying appropriate safety margins to prevent shape mismatch errors during training. 2025-10-22 15:09:38 +08:00
Zchen
4a99f50afd Add training parameter to analyze_dataset_shapes for improved data augmentation handling 2025-10-22 14:13:26 +08:00
Zchen
7511f4cf68 f 2025-10-22 13:43:52 +08:00
Zchen
8d94b706f7 Enhance dataset shape analysis to incorporate data augmentation effects and adjust safety margins accordingly 2025-10-22 13:08:46 +08:00
Zchen
d92889d435 Adjust safety margin in dataset shape analysis to account for data augmentation effects 2025-10-22 10:23:22 +08:00
Zchen
0d9bf29d07 Adjust safety margin in dataset shape analysis based on sample size for improved accuracy 2025-10-22 01:48:40 +08:00
Zchen
3c993a6268 Increase safety margin to 30% in dataset shape analysis for improved padding accuracy 2025-10-22 01:47:08 +08:00
Zchen
6fb5907c72 Refactor create_input_fn to support static shape handling for XLA compatibility 2025-10-22 01:29:31 +08:00
Zchen
c03441d8f3 Refactor dynamic padding shapes and update device placement configuration for TPU training 2025-10-22 01:03:14 +08:00
Zchen
57f07434ac f 2025-10-22 00:54:20 +08:00
Zchen
52a9b17375 f 2025-10-22 00:38:55 +08:00
Zchen
e715d9ac79 Enhance error handling and deprecate batch generation methods in BrainToTextDatasetTF
- Improved error logging when loading trial data fails, ensuring correct feature dimensions in dummy data.
- Marked _create_batch_generator and create_dataset methods as deprecated, recommending create_input_fn for better performance.
- Adjusted maximum parallel workers in analyze_dataset_shapes based on CPU cores.
2025-10-22 00:28:10 +08:00
Zchen
a031972ba6 f 2025-10-21 01:07:57 +08:00
Zchen
ab12d0b7ee f 2025-10-21 00:31:59 +08:00
Zchen
e7c9b95b00 f 2025-10-21 00:19:05 +08:00
Zchen
5a0079641a f 2025-10-20 23:34:44 +08:00
Zchen
e399cf262a ff 2025-10-20 13:37:11 +08:00
Zchen
7358ff3d79 Enable soft device placement for CTC operations and update related comments 2025-10-20 11:22:13 +08:00
Zchen
f8fb4d7133 Remove setup script, TPU memory monitor, and training model script
- Deleted `setup_tensorflow_tpu.sh` which was responsible for setting up the TensorFlow environment on TPU v5e-8.
- Removed `tpu_memory_monitor.py`, a tool for monitoring TPU memory usage during training.
- Eliminated `train_model.py`, the script for training the Brain-to-Text RNN model.
2025-10-20 11:05:03 +08:00
Zchen
7c272b7c5b Remove test scripts for data loading and TensorFlow implementation 2025-10-20 01:37:22 +08:00
Zchen
0a0e07a193 Remove custom CTC loss implementation for TPU from the TripleGRUDecoder class 2025-10-20 01:16:50 +08:00
Zchen
06ddbc6ac2 Refactor input function to implement batch-first approach with dynamic padding and apply data augmentation post-batching for TPU compatibility 2025-10-20 00:58:29 +08:00
Zchen
fabf70cfa9 Enhance dataset shape analysis by implementing parallel processing and improving sampling logic 2025-10-20 00:35:17 +08:00
Zchen
e1669b5a4c Increase batch size from 256 to 512 for training in rnn_args.yaml 2025-10-20 00:21:33 +08:00
Zchen
6e02894a8a f 2025-10-20 00:13:39 +08:00
Zchen
4db3625dc5 f 2025-10-19 23:55:56 +08:00
Zchen
fed5fd8251 f 2025-10-19 22:25:21 +08:00
Zchen
4b373ab317 ff 2025-10-19 20:16:23 +08:00
Zchen
40d0fc50de f 2025-10-19 13:18:20 +08:00
Zchen
4328114ed6 Add dataset shape analysis function and integrate into input function for TPU optimization 2025-10-19 11:04:36 +08:00
Zchen
cfd9653da9 Enhance dataset caching logic for training and validation sets with improved messaging 2025-10-19 10:31:31 +08:00
Zchen
558be0ad98 Refactor individual dataset creation for improved I/O efficiency and add logging for error handling 2025-10-19 10:31:18 +08:00
Zchen
d83f990beb f 2025-10-17 12:20:17 +08:00
Zchen
eb058fe9d3 ff 2025-10-17 11:57:10 +08:00
Zchen
57360bec8a Remove CPU optimization call and add logging for TPU strategy and data pipeline performance 2025-10-17 11:45:20 +08:00
Zchen
eb4e3fc69f fff 2025-10-17 11:38:57 +08:00
Zchen
6c7abfcca8 f 2025-10-17 10:53:58 +08:00
Zchen
7ede7b5f12 f 2025-10-17 02:09:14 +08:00
Zchen
ca8c615505 f 2025-10-17 02:01:48 +08:00
Zchen
49700456b8 f 2025-10-17 01:58:28 +08:00
Zchen
8ee09b6b5e f 2025-10-17 01:54:32 +08:00
Zchen
a5a3179ca6 f 2025-10-17 01:49:03 +08:00
Zchen
59fb73ee9f f 2025-10-17 01:36:08 +08:00
Zchen
0a72143513 legacy adam 2025-10-17 01:26:02 +08:00
Zchen
7df78244e6 adamw to adam 2025-10-17 01:07:01 +08:00
Zchen
a96e272f7b fix twice gradient cut 2025-10-17 00:51:53 +08:00
Zchen
7a43ebfb71 refactor: streamline model building and ensure dtype consistency in L2 loss calculation 2025-10-16 23:06:09 +08:00