f
This commit is contained in:
@@ -633,12 +633,12 @@ class DataAugmentationTF:
|
||||
reshaped_inputs = tf.expand_dims(inputs, axis=2)
|
||||
|
||||
# Execute depthwise convolution
|
||||
# This is a single, efficient operation replacing the original Python for-loop
|
||||
# Use VALID padding to ensure output length <= input length
|
||||
smoothed = tf.nn.depthwise_conv2d(
|
||||
reshaped_inputs,
|
||||
kernel,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME'
|
||||
padding='VALID' # Changed from 'SAME' to 'VALID' to prevent length increase
|
||||
)
|
||||
|
||||
# Remove the dummy width dimension to restore original shape
|
||||
@@ -963,15 +963,15 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF,
|
||||
'n_features': dataset_tf.feature_dim
|
||||
}
|
||||
|
||||
# 5. 添加适当的安全边际 - 现在分析已包含数据增强效果
|
||||
# 5. 添加适当的安全边际 - Gaussian平滑现在使用VALID padding,不会增加长度
|
||||
if transform_args:
|
||||
# 已应用数据增强分析:只需要很小的边际应对随机性
|
||||
# 已应用数据增强分析:VALID padding确保长度不增加,只需小边际
|
||||
if sample_size == -1:
|
||||
safety_margin = 1.05 # 5% buffer for augmentation randomness (full data)
|
||||
margin_reason = "small buffer for augmentation randomness after full analysis"
|
||||
safety_margin = 1.05 # 5% buffer for minor augmentation effects (full data)
|
||||
margin_reason = "small buffer after full analysis with length-preserving augmentation"
|
||||
else:
|
||||
safety_margin = 1.15 # 15% buffer for augmentation randomness + sampling
|
||||
margin_reason = f"buffer for augmentation randomness + sampling {sample_size} trials"
|
||||
safety_margin = 1.15 # 15% buffer for augmentation + sampling
|
||||
margin_reason = f"buffer for augmentation + sampling {sample_size} trials"
|
||||
else:
|
||||
# 未应用数据增强:使用原始逻辑
|
||||
if sample_size == -1:
|
||||
|
||||
Reference in New Issue
Block a user