312 lines
8.2 KiB
Go
312 lines
8.2 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"flag"
|
||
"fmt"
|
||
"os"
|
||
"os/signal"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/go-redis/redis/v8"
|
||
"gitea.zjmud.xyz/phyer/myTestFreqAI/goflow/config"
|
||
)
|
||
|
||
var (
|
||
ctx = context.Background()
|
||
hostname string
|
||
redisAddr string
|
||
redisPassword string
|
||
redisDB int
|
||
client *redis.Client
|
||
wg sync.WaitGroup
|
||
)
|
||
|
||
// 辅助函数 - 获取正则匹配的第一个结果
|
||
func getFirstMatch(re *regexp.Regexp, line string) string {
|
||
matches := re.FindStringSubmatch(line)
|
||
if len(matches) > 1 {
|
||
return matches[1]
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// 提取阶段ID
|
||
func extractPhaseId(phaseInfo string) int {
|
||
phasePattern := regexp.MustCompile(`ID\s*:\s*(\d+)`)
|
||
idStr := getFirstMatch(phasePattern, phaseInfo)
|
||
id, _ := strconv.Atoi(idStr)
|
||
return id
|
||
}
|
||
|
||
// 提取阶段名称
|
||
func extractPhaseName(phaseInfo string) string {
|
||
namePattern := regexp.MustCompile(`Name\s*:\s*"([^"]*)"`)
|
||
return getFirstMatch(namePattern, phaseInfo)
|
||
}
|
||
|
||
// 提取阶段空间
|
||
func extractPhaseSpace(phaseInfo string) string {
|
||
spacePattern := regexp.MustCompile(`Space\s*:\s*"([^"]*)"`)
|
||
return getFirstMatch(spacePattern, phaseInfo)
|
||
}
|
||
|
||
// 提取阶段描述
|
||
func extractPhaseDescription(phaseInfo string) string {
|
||
descPattern := regexp.MustCompile(`Description\s*:\s*"([^"]*)"`)
|
||
return getFirstMatch(descPattern, phaseInfo)
|
||
}
|
||
|
||
// 提取参数名称
|
||
func extractParameterName(paramLine string) string {
|
||
namePattern := regexp.MustCompile(`"([^"]+)"\s*:`)
|
||
return getFirstMatch(namePattern, paramLine)
|
||
}
|
||
|
||
// 提取步骤ID
|
||
func extractStepID(stepInfo string) int {
|
||
stepPattern := regexp.MustCompile(`ID\s*:\s*(\d+)`)
|
||
idStr := getFirstMatch(stepPattern, stepInfo)
|
||
id, _ := strconv.Atoi(idStr)
|
||
return id
|
||
}
|
||
|
||
// 尝试JSON解析消息
|
||
func tryParseJSON(message string) (map[string]interface{}, bool) {
|
||
var result map[string]interface{}
|
||
err := json.Unmarshal([]byte(message), &result)
|
||
if err == nil {
|
||
return result, true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
// 处理任务消息
|
||
func handleTaskMessage(message string) {
|
||
fmt.Println("[消息处理] 开始处理任务消息")
|
||
|
||
// 输出消息的预览(前50个字符)
|
||
preview := message
|
||
if len(message) > 50 {
|
||
preview = message[:50] + "..."
|
||
}
|
||
fmt.Printf("[消息处理] 消息预览: %s\n", preview)
|
||
|
||
// 1. 尝试JSON解析
|
||
if jsonData, parsed := tryParseJSON(message); parsed {
|
||
fmt.Println("[消息处理] 成功解析JSON格式消息")
|
||
printJSONFields(jsonData)
|
||
} else {
|
||
// 2. 尝试文本解析
|
||
fmt.Println("[消息处理] 消息不是JSON格式,尝试文本解析")
|
||
|
||
// 解析消息中的关键信息
|
||
phaseID := extractPhaseId(message)
|
||
phaseName := extractPhaseName(message)
|
||
phaseSpace := extractPhaseSpace(message)
|
||
phaseDescription := extractPhaseDescription(message)
|
||
stepID := extractStepID(message)
|
||
|
||
// 输出解析结果
|
||
if phaseID > 0 {
|
||
fmt.Printf("[消息处理] 提取到阶段ID: %d\n", phaseID)
|
||
}
|
||
if phaseName != "" {
|
||
fmt.Printf("[消息处理] 提取到阶段名称: %s\n", phaseName)
|
||
}
|
||
if phaseSpace != "" {
|
||
fmt.Printf("[消息处理] 提取到阶段空间: %s\n", phaseSpace)
|
||
}
|
||
if phaseDescription != "" {
|
||
fmt.Printf("[消息处理] 提取到阶段描述: %s\n", phaseDescription)
|
||
}
|
||
if stepID > 0 {
|
||
fmt.Printf("[消息处理] 提取到步骤ID: %d\n", stepID)
|
||
}
|
||
}
|
||
|
||
// 3. 提取消息中的关键字段(通用方法)
|
||
extractCommonFields(message)
|
||
|
||
// 4. 模拟任务处理
|
||
fmt.Println("[消息处理] 开始执行任务...")
|
||
// 模拟任务处理延迟
|
||
time.Sleep(1 * time.Second)
|
||
fmt.Println("[消息处理] 任务执行完成")
|
||
fmt.Println("[消息处理] 任务处理结果: 成功")
|
||
|
||
// 输出任务处理状态确认
|
||
fmt.Println("[消息处理] 任务已成功处理完成,不再显示'收到消息, 但是什么都没干'")
|
||
}
|
||
|
||
// 按行号打印消息内容
|
||
func printMessageWithLineNumbers(message string) {
|
||
lines := strings.Split(message, "\n")
|
||
for i, line := range lines {
|
||
fmt.Printf("[消息处理] 行 %3d: %s\n", i+1, line)
|
||
// 避免输出过长
|
||
if i >= 10 {
|
||
fmt.Printf("[消息处理] ... 还有 %d 行未显示\n", len(lines)-10)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 打印JSON字段
|
||
func printJSONFields(jsonData map[string]interface{}) {
|
||
// 检查常见字段
|
||
if phaseID, ok := jsonData["phase_id"].(float64); ok {
|
||
fmt.Printf("[消息处理] JSON字段 - 阶段ID: %.0f\n", phaseID)
|
||
}
|
||
if phaseName, ok := jsonData["phase_name"].(string); ok {
|
||
fmt.Printf("[消息处理] JSON字段 - 阶段名称: %s\n", phaseName)
|
||
}
|
||
if stepID, ok := jsonData["step_id"].(float64); ok {
|
||
fmt.Printf("[消息处理] JSON字段 - 步骤ID: %.0f\n", stepID)
|
||
}
|
||
|
||
// 列出所有顶级字段
|
||
fmt.Printf("[消息处理] JSON所有顶级字段: %v\n", getAllKeys(jsonData))
|
||
}
|
||
|
||
// 获取map的所有键
|
||
func getAllKeys(m map[string]interface{}) []string {
|
||
keys := make([]string, 0, len(m))
|
||
for k := range m {
|
||
keys = append(keys, k)
|
||
}
|
||
return keys
|
||
}
|
||
|
||
// 提取常见字段
|
||
func extractCommonFields(message string) {
|
||
// 检查关键字段
|
||
keywords := []string{"phase", "step", "task", "hyperopt", "parameter", "result"}
|
||
foundKeywords := []string{}
|
||
|
||
for _, keyword := range keywords {
|
||
if strings.Contains(strings.ToLower(message), keyword) {
|
||
foundKeywords = append(foundKeywords, keyword)
|
||
}
|
||
}
|
||
|
||
if len(foundKeywords) > 0 {
|
||
fmt.Printf("[消息处理] 检测到关键字段: %v\n", foundKeywords)
|
||
}
|
||
}
|
||
|
||
// 订阅任务频道函数
|
||
func subscribeToTasks() {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
fmt.Println("[订阅协程] 启动任务订阅协程")
|
||
|
||
// 创建订阅
|
||
pubsub := client.Subscribe(ctx, config.HyperoptTasksChannel, config.HyperoptChannel)
|
||
defer func() {
|
||
fmt.Println("[订阅协程] 关闭任务订阅连接")
|
||
pubsub.Close()
|
||
}()
|
||
|
||
fmt.Println("[订阅协程] 等待接收任务消息...")
|
||
ch := pubsub.Channel()
|
||
|
||
// 接收消息循环
|
||
for msg := range ch {
|
||
fmt.Printf("[订阅协程] 收到来自 %s 的任务消息\n", msg.Channel)
|
||
fmt.Printf("[订阅协程] 消息内容长度: %d 字节\n", len(msg.Payload))
|
||
|
||
// 记录开始处理时间
|
||
startTime := time.Now()
|
||
|
||
// 处理收到的任务消息
|
||
handleTaskMessage(msg.Payload)
|
||
|
||
// 记录处理完成时间
|
||
elapsed := time.Since(startTime)
|
||
fmt.Printf("[订阅协程] 消息处理完成,耗时: %v\n", elapsed)
|
||
}
|
||
}()
|
||
}
|
||
|
||
// 主函数
|
||
func main() {
|
||
fmt.Println("开始初始化HyperOpt客户端...")
|
||
|
||
// 解析命令行参数
|
||
redisAddrFlag := flag.String("redis-addr", config.DefaultRedisConfig().Addr, "Redis服务器地址")
|
||
redisPasswordFlag := flag.String("redis-password", config.DefaultRedisConfig().Password, "Redis密码")
|
||
redisDBFlag := flag.Int("redis-db", config.DefaultRedisConfig().DB, "Redis数据库索引")
|
||
flag.Parse()
|
||
|
||
// 设置全局变量
|
||
redisAddr = *redisAddrFlag
|
||
redisPassword = *redisPasswordFlag
|
||
redisDB = *redisDBFlag
|
||
|
||
fmt.Println("命令行参数解析完成")
|
||
|
||
// 获取主机名
|
||
hostname, err := os.Hostname()
|
||
if err != nil {
|
||
fmt.Printf("错误: 获取主机名失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Printf("主机名: %s\n", hostname)
|
||
|
||
// 初始化Redis客户端
|
||
fmt.Printf("Redis客户端初始化,连接地址: %s, 数据库: %d\n", redisAddr, redisDB)
|
||
client = redis.NewClient(&redis.Options{
|
||
Addr: redisAddr,
|
||
Password: redisPassword,
|
||
DB: redisDB,
|
||
})
|
||
defer client.Close()
|
||
|
||
// 测试Redis连接
|
||
_, err = client.Ping(ctx).Result()
|
||
if err != nil {
|
||
fmt.Printf("错误: 无法连接到Redis: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Println("Redis连接测试成功")
|
||
|
||
// 订阅频道
|
||
hyperoptChannelName := config.HyperoptTasksChannel
|
||
fmt.Printf("开始订阅 %s 频道\n", hyperoptChannelName)
|
||
subscribeToTasks()
|
||
|
||
fmt.Println("客户端初始化完成")
|
||
fmt.Println("客户端初始化完成,等待接收任务...")
|
||
fmt.Println("按Ctrl+C退出")
|
||
|
||
// 保持主协程运行,定期输出等待信息
|
||
ticker := time.NewTicker(5 * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
// 等待中断信号
|
||
quit := make(chan os.Signal, 1)
|
||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||
|
||
// 主循环
|
||
for {
|
||
select {
|
||
case <-quit:
|
||
fmt.Println("接收到中断信号,开始关闭客户端")
|
||
fmt.Println("正在关闭客户端...")
|
||
wg.Wait()
|
||
fmt.Println("客户端已关闭")
|
||
return
|
||
case <-ticker.C:
|
||
fmt.Println("等待接收任务消息...")
|
||
}
|
||
}
|
||
}
|