99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
修复DataFrame长度不匹配问题的脚本
|
||
错误:Dataframe returned from strategy has mismatching length
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Dict, Any
|
||
import logging
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def validate_dataframe_length(original_df: pd.DataFrame, processed_df: pd.DataFrame, pair: str) -> pd.DataFrame:
|
||
"""
|
||
验证并修复DataFrame长度不匹配问题
|
||
"""
|
||
original_len = len(original_df)
|
||
processed_len = len(processed_df)
|
||
|
||
if original_len != processed_len:
|
||
logger.warning(f"长度不匹配: {pair} - 原始: {original_len}, 处理后: {processed_len}")
|
||
|
||
# 确保索引对齐
|
||
if processed_len < original_len:
|
||
# 重新索引到原始DataFrame的长度
|
||
processed_df = processed_df.reindex(original_df.index)
|
||
# 填充NaN值
|
||
processed_df = processed_df.fillna(0)
|
||
elif processed_len > original_len:
|
||
# 截断到原始长度
|
||
processed_df = processed_df.iloc[:original_len]
|
||
|
||
return processed_df
|
||
|
||
def add_length_validation_to_strategy():
|
||
"""
|
||
为策略添加长度验证代码
|
||
"""
|
||
validation_code = '''
|
||
def populate_indicators(self, dataframe: pd.DataFrame, metadata: dict) -> pd.DataFrame:
|
||
"""
|
||
添加指标 - 包含长度验证
|
||
"""
|
||
original_length = len(dataframe)
|
||
|
||
# 保存原始索引
|
||
original_index = dataframe.index
|
||
|
||
# 你的指标计算代码...
|
||
|
||
# 验证长度
|
||
if len(dataframe) != original_length:
|
||
logger.warning(f"{metadata.get('pair', 'Unknown')} - DataFrame长度不匹配,正在修复...")
|
||
dataframe = dataframe.reindex(original_index)
|
||
dataframe = dataframe.fillna(0)
|
||
|
||
return dataframe
|
||
|
||
def populate_entry_trend(self, dataframe: pd.DataFrame, metadata: dict) -> pd.DataFrame:
|
||
"""
|
||
入场信号 - 包含长度验证
|
||
"""
|
||
original_length = len(dataframe)
|
||
|
||
# 你的入场信号代码...
|
||
|
||
# 验证长度
|
||
if len(dataframe) != original_length:
|
||
dataframe = dataframe.reindex(dataframe.index[:original_length])
|
||
|
||
return dataframe
|
||
|
||
def populate_exit_trend(self, dataframe: pd.DataFrame, metadata: dict) -> pd.DataFrame:
|
||
"""
|
||
出场信号 - 包含长度验证
|
||
"""
|
||
original_length = len(dataframe)
|
||
|
||
# 你的出场信号代码...
|
||
|
||
# 验证长度
|
||
if len(dataframe) != original_length:
|
||
dataframe = dataframe.reindex(dataframe.index[:original_length])
|
||
|
||
return dataframe
|
||
'''
|
||
return validation_code
|
||
|
||
if __name__ == "__main__":
|
||
# 测试修复
|
||
test_df = pd.DataFrame({'close': [1, 2, 3, 4, 5]})
|
||
broken_df = pd.DataFrame({'close': [1, 2, 3]})
|
||
|
||
fixed_df = validate_dataframe_length(test_df, broken_df, "TEST/USDT")
|
||
print(f"修复后长度: {len(fixed_df)}")
|
||
print("修复代码已生成,请应用到策略中") |