This commit is contained in:
Zchen
2025-10-22 13:43:52 +08:00
parent 8d94b706f7
commit 7511f4cf68

View File

@@ -633,12 +633,12 @@ class DataAugmentationTF:
reshaped_inputs = tf.expand_dims(inputs, axis=2)
# Execute depthwise convolution
# This is a single, efficient operation replacing the original Python for-loop
# Use VALID padding to ensure output length <= input length
smoothed = tf.nn.depthwise_conv2d(
reshaped_inputs,
kernel,
strides=[1, 1, 1, 1],
padding='SAME'
padding='VALID' # Changed from 'SAME' to 'VALID' to prevent length increase
)
# Remove the dummy width dimension to restore original shape
@@ -963,15 +963,15 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
'n_features': dataset_tf.feature_dim
}
# 5. 添加适当的安全边际 - 现在分析已包含数据增强效果
# 5. 添加适当的安全边际 - Gaussian平滑现在使用VALID padding不会增加长度
if transform_args:
# 已应用数据增强分析:只需要很小的边际应对随机性
# 已应用数据增强分析:VALID padding确保长度不增加只需小边际
if sample_size == -1:
safety_margin = 1.05 # 5% buffer for augmentation randomness (full data)
margin_reason = "small buffer for augmentation randomness after full analysis"
safety_margin = 1.05 # 5% buffer for minor augmentation effects (full data)
margin_reason = "small buffer after full analysis with length-preserving augmentation"
else:
safety_margin = 1.15 # 15% buffer for augmentation randomness + sampling
margin_reason = f"buffer for augmentation randomness + sampling {sample_size} trials"
safety_margin = 1.15 # 15% buffer for augmentation + sampling
margin_reason = f"buffer for augmentation + sampling {sample_size} trials"
else:
# 未应用数据增强:使用原始逻辑
if sample_size == -1: