Enhance dataset shape analysis to incorporate data augmentation effects and adjust safety margins accordingly
This commit is contained in:
@@ -830,17 +830,20 @@ def train_test_split_indices(file_paths: List[str],
|
||||
return train_trials, test_trials
|
||||
|
||||
|
||||
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
|
||||
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
|
||||
sample_size: int = 100,
|
||||
transform_args: Optional[Dict[str, Any]] = None) -> Dict[str, int]:
|
||||
"""
|
||||
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
|
||||
utilizing multiple CPU cores and the dataset's caching mechanism.
|
||||
applying data augmentation first to capture post-augmentation shapes.
|
||||
|
||||
Args:
|
||||
dataset_tf: Dataset instance to analyze
|
||||
sample_size: Number of samples to analyze (set to -1 for all data)
|
||||
transform_args: Data transformation configuration to apply during analysis
|
||||
|
||||
Returns:
|
||||
Dictionary with maximum dimensions
|
||||
Dictionary with maximum dimensions after data augmentation
|
||||
"""
|
||||
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
|
||||
start_time = time.time()
|
||||
@@ -867,29 +870,55 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
total_trials_to_analyze = len(trials_to_check)
|
||||
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
|
||||
|
||||
# 定义一个辅助函数,供每个线程调用
|
||||
# 定义一个辅助函数,供每个线程调用,包含数据增强
|
||||
def analyze_single_trial(day_trial_pair):
|
||||
"""Loads and analyzes a single trial, returns its shapes."""
|
||||
"""Loads and analyzes a single trial with data augmentation, returns its shapes."""
|
||||
day, trial = day_trial_pair
|
||||
try:
|
||||
# 复用 dataset_tf 的加载和缓存逻辑
|
||||
trial_data = dataset_tf._load_trial_data(day, trial)
|
||||
|
||||
# 直接从加载的数据中获取信息
|
||||
time_steps = int(trial_data['n_time_steps'])
|
||||
# 模拟数据增强过程(如果提供了transform_args)
|
||||
if transform_args:
|
||||
# 将numpy数据转换为TensorFlow张量
|
||||
features = tf.constant(trial_data['input_features'], dtype=tf.float32)
|
||||
n_time_steps = tf.constant(trial_data['n_time_steps'], dtype=tf.int32)
|
||||
|
||||
# 添加批次维度(batch_size=1)
|
||||
features = tf.expand_dims(features, axis=0)
|
||||
n_time_steps = tf.expand_dims(n_time_steps, axis=0)
|
||||
|
||||
# 应用数据变换
|
||||
augmented_features, augmented_n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, transform_args, training=True
|
||||
)
|
||||
|
||||
# 移除批次维度并获取增强后的形状
|
||||
augmented_features = tf.squeeze(augmented_features, axis=0)
|
||||
augmented_n_time_steps = tf.squeeze(augmented_n_time_steps, axis=0)
|
||||
|
||||
# 获取增强后的实际时间步数
|
||||
time_steps = int(augmented_n_time_steps.numpy())
|
||||
actual_features_shape = augmented_features.shape[0] # 可能与n_time_steps不同
|
||||
time_steps = max(time_steps, actual_features_shape) # 取较大值
|
||||
else:
|
||||
# 不应用数据增强,使用原始数据
|
||||
time_steps = int(trial_data['n_time_steps'])
|
||||
|
||||
# 获取其他维度(这些不受数据增强影响)
|
||||
phone_seq_len = int(trial_data['phone_seq_lens'])
|
||||
|
||||
# 处理 transcription 数据 - 它可能是数组
|
||||
# 处理 transcription 数据
|
||||
transcription_data = trial_data['transcription']
|
||||
if hasattr(transcription_data, '__len__'):
|
||||
transcription_len = len(transcription_data)
|
||||
else:
|
||||
transcription_len = 1 # 如果是标量,长度为1
|
||||
transcription_len = 1
|
||||
|
||||
return (time_steps, phone_seq_len, transcription_len)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||
return None # 返回 None 表示失败
|
||||
return None
|
||||
|
||||
# 3. 使用 ThreadPoolExecutor 进行并行处理
|
||||
# Use dynamic calculation based on CPU cores with reasonable upper limit
|
||||
@@ -934,15 +963,23 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
'n_features': dataset_tf.feature_dim
|
||||
}
|
||||
|
||||
# 5. 添加适当的安全边际 - 基于分析范围和数据增强调整
|
||||
if sample_size == -1:
|
||||
# 全数据分析:需要为数据增强预留空间(特别是Gaussian平滑)
|
||||
safety_margin = 1.15 # 15% buffer for data augmentation effects
|
||||
margin_reason = "buffer for full dataset analysis + data augmentation"
|
||||
# 5. 添加适当的安全边际 - 现在分析已包含数据增强效果
|
||||
if transform_args:
|
||||
# 已应用数据增强分析:只需要很小的边际应对随机性
|
||||
if sample_size == -1:
|
||||
safety_margin = 1.05 # 5% buffer for augmentation randomness (full data)
|
||||
margin_reason = "small buffer for augmentation randomness after full analysis"
|
||||
else:
|
||||
safety_margin = 1.15 # 15% buffer for augmentation randomness + sampling
|
||||
margin_reason = f"buffer for augmentation randomness + sampling {sample_size} trials"
|
||||
else:
|
||||
# 采样分析:需要更大的边际应对未采样到的极值 + 数据增强
|
||||
safety_margin = 1.35 # 35% buffer for sampling uncertainty + data augmentation
|
||||
margin_reason = f"larger buffer due to sampling only {sample_size} trials + data augmentation"
|
||||
# 未应用数据增强:使用原始逻辑
|
||||
if sample_size == -1:
|
||||
safety_margin = 1.02 # 2% buffer for rounding errors
|
||||
margin_reason = "minimal buffer for full dataset analysis without augmentation"
|
||||
else:
|
||||
safety_margin = 1.3 # 30% buffer for sampling uncertainty
|
||||
margin_reason = f"larger buffer due to sampling only {sample_size} trials without augmentation"
|
||||
|
||||
final_max_shapes = {
|
||||
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
||||
@@ -1005,7 +1042,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
|
||||
# Analyze dataset to get maximum shapes
|
||||
print("📊 Analyzing dataset for maximum shapes...")
|
||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1) # Analyze ALL data for maximum accuracy
|
||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1, transform_args=transform_args) # Pass transform_args
|
||||
|
||||
# Use static shapes based on analysis
|
||||
padded_shapes = {
|
||||
|
||||
Reference in New Issue
Block a user