support baidu bce
This commit is contained in:
parent
3d639b9eaf
commit
be74f25b82
@ -134,6 +134,17 @@ M._defaults = {
|
|||||||
---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details
|
---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details
|
||||||
---@type {[string]: AvanteProvider}
|
---@type {[string]: AvanteProvider}
|
||||||
vendors = {
|
vendors = {
|
||||||
|
---@type AvanteSupportedProvider
|
||||||
|
["baidu"] = {
|
||||||
|
endpoint = "https://qianfan.baidubce.com/v2",
|
||||||
|
model = "deepseek-v3",
|
||||||
|
timeout = 30000,
|
||||||
|
temperature = 0,
|
||||||
|
max_tokens = 4096,
|
||||||
|
appid = "", -- Required for baidu provider
|
||||||
|
disable_search = false,
|
||||||
|
enable_citation = false,
|
||||||
|
},
|
||||||
---@type AvanteSupportedProvider
|
---@type AvanteSupportedProvider
|
||||||
["claude-haiku"] = {
|
["claude-haiku"] = {
|
||||||
__inherited_from = "claude",
|
__inherited_from = "claude",
|
||||||
|
75
lua/avante/providers/baidu.lua
Normal file
75
lua/avante/providers/baidu.lua
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
local Utils = require("avante.utils")
|
||||||
|
local Config = require("avante.config")
|
||||||
|
local P = require("avante.providers")
|
||||||
|
|
||||||
|
---@class AvanteProviderFunctor
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
M.api_key_name = "BAIDU_API_KEY"
|
||||||
|
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
---@param opts AvantePromptOptions
|
||||||
|
M.parse_messages = function(opts)
|
||||||
|
local messages = {}
|
||||||
|
|
||||||
|
table.insert(messages, { role = "user", content = opts.system_prompt })
|
||||||
|
|
||||||
|
vim
|
||||||
|
.iter(opts.messages)
|
||||||
|
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
end
|
||||||
|
|
||||||
|
M.parse_response = function(ctx, data_stream, _, opts)
|
||||||
|
if data_stream:match('"%[DONE%]":') then
|
||||||
|
opts.on_stop({ reason = "complete" })
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
---@type BaiduChatResponse
|
||||||
|
local jsn = vim.json.decode(data_stream)
|
||||||
|
if jsn.result then
|
||||||
|
opts.on_chunk(jsn.result)
|
||||||
|
if jsn.is_end then opts.on_stop({ reason = "complete" }) end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
|
local base, body_opts = P.parse_config(provider)
|
||||||
|
|
||||||
|
-- Validate required appid
|
||||||
|
if not base.appid or base.appid == "" then error("Baidu provider requires appid to be set in config") end
|
||||||
|
|
||||||
|
local headers = {
|
||||||
|
["Content-Type"] = "application/json",
|
||||||
|
["appid"] = base.appid,
|
||||||
|
}
|
||||||
|
|
||||||
|
if P.env.require_api_key(base) then
|
||||||
|
local api_key = provider.parse_api_key()
|
||||||
|
if api_key == nil then
|
||||||
|
error("Baidu API key is not set, please set BAIDU_API_KEY in your environment variable or config file")
|
||||||
|
end
|
||||||
|
headers["Authorization"] = "Bearer " .. api_key
|
||||||
|
end
|
||||||
|
|
||||||
|
return {
|
||||||
|
url = Utils.url_join(base.endpoint, "/chat/completions"),
|
||||||
|
proxy = base.proxy,
|
||||||
|
insecure = base.allow_insecure,
|
||||||
|
headers = headers,
|
||||||
|
body = vim.tbl_deep_extend("force", {
|
||||||
|
model = base.model,
|
||||||
|
messages = M.parse_messages(prompt_opts),
|
||||||
|
disable_search = base.disable_search,
|
||||||
|
enable_citation = base.enable_citation,
|
||||||
|
}, body_opts),
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
Loading…
x
Reference in New Issue
Block a user