254 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			254 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | 测试优化后的数据加载管道性能 | ||
|  | Test script for optimized data loading pipeline performance | ||
|  | """
 | ||
|  | 
 | ||
|  | import os | ||
|  | import time | ||
|  | import psutil | ||
|  | import tensorflow as tf | ||
|  | from omegaconf import OmegaConf | ||
|  | from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn | ||
|  | 
 | ||
|  | def get_memory_usage(): | ||
|  |     """获取当前内存使用情况""" | ||
|  |     process = psutil.Process() | ||
|  |     memory_info = process.memory_info() | ||
|  |     return memory_info.rss / 1024 / 1024  # MB | ||
|  | 
 | ||
|  | def test_data_loading_performance(): | ||
|  |     """测试数据加载性能对比""" | ||
|  | 
 | ||
|  |     # 加载配置 | ||
|  |     config_path = "../rnn_args.yaml" | ||
|  |     if not os.path.exists(config_path): | ||
|  |         print("❌ Configuration file not found. Creating minimal test config...") | ||
|  |         # 创建最小测试配置 | ||
|  |         args = { | ||
|  |             'dataset': { | ||
|  |                 'dataset_dir': '../data/hdf5_data_final', | ||
|  |                 'sessions': ['t15.2022.03.14', 't15.2022.03.16'], | ||
|  |                 'batch_size': 32, | ||
|  |                 'days_per_batch': 1, | ||
|  |                 'seed': 42, | ||
|  |                 'data_transforms': { | ||
|  |                     'smooth_data': False, | ||
|  |                     'white_noise_std': 0.0, | ||
|  |                     'constant_offset_std': 0.0, | ||
|  |                     'random_walk_std': 0.0, | ||
|  |                     'static_gain_std': 0.0, | ||
|  |                     'random_cut': 0 | ||
|  |                 } | ||
|  |             }, | ||
|  |             'num_training_batches': 10  # 只测试10个batch | ||
|  |         } | ||
|  |     else: | ||
|  |         args = OmegaConf.load(config_path) | ||
|  |         args = OmegaConf.to_container(args, resolve=True) | ||
|  |         # 限制测试batch数量 | ||
|  |         args['num_training_batches'] = 10 | ||
|  | 
 | ||
|  |     print("🔍 Starting data loading performance test...") | ||
|  |     print(f"📊 Test configuration: {args['num_training_batches']} batches, batch_size={args['dataset']['batch_size']}") | ||
|  | 
 | ||
|  |     # 获取文件路径 | ||
|  |     train_file_paths = [ | ||
|  |         os.path.join(args["dataset"]["dataset_dir"], s, 'data_train.hdf5') | ||
|  |         for s in args['dataset']['sessions'] | ||
|  |     ] | ||
|  | 
 | ||
|  |     print(f"📁 Testing with files: {train_file_paths}") | ||
|  | 
 | ||
|  |     # 检查文件是否存在 | ||
|  |     missing_files = [f for f in train_file_paths if not os.path.exists(f)] | ||
|  |     if missing_files: | ||
|  |         print(f"❌ Missing files: {missing_files}") | ||
|  |         print("⚠️ Creating dummy test data...") | ||
|  |         return test_with_dummy_data(args) | ||
|  | 
 | ||
|  |     # 分割数据 | ||
|  |     print("🔄 Splitting data...") | ||
|  |     train_trials, _ = train_test_split_indices( | ||
|  |         file_paths=train_file_paths, | ||
|  |         test_percentage=0, | ||
|  |         seed=args['dataset']['seed'] | ||
|  |     ) | ||
|  | 
 | ||
|  |     print(f"📈 Found {sum(len(trials['trials']) for trials in train_trials.values())} training trials") | ||
|  | 
 | ||
|  |     # 测试1: 不使用缓存 | ||
|  |     print("\n" + "="*60) | ||
|  |     print("🐌 TEST 1: 标准数据加载 (无缓存)") | ||
|  |     print("="*60) | ||
|  | 
 | ||
|  |     initial_memory = get_memory_usage() | ||
|  |     start_time = time.time() | ||
|  | 
 | ||
|  |     dataset_no_cache = BrainToTextDatasetTF( | ||
|  |         trial_indices=train_trials, | ||
|  |         n_batches=args['num_training_batches'], | ||
|  |         split='train', | ||
|  |         batch_size=args['dataset']['batch_size'], | ||
|  |         days_per_batch=args['dataset']['days_per_batch'], | ||
|  |         random_seed=args['dataset']['seed'], | ||
|  |         cache_data=False,           # 禁用缓存 | ||
|  |         preload_all_data=False      # 禁用预加载 | ||
|  |     ) | ||
|  | 
 | ||
|  |     tf_dataset_no_cache = create_input_fn( | ||
|  |         dataset_no_cache, | ||
|  |         args['dataset']['data_transforms'], | ||
|  |         training=True | ||
|  |     ) | ||
|  | 
 | ||
|  |     # 测试前3个batch的加载时间 | ||
|  |     batch_times = [] | ||
|  |     for i, batch in enumerate(tf_dataset_no_cache.take(3)): | ||
|  |         batch_start = time.time() | ||
|  |         # 触发实际数据加载 | ||
|  |         _ = batch['input_features'].numpy() | ||
|  |         batch_time = time.time() - batch_start | ||
|  |         batch_times.append(batch_time) | ||
|  |         print(f"   Batch {i}: {batch_time:.3f}s") | ||
|  | 
 | ||
|  |     no_cache_time = time.time() - start_time | ||
|  |     no_cache_memory = get_memory_usage() - initial_memory | ||
|  | 
 | ||
|  |     print(f"💾 Memory usage: +{no_cache_memory:.1f} MB") | ||
|  |     print(f"⏱️ Total time: {no_cache_time:.3f}s") | ||
|  |     print(f"📊 Avg batch time: {sum(batch_times)/len(batch_times):.3f}s") | ||
|  | 
 | ||
|  |     # 测试2: 使用预加载缓存 | ||
|  |     print("\n" + "="*60) | ||
|  |     print("🚀 TEST 2: 优化数据加载 (全缓存预加载)") | ||
|  |     print("="*60) | ||
|  | 
 | ||
|  |     initial_memory = get_memory_usage() | ||
|  |     start_time = time.time() | ||
|  | 
 | ||
|  |     dataset_with_cache = BrainToTextDatasetTF( | ||
|  |         trial_indices=train_trials, | ||
|  |         n_batches=args['num_training_batches'], | ||
|  |         split='train', | ||
|  |         batch_size=args['dataset']['batch_size'], | ||
|  |         days_per_batch=args['dataset']['days_per_batch'], | ||
|  |         random_seed=args['dataset']['seed'], | ||
|  |         cache_data=True,            # 启用缓存 | ||
|  |         preload_all_data=True       # 启用预加载 | ||
|  |     ) | ||
|  | 
 | ||
|  |     preload_time = time.time() - start_time | ||
|  |     preload_memory = get_memory_usage() - initial_memory | ||
|  | 
 | ||
|  |     print(f"📝 Preloading completed in {preload_time:.3f}s") | ||
|  |     print(f"💾 Preloading memory: +{preload_memory:.1f} MB") | ||
|  | 
 | ||
|  |     tf_dataset_with_cache = create_input_fn( | ||
|  |         dataset_with_cache, | ||
|  |         args['dataset']['data_transforms'], | ||
|  |         training=True | ||
|  |     ) | ||
|  | 
 | ||
|  |     # 测试前3个batch的加载时间 | ||
|  |     batch_start_time = time.time() | ||
|  |     batch_times_cached = [] | ||
|  |     for i, batch in enumerate(tf_dataset_with_cache.take(3)): | ||
|  |         batch_start = time.time() | ||
|  |         # 触发实际数据加载 | ||
|  |         _ = batch['input_features'].numpy() | ||
|  |         batch_time = time.time() - batch_start | ||
|  |         batch_times_cached.append(batch_time) | ||
|  |         print(f"   Batch {i}: {batch_time:.3f}s") | ||
|  | 
 | ||
|  |     cached_batch_time = time.time() - batch_start_time | ||
|  |     cached_memory = get_memory_usage() - initial_memory | ||
|  | 
 | ||
|  |     print(f"💾 Total memory usage: +{cached_memory:.1f} MB") | ||
|  |     print(f"⏱️ Batch loading time: {cached_batch_time:.3f}s") | ||
|  |     print(f"📊 Avg batch time: {sum(batch_times_cached)/len(batch_times_cached):.3f}s") | ||
|  | 
 | ||
|  |     # 性能对比 | ||
|  |     print("\n" + "="*60) | ||
|  |     print("📈 PERFORMANCE COMPARISON") | ||
|  |     print("="*60) | ||
|  | 
 | ||
|  |     speedup = (sum(batch_times)/len(batch_times)) / (sum(batch_times_cached)/len(batch_times_cached)) | ||
|  |     memory_cost = cached_memory - no_cache_memory | ||
|  | 
 | ||
|  |     print(f"🚀 Speed improvement: {speedup:.1f}x faster") | ||
|  |     print(f"💾 Memory cost: +{memory_cost:.1f} MB for caching") | ||
|  |     print(f"⚡ First batch time: {batch_times[0]:.3f}s → {batch_times_cached[0]:.3f}s") | ||
|  | 
 | ||
|  |     if speedup > 2: | ||
|  |         print("✅ Excellent! 缓存优化显著提升了数据加载速度") | ||
|  |     elif speedup > 1.5: | ||
|  |         print("✅ Good! 缓存优化有效提升了数据加载速度") | ||
|  |     else: | ||
|  |         print("⚠️ Warning: 缓存优化效果不明显,可能数据量太小") | ||
|  | 
 | ||
|  |     return True | ||
|  | 
 | ||
|  | def test_with_dummy_data(args): | ||
|  |     """使用模拟数据进行测试""" | ||
|  |     print("🔧 Creating dummy data for testing...") | ||
|  | 
 | ||
|  |     # 创建模拟试验索引 | ||
|  |     dummy_trials = { | ||
|  |         0: { | ||
|  |             'trials': list(range(100)),  # 100个模拟试验 | ||
|  |             'session_path': 'dummy_path' | ||
|  |         } | ||
|  |     } | ||
|  | 
 | ||
|  |     print("📊 Testing with dummy data (100 trials)...") | ||
|  | 
 | ||
|  |     # 测试缓存vs非缓存的初始化时间差异 | ||
|  |     print("\n🐌 Testing without cache...") | ||
|  |     start_time = time.time() | ||
|  |     dataset_no_cache = BrainToTextDatasetTF( | ||
|  |         trial_indices=dummy_trials, | ||
|  |         n_batches=5, | ||
|  |         split='train', | ||
|  |         batch_size=32, | ||
|  |         days_per_batch=1, | ||
|  |         random_seed=42, | ||
|  |         cache_data=False, | ||
|  |         preload_all_data=False | ||
|  |     ) | ||
|  |     no_cache_time = time.time() - start_time | ||
|  |     print(f"   Initialization time: {no_cache_time:.3f}s") | ||
|  | 
 | ||
|  |     print("\n🚀 Testing with cache...") | ||
|  |     start_time = time.time() | ||
|  |     dataset_with_cache = BrainToTextDatasetTF( | ||
|  |         trial_indices=dummy_trials, | ||
|  |         n_batches=5, | ||
|  |         split='train', | ||
|  |         batch_size=32, | ||
|  |         days_per_batch=1, | ||
|  |         random_seed=42, | ||
|  |         cache_data=True, | ||
|  |         preload_all_data=True | ||
|  |     ) | ||
|  |     cache_time = time.time() - start_time | ||
|  |     print(f"   Initialization time: {cache_time:.3f}s") | ||
|  | 
 | ||
|  |     print(f"\n✅ 缓存机制已成功集成到数据加载管道中") | ||
|  |     print(f"📝 实际性能需要用真实的HDF5数据进行测试") | ||
|  | 
 | ||
|  |     return True | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     print("🧪 Data Loading Performance Test") | ||
|  |     print("="*60) | ||
|  | 
 | ||
|  |     try: | ||
|  |         success = test_data_loading_performance() | ||
|  |         if success: | ||
|  |             print("\n🎉 Data loading optimization test completed successfully!") | ||
|  |             print("💡 你现在可以运行 train_model_tf.py 来享受快速的数据加载了") | ||
|  |     except Exception as e: | ||
|  |         print(f"\n❌ Test failed with error: {e}") | ||
|  |         import traceback | ||
|  |         traceback.print_exc() |