327 lines
14 KiB
Python
327 lines
14 KiB
Python
import pymysql
|
||
import decimal
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
import argparse
|
||
|
||
# -------------------------- 配置项 --------------------------
|
||
DB_CONFIG = {
|
||
"host": "192.168.1.215", # 数据库地址
|
||
"port": 3306, # 端口
|
||
"user": "freqai", # 用户名
|
||
"password": "Abcd@1234", # 密码
|
||
"database": "freqai", # 数据库名
|
||
"charset": "utf8mb4" # 字符集
|
||
}
|
||
MIN_SPAN_DAYS = 5
|
||
CUTOFF_DATE = datetime(2025, 9, 10)
|
||
SHORT_SPAN_THRESHOLD = 50
|
||
|
||
# -------------------------- 工具函数 --------------------------
|
||
def truncate_commitid(commitid):
|
||
return commitid[:8] if commitid else ""
|
||
|
||
def format_timespan(from_str, to_str):
|
||
from_date = from_str.split(" ")[0].replace("-", "")
|
||
to_date = to_str.split(" ")[0].replace("-", "")
|
||
return f"{from_date}-{to_date}"
|
||
|
||
def parse_formatted_timespan(formatted_span):
|
||
"""解析格式化时段为datetime对象,异常时返回None"""
|
||
try:
|
||
from_date_str, to_date_str = formatted_span.split("-")
|
||
from_date = datetime.strptime(from_date_str, "%Y%m%d")
|
||
to_date = datetime.strptime(to_date_str, "%Y%m%d")
|
||
return from_date, to_date
|
||
except (ValueError, IndexError):
|
||
return None, None
|
||
|
||
def get_span_days(formatted_span):
|
||
from_date, to_date = parse_formatted_timespan(formatted_span)
|
||
if not from_date or not to_date:
|
||
return 0
|
||
return (to_date - from_date).days + 1
|
||
|
||
def is_span_valid(formatted_span):
|
||
from_date, _ = parse_formatted_timespan(formatted_span)
|
||
if not from_date:
|
||
return False
|
||
span_days = get_span_days(formatted_span)
|
||
if from_date < CUTOFF_DATE and span_days < SHORT_SPAN_THRESHOLD:
|
||
return False
|
||
return True
|
||
|
||
# 新增:计算缺失时段(目标时段 - 已覆盖时段)
|
||
def get_missing_spans(target_spans, covered_spans):
|
||
"""返回该CommitID缺失的目标时段列表"""
|
||
return [span for span in target_spans if span not in covered_spans]
|
||
|
||
def validate_timespan_format(formatted_span):
|
||
"""校验时段格式是否为 YYYYMMDD-YYYYMMDD,且开始日期≤结束日期"""
|
||
from_date, to_date = parse_formatted_timespan(formatted_span)
|
||
if not from_date or not to_date:
|
||
return False
|
||
if from_date > to_date:
|
||
return False
|
||
return True
|
||
|
||
# -------------------------- 命令行参数解析函数 --------------------------
|
||
def parse_command_line_args():
|
||
"""解析命令行参数,提取 --timespans 传入的时段列表"""
|
||
parser = argparse.ArgumentParser(description="查询指定时段组合下表现最优的CommitID")
|
||
parser.add_argument(
|
||
"--timespans",
|
||
type=str,
|
||
required=True,
|
||
help="指定目标时段组合,格式:YYYYMMDD-YYYYMMDD,YYYYMMDD-YYYYMMDD(多个时段用逗号分隔)"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
# 处理参数:分割逗号,去重,过滤空字符串
|
||
target_timespans = [span.strip() for span in args.timespans.split(",") if span.strip()]
|
||
|
||
# 校验每个时段的格式合法性
|
||
valid_timespans = []
|
||
invalid_timespans = []
|
||
for span in target_timespans:
|
||
if validate_timespan_format(span):
|
||
valid_timespans.append(span)
|
||
else:
|
||
invalid_timespans.append(span)
|
||
|
||
# 输出校验结果
|
||
if invalid_timespans:
|
||
print(f"⚠️ 以下时段格式不合法(需为YYYYMMDD-YYYYMMDD,且开始日期≤结束日期),已忽略:")
|
||
for span in invalid_timespans:
|
||
print(f" - {span}")
|
||
|
||
if not valid_timespans:
|
||
raise ValueError("❌ 无有效时段传入,程序终止")
|
||
|
||
return valid_timespans
|
||
|
||
# -------------------------- 数据库操作 --------------------------
|
||
def get_db_connection():
|
||
try:
|
||
return pymysql.connect(**DB_CONFIG)
|
||
except pymysql.MySQLError as e:
|
||
print(f"数据库连接失败: {e}")
|
||
raise
|
||
|
||
def get_backtest_data_in_target_spans(target_timespans):
|
||
"""提取目标时段组合内的所有回测数据(新增:保留每个CommitID的已覆盖时段)"""
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor(pymysql.cursors.DictCursor)
|
||
|
||
# 先把目标时段转换成「原始格式」(YYYY-MM-DD HH:MM:SS),用于数据库查询
|
||
target_span_original = {}
|
||
for formatted_span in target_timespans:
|
||
from_date, to_date = parse_formatted_timespan(formatted_span)
|
||
from_str = from_date.strftime("%Y-%m-%d 00:00:00")
|
||
to_str = to_date.strftime("%Y-%m-%d 00:00:00")
|
||
target_span_original[formatted_span] = (from_str, to_str)
|
||
|
||
# 遍历目标时段,查询每个时段的回测数据
|
||
target_backtest_data = defaultdict(dict)
|
||
# 新增:记录每个CommitID的已覆盖时段
|
||
cid_covered_spans = defaultdict(list)
|
||
|
||
for formatted_span, (from_str, to_str) in target_span_original.items():
|
||
sql = """
|
||
SELECT
|
||
JSON_UNQUOTE(JSON_EXTRACT(result, '$.commitid')) AS commitid,
|
||
IF(
|
||
JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) IS NULL
|
||
OR JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.sqn')) IS NULL,
|
||
0,
|
||
SIGN(CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) AS DECIMAL(10,4)))
|
||
* CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.total_profit')) AS DECIMAL(10,4))
|
||
* CAST(JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.sqn')) AS DECIMAL(10,4))
|
||
) AS corrected_product
|
||
FROM backtestresult
|
||
WHERE
|
||
JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.backtesting_from_str')) = %s
|
||
AND JSON_UNQUOTE(JSON_EXTRACT(result, '$.report.backtesting_to_str')) = %s;
|
||
"""
|
||
|
||
try:
|
||
cursor.execute(sql, (from_str, to_str))
|
||
records = cursor.fetchall()
|
||
|
||
# 计算该时段内所有commitid的相对排名
|
||
span_scores = {}
|
||
for record in records:
|
||
cid = record["commitid"]
|
||
score = decimal.Decimal(str(record["corrected_product"]))
|
||
if cid:
|
||
span_scores[cid] = score
|
||
|
||
# 按得分排序,计算相对排名分数
|
||
if span_scores:
|
||
sorted_cids = sorted(span_scores.items(), key=lambda x: x[1], reverse=True)
|
||
total_count = len(sorted_cids)
|
||
|
||
# 为每个commitid分配相对排名分数(0-100分)
|
||
for idx, (cid, score) in enumerate(sorted_cids):
|
||
# 使用百分位排名:排名越靠前,分数越高
|
||
# 例如:第1名=100分,第2名=90分,...,第10名=10分
|
||
rank_score = max(0, 100 - idx * 10)
|
||
|
||
# 存储相对排名分数
|
||
if cid not in target_backtest_data[formatted_span] or rank_score > target_backtest_data[formatted_span][cid]:
|
||
target_backtest_data[formatted_span][cid] = rank_score
|
||
|
||
# 记录该CommitID覆盖了当前时段(去重,避免重复添加)
|
||
if formatted_span not in cid_covered_spans[cid]:
|
||
cid_covered_spans[cid].append(formatted_span)
|
||
|
||
except pymysql.MySQLError as e:
|
||
print(f"查询目标时段 {formatted_span} 失败: {e}")
|
||
continue
|
||
|
||
cursor.close()
|
||
conn.close()
|
||
return target_backtest_data, cid_covered_spans
|
||
|
||
# -------------------------- 核心函数:寻找目标时段组合下的最优CommitID --------------------------
|
||
def find_best_commitid_in_target_spans(target_timespans):
|
||
"""寻找指定时段组合下的最优CommitID(包含缺失时段)"""
|
||
# 步骤1:获取目标时段回测数据 + 每个CommitID的已覆盖时段
|
||
target_backtest_data, cid_covered_spans = get_backtest_data_in_target_spans(target_timespans)
|
||
if not target_backtest_data:
|
||
print("❌ 目标时段组合内无回测数据")
|
||
return None, []
|
||
|
||
# 步骤2:聚合CommitID表现
|
||
commitid_agg_score = defaultdict(lambda: {
|
||
"total_score": decimal.Decimal("0"),
|
||
"avg_score": decimal.Decimal("0"),
|
||
"valid_span_count": 0,
|
||
"short_id": "",
|
||
"covered_spans": [], # 新增:已覆盖时段列表
|
||
"missing_spans": [] # 新增:缺失时段列表
|
||
})
|
||
|
||
for formatted_span, cid_scores in target_backtest_data.items():
|
||
for cid, score in cid_scores.items():
|
||
short_id = truncate_commitid(cid)
|
||
commitid_agg_score[cid]["total_score"] += score
|
||
commitid_agg_score[cid]["valid_span_count"] += 1
|
||
commitid_agg_score[cid]["short_id"] = short_id
|
||
|
||
# 步骤3:补充每个CommitID的已覆盖时段、缺失时段
|
||
for cid, agg_data in commitid_agg_score.items():
|
||
# 补充已覆盖时段
|
||
agg_data["covered_spans"] = cid_covered_spans.get(cid, [])
|
||
# 计算并补充缺失时段
|
||
agg_data["missing_spans"] = get_missing_spans(target_timespans, agg_data["covered_spans"])
|
||
# 计算平均得分
|
||
if agg_data["valid_span_count"] > 0:
|
||
agg_data["avg_score"] = agg_data["total_score"] / decimal.Decimal(str(agg_data["valid_span_count"]))
|
||
|
||
# 步骤4:排序选最优
|
||
sorted_commitids = sorted(
|
||
commitid_agg_score.items(),
|
||
key=lambda x: x[1]["total_score"],
|
||
reverse=True
|
||
)
|
||
|
||
# 格式化结果
|
||
result_list = []
|
||
for cid, agg_data in sorted_commitids:
|
||
result_list.append({
|
||
"commitid": cid,
|
||
"short_id": agg_data["short_id"],
|
||
"total_score": float(agg_data["total_score"]),
|
||
"avg_score": float(agg_data["avg_score"]),
|
||
"valid_span_count": agg_data["valid_span_count"],
|
||
"covered_spans": agg_data["covered_spans"],
|
||
"missing_spans": agg_data["missing_spans"]
|
||
})
|
||
|
||
# 最优CommitID
|
||
best_commitid = result_list[0] if result_list else None
|
||
return best_commitid, result_list
|
||
|
||
# -------------------------- 主函数 --------------------------
|
||
def main():
|
||
# 步骤1:解析命令行参数,获取有效目标时段
|
||
try:
|
||
TARGET_TIMESPANS = parse_command_line_args()
|
||
except ValueError as e:
|
||
print(e)
|
||
return
|
||
|
||
# 步骤2:输出目标时段信息
|
||
print(f"# 特定时段组合下最优CommitID查询报告(含缺失时段)")
|
||
print(f"\n## 目标时段组合(共 {len(TARGET_TIMESPANS)} 个有效时段)")
|
||
for idx, span in enumerate(TARGET_TIMESPANS, 1):
|
||
days = get_span_days(span)
|
||
print(f"{idx}. {span}(跨度:{days}天)")
|
||
|
||
# 步骤3:寻找最优CommitID
|
||
print(f"\n## 查询结果")
|
||
best_commitid, result_list = find_best_commitid_in_target_spans(TARGET_TIMESPANS)
|
||
|
||
if not best_commitid:
|
||
return
|
||
|
||
# 输出最优CommitID(包含缺失时段)
|
||
print(f"### 表现最出色的CommitID")
|
||
print(f"- CommitID(完整):{best_commitid['commitid']}")
|
||
print(f"- CommitID(前8位):{best_commitid['short_id']}")
|
||
print(f"- 目标时段总得分:{round(best_commitid['total_score'], 4)}")
|
||
print(f"- 目标时段平均得分:{round(best_commitid['avg_score'], 4)}")
|
||
print(f"- 有效覆盖时段数:{best_commitid['valid_span_count']}/{len(TARGET_TIMESPANS)}")
|
||
print(f"- 已覆盖时段:{', '.join(best_commitid['covered_spans']) if best_commitid['covered_spans'] else '无'}")
|
||
print(f"- 缺失时段:{', '.join(best_commitid['missing_spans']) if best_commitid['missing_spans'] else '无'}")
|
||
|
||
# 输出前10名CommitID(包含缺失时段)
|
||
print(f"\n### 目标时段组合下CommitID排名(前50,含缺失时段)")
|
||
print("| 排名 | CommitID(前8位) | 总得分 | 平均得分 | 有效覆盖时段数 | 缺失时段数 | 缺失时段详情 |")
|
||
print("|------|------------------|--------|----------|----------------|------------|--------------|")
|
||
for idx, item in enumerate(result_list[:50], 1):
|
||
# 处理缺失时段详情(避免表格过长,换行展示或简化)
|
||
missing_count = len(item["missing_spans"])
|
||
missing_detail = ", ".join(item["missing_spans"]) if missing_count > 0 else "无"
|
||
# 若缺失时段过多,可简化为「xxx等N个」,这里保留完整详情便于核对
|
||
if len(missing_detail) > 100:
|
||
missing_detail = f"{', '.join(item['missing_spans'][:3])} 等 {missing_count} 个"
|
||
|
||
print(f"| {idx} | {item['short_id']} | {round(item['total_score'],4)} | {round(item['avg_score'],4)} | {item['valid_span_count']}/{len(TARGET_TIMESPANS)} | {missing_count} | {missing_detail} |")
|
||
|
||
# 可选:输出前10名CommitID的完整缺失时段(便于详细核对)
|
||
print(f"\n### 前50名CommitID缺失时段详细清单(Markdown版本)")
|
||
for idx, item in enumerate(result_list[:50], 1):
|
||
print(f"\n{idx}. {item['short_id']}(缺失 {len(item['missing_spans'])} 个时段):")
|
||
if item["missing_spans"]:
|
||
for span in item["missing_spans"]:
|
||
print(f" - {span}(跨度:{get_span_days(span)}天)")
|
||
else:
|
||
print(f" - 无缺失时段(完整覆盖所有目标时段)")
|
||
|
||
# 新增:JSON版本输出到文件
|
||
print(f"\n### 前50名CommitID缺失时段详细清单(JSON版本)")
|
||
print(f"JSON数据已保存到:commitid_missing_spans.json")
|
||
|
||
import json
|
||
json_output = []
|
||
for idx, item in enumerate(result_list[:50], 1):
|
||
json_entry = {
|
||
"rank": idx,
|
||
"commitid_short": item['short_id'],
|
||
"missing_count": len(item['missing_spans']),
|
||
"missing_spans": [
|
||
{"timespan": span, "duration_days": get_span_days(span)}
|
||
for span in item['missing_spans']
|
||
]
|
||
}
|
||
json_output.append(json_entry)
|
||
|
||
# 将JSON数据写入文件
|
||
with open("commitid_missing_spans.json", "w", encoding="utf-8") as f:
|
||
json.dump(json_output, f, indent=4, ensure_ascii=False)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|