diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 123ac11..ea660fd 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -134,6 +134,17 @@ M._defaults = { ---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details ---@type {[string]: AvanteProvider} vendors = { + ---@type AvanteSupportedProvider + ["baidu"] = { + endpoint = "https://qianfan.baidubce.com/v2", + model = "deepseek-v3", + timeout = 30000, + temperature = 0, + max_tokens = 4096, + appid = "", -- Required for baidu provider + disable_search = false, + enable_citation = false, + }, ---@type AvanteSupportedProvider ["claude-haiku"] = { __inherited_from = "claude", diff --git a/lua/avante/providers/baidu.lua b/lua/avante/providers/baidu.lua new file mode 100644 index 0000000..f4c27d3 --- /dev/null +++ b/lua/avante/providers/baidu.lua @@ -0,0 +1,75 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local P = require("avante.providers") + +---@class AvanteProviderFunctor +local M = {} + +M.api_key_name = "BAIDU_API_KEY" + +M.role_map = { + user = "user", + assistant = "assistant", +} + +---@param opts AvantePromptOptions +M.parse_messages = function(opts) + local messages = {} + + table.insert(messages, { role = "user", content = opts.system_prompt }) + + vim + .iter(opts.messages) + :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) + + return messages +end + +M.parse_response = function(ctx, data_stream, _, opts) + if data_stream:match('"%[DONE%]":') then + opts.on_stop({ reason = "complete" }) + return + end + + ---@type BaiduChatResponse + local jsn = vim.json.decode(data_stream) + if jsn.result then + opts.on_chunk(jsn.result) + if jsn.is_end then opts.on_stop({ reason = "complete" }) end + end +end + +M.parse_curl_args = function(provider, prompt_opts) + local base, body_opts = P.parse_config(provider) + + -- Validate required appid + if not base.appid or base.appid == "" then error("Baidu provider requires appid to be set in config") end + + local headers = { + ["Content-Type"] = "application/json", + ["appid"] = base.appid, + } + + if P.env.require_api_key(base) then + local api_key = provider.parse_api_key() + if api_key == nil then + error("Baidu API key is not set, please set BAIDU_API_KEY in your environment variable or config file") + end + headers["Authorization"] = "Bearer " .. api_key + end + + return { + url = Utils.url_join(base.endpoint, "/chat/completions"), + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = vim.tbl_deep_extend("force", { + model = base.model, + messages = M.parse_messages(prompt_opts), + disable_search = base.disable_search, + enable_citation = base.enable_citation, + }, body_opts), + } +end + +return M diff --git a/tree.txt b/tree.txt index 0ec1a83..6f112aa 100644 --- a/tree.txt +++ b/tree.txt @@ -3,7 +3,7 @@ │   └── avante.vim ├── Build.ps1 ├── build.sh -├── Cargo.lock +├─ ─ Cargo.lock ├── Cargo.toml ├── crates │   ├── avante-html2md