feat(llm): add support for parsing secret vault (#200)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
		
							parent
							
								
									8d375dd591
								
							
						
					
					
						commit
						a7d3defa3d
					
				@ -15,6 +15,7 @@ M.defaults = {
 | 
			
		||||
  openai = {
 | 
			
		||||
    endpoint = "https://api.openai.com/v1",
 | 
			
		||||
    model = "gpt-4o",
 | 
			
		||||
    timeout = 30000, -- Timeout in milliseconds
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
    ["local"] = false,
 | 
			
		||||
@ -34,6 +35,7 @@ M.defaults = {
 | 
			
		||||
    endpoint = "", -- example: "https://<your-resource-name>.openai.azure.com"
 | 
			
		||||
    deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment")
 | 
			
		||||
    api_version = "2024-06-01",
 | 
			
		||||
    timeout = 30000, -- Timeout in milliseconds
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
    ["local"] = false,
 | 
			
		||||
@ -42,22 +44,25 @@ M.defaults = {
 | 
			
		||||
  claude = {
 | 
			
		||||
    endpoint = "https://api.anthropic.com",
 | 
			
		||||
    model = "claude-3-5-sonnet-20240620",
 | 
			
		||||
    timeout = 30000, -- Timeout in milliseconds
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
    ["local"] = false,
 | 
			
		||||
  },
 | 
			
		||||
  ---@type AvanteGeminiProvider
 | 
			
		||||
  ---@type AvanteSupportedProvider
 | 
			
		||||
  gemini = {
 | 
			
		||||
    endpoint = "https://generativelanguage.googleapis.com/v1beta/models",
 | 
			
		||||
    model = "gemini-1.5-pro",
 | 
			
		||||
    timeout = 30000, -- Timeout in milliseconds
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 4096,
 | 
			
		||||
    ["local"] = false,
 | 
			
		||||
  },
 | 
			
		||||
  ---@type AvanteGeminiProvider
 | 
			
		||||
  cohere = {
 | 
			
		||||
    endpoint = "https://api.cohere.com",
 | 
			
		||||
    endpoint = "https://api.cohere.com/v1",
 | 
			
		||||
    model = "command-r-plus",
 | 
			
		||||
    timeout = 30000, -- Timeout in milliseconds
 | 
			
		||||
    temperature = 0,
 | 
			
		||||
    max_tokens = 3072,
 | 
			
		||||
    ["local"] = false,
 | 
			
		||||
@ -189,7 +194,7 @@ end
 | 
			
		||||
 | 
			
		||||
---get supported providers
 | 
			
		||||
---@param provider Provider
 | 
			
		||||
---@return AvanteProvider | fun(): AvanteProvider
 | 
			
		||||
---@return AvanteProviderFunctor
 | 
			
		||||
M.get_provider = function(provider)
 | 
			
		||||
  if M.options[provider] ~= nil then
 | 
			
		||||
    return vim.deepcopy(M.options[provider], true)
 | 
			
		||||
@ -200,8 +205,17 @@ M.get_provider = function(provider)
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.BASE_PROVIDER_KEYS =
 | 
			
		||||
  { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure", "api_key_name" }
 | 
			
		||||
M.BASE_PROVIDER_KEYS = {
 | 
			
		||||
  "endpoint",
 | 
			
		||||
  "model",
 | 
			
		||||
  "local",
 | 
			
		||||
  "deployment",
 | 
			
		||||
  "api_version",
 | 
			
		||||
  "proxy",
 | 
			
		||||
  "allow_insecure",
 | 
			
		||||
  "api_key_name",
 | 
			
		||||
  "timeout",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
---@return {width: integer, height: integer}
 | 
			
		||||
function M.get_sidebar_layout_options()
 | 
			
		||||
 | 
			
		||||
@ -94,7 +94,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
 | 
			
		||||
  ---@type AvanteHandlerOptions
 | 
			
		||||
  local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
 | 
			
		||||
  ---@type AvanteCurlOutput
 | 
			
		||||
  local spec = Provider.parse_curl_args(Config.get_provider(provider), code_opts)
 | 
			
		||||
  local spec = Provider.parse_curl_args(Provider, code_opts)
 | 
			
		||||
 | 
			
		||||
  ---@param line string
 | 
			
		||||
  local function parse_stream_data(line)
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,10 @@
 | 
			
		||||
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
 | 
			
		||||
---@field deployment string
 | 
			
		||||
---@field api_version string
 | 
			
		||||
---@field temperature number
 | 
			
		||||
---@field max_tokens number
 | 
			
		||||
 | 
			
		||||
local Utils = require("avante.utils")
 | 
			
		||||
local Config = require("avante.config")
 | 
			
		||||
local P = require("avante.providers")
 | 
			
		||||
local O = require("avante.providers").openai
 | 
			
		||||
 | 
			
		||||
@ -8,10 +13,6 @@ local M = {}
 | 
			
		||||
 | 
			
		||||
M.api_key_name = "AZURE_OPENAI_API_KEY"
 | 
			
		||||
 | 
			
		||||
M.has = function()
 | 
			
		||||
  return os.getenv(M.api_key_name) and true or false
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.parse_message = O.parse_message
 | 
			
		||||
M.parse_response = O.parse_response
 | 
			
		||||
 | 
			
		||||
@ -22,7 +23,7 @@ M.parse_curl_args = function(provider, code_opts)
 | 
			
		||||
    ["Content-Type"] = "application/json",
 | 
			
		||||
  }
 | 
			
		||||
  if not P.env.is_local("azure") then
 | 
			
		||||
    headers["api-key"] = os.getenv(base.api_key_name or M.api_key_name)
 | 
			
		||||
    headers["api-key"] = provider.parse_api_key()
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return {
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
local Utils = require("avante.utils")
 | 
			
		||||
local Config = require("avante.config")
 | 
			
		||||
local Tiktoken = require("avante.tiktoken")
 | 
			
		||||
local P = require("avante.providers")
 | 
			
		||||
 | 
			
		||||
@ -8,10 +7,6 @@ local M = {}
 | 
			
		||||
 | 
			
		||||
M.api_key_name = "ANTHROPIC_API_KEY"
 | 
			
		||||
 | 
			
		||||
M.has = function()
 | 
			
		||||
  return os.getenv(M.api_key_name) and true or false
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.parse_message = function(opts)
 | 
			
		||||
  local code_prompt_obj = {
 | 
			
		||||
    type = "text",
 | 
			
		||||
@ -93,7 +88,7 @@ M.parse_curl_args = function(provider, code_opts)
 | 
			
		||||
    ["anthropic-beta"] = "prompt-caching-2024-07-31",
 | 
			
		||||
  }
 | 
			
		||||
  if not P.env.is_local("claude") then
 | 
			
		||||
    headers["x-api-key"] = os.getenv(base.api_key_name or M.api_key_name)
 | 
			
		||||
    headers["x-api-key"] = provider.parse_api_key()
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return {
 | 
			
		||||
 | 
			
		||||
@ -30,10 +30,6 @@ local M = {}
 | 
			
		||||
 | 
			
		||||
M.api_key_name = "CO_API_KEY"
 | 
			
		||||
 | 
			
		||||
M.has = function()
 | 
			
		||||
  return os.getenv(M.api_key_name) and true or false
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.parse_message = function(opts)
 | 
			
		||||
  local user_prompt = opts.base_prompt
 | 
			
		||||
    .. "\n\nCODE:\n"
 | 
			
		||||
@ -103,11 +99,11 @@ M.parse_curl_args = function(provider, code_opts)
 | 
			
		||||
      .. vim.version().patch,
 | 
			
		||||
  }
 | 
			
		||||
  if not P.env.is_local("openai") then
 | 
			
		||||
    headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
 | 
			
		||||
    headers["Authorization"] = "Bearer " .. provider.parse_api_key()
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return {
 | 
			
		||||
    url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/chat",
 | 
			
		||||
    url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat",
 | 
			
		||||
    proxy = base.proxy,
 | 
			
		||||
    insecure = base.allow_insecure,
 | 
			
		||||
    headers = headers,
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,9 @@
 | 
			
		||||
---@class AvanteCopilotProvider: AvanteSupportedProvider
 | 
			
		||||
---@field timeout number
 | 
			
		||||
 | 
			
		||||
local curl = require("plenary.curl")
 | 
			
		||||
 | 
			
		||||
local Utils = require("avante.utils")
 | 
			
		||||
local Config = require("avante.config")
 | 
			
		||||
local P = require("avante.providers")
 | 
			
		||||
local O = require("avante.providers").openai
 | 
			
		||||
 | 
			
		||||
@ -196,7 +198,7 @@ M.parse_curl_args = function(provider, code_opts)
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    local response = curl.get(url, {
 | 
			
		||||
      timeout = Config.copilot.timeout,
 | 
			
		||||
      timeout = base.timeout,
 | 
			
		||||
      headers = headers,
 | 
			
		||||
      proxy = base.proxy,
 | 
			
		||||
      insecure = base.allow_insecure,
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
local Utils = require("avante.utils")
 | 
			
		||||
local Config = require("avante.config")
 | 
			
		||||
local P = require("avante.providers")
 | 
			
		||||
 | 
			
		||||
---@class AvanteProviderFunctor
 | 
			
		||||
@ -7,10 +6,6 @@ local M = {}
 | 
			
		||||
 | 
			
		||||
M.api_key_name = "GEMINI_API_KEY"
 | 
			
		||||
 | 
			
		||||
M.has = function()
 | 
			
		||||
  return os.getenv(M.api_key_name) and true or false
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.parse_message = function(opts)
 | 
			
		||||
  local code_prompt_obj = {
 | 
			
		||||
    text = string.format("<code>```%s\n%s```</code>", opts.code_lang, opts.code_content),
 | 
			
		||||
@ -68,7 +63,7 @@ M.parse_curl_args = function(provider, code_opts)
 | 
			
		||||
      .. "/"
 | 
			
		||||
      .. base.model
 | 
			
		||||
      .. ":streamGenerateContent?alt=sse&key="
 | 
			
		||||
      .. os.getenv(base.api_key_name or M.api_key_name),
 | 
			
		||||
      .. provider.parse_api_key(),
 | 
			
		||||
    proxy = base.proxy,
 | 
			
		||||
    insecure = base.allow_insecure,
 | 
			
		||||
    headers = { ["Content-Type"] = "application/json" },
 | 
			
		||||
 | 
			
		||||
@ -52,35 +52,24 @@ local Dressing = require("avante.ui.dressing")
 | 
			
		||||
---@field temperature? number
 | 
			
		||||
---@field max_tokens? number
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
 | 
			
		||||
---@field deployment string
 | 
			
		||||
---@field api_version string
 | 
			
		||||
---@field temperature number
 | 
			
		||||
---@field max_tokens number
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteCopilotProvider: AvanteSupportedProvider
 | 
			
		||||
---@field timeout number
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteGeminiProvider: AvanteDefaultBaseProvider
 | 
			
		||||
---@field model string
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteProvider: AvanteDefaultBaseProvider
 | 
			
		||||
---@field parse_response_data AvanteResponseParser
 | 
			
		||||
---@field parse_curl_args AvanteCurlArgsParser
 | 
			
		||||
---@field parse_stream_data? AvanteStreamParser
 | 
			
		||||
---
 | 
			
		||||
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
 | 
			
		||||
---@alias AvanteChunkParser fun(chunk: string): any
 | 
			
		||||
---@alias AvanteCompleteParser fun(err: string|nil): nil
 | 
			
		||||
---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any>
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteProvider: AvanteSupportedProvider
 | 
			
		||||
---@field parse_response_data AvanteResponseParser
 | 
			
		||||
---@field parse_curl_args? AvanteCurlArgsParser
 | 
			
		||||
---@field parse_stream_data? AvanteStreamParser
 | 
			
		||||
---
 | 
			
		||||
---@class AvanteProviderFunctor
 | 
			
		||||
---@field parse_message AvanteMessageParser
 | 
			
		||||
---@field parse_response AvanteResponseParser
 | 
			
		||||
---@field parse_curl_args AvanteCurlArgsParser
 | 
			
		||||
---@field setup? fun(): nil
 | 
			
		||||
---@field setup fun(): nil
 | 
			
		||||
---@field has fun(): boolean
 | 
			
		||||
---@field api_key_name string
 | 
			
		||||
---@field parse_api_key fun(): string | nil
 | 
			
		||||
---@field parse_stream_data? AvanteStreamParser
 | 
			
		||||
---
 | 
			
		||||
---@class avante.Providers
 | 
			
		||||
@ -92,38 +81,50 @@ local Dressing = require("avante.ui.dressing")
 | 
			
		||||
---@field cohere AvanteProviderFunctor
 | 
			
		||||
local M = {}
 | 
			
		||||
 | 
			
		||||
setmetatable(M, {
 | 
			
		||||
  ---@param t avante.Providers
 | 
			
		||||
  ---@param k Provider
 | 
			
		||||
  __index = function(t, k)
 | 
			
		||||
    if Config.vendors[k] ~= nil then
 | 
			
		||||
      ---@type AvanteProvider
 | 
			
		||||
      local v = Config.vendors[k]
 | 
			
		||||
 | 
			
		||||
      -- Patch from vendors similar to supported providers.
 | 
			
		||||
      ---@type AvanteProviderFunctor
 | 
			
		||||
      t[k] = setmetatable({}, { __index = v })
 | 
			
		||||
      -- Hack for aliasing and makes it sane for us.
 | 
			
		||||
      t[k].parse_response = v.parse_response_data
 | 
			
		||||
      t[k].has = function()
 | 
			
		||||
        return os.getenv(t[k].api_key_name) and true or false
 | 
			
		||||
      end
 | 
			
		||||
 | 
			
		||||
      return t[k]
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    ---@type AvanteProviderFunctor
 | 
			
		||||
    t[k] = require("avante.providers." .. k)
 | 
			
		||||
    return t[k]
 | 
			
		||||
  end,
 | 
			
		||||
})
 | 
			
		||||
 | 
			
		||||
---@class EnvironmentHandler
 | 
			
		||||
local E = {}
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
E._once = false
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@type table<string, string>
 | 
			
		||||
E.cache = {}
 | 
			
		||||
 | 
			
		||||
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor
 | 
			
		||||
---@return string | nil
 | 
			
		||||
E.parse_envvar = function(Opts)
 | 
			
		||||
  local api_key_name = Opts.api_key_name
 | 
			
		||||
  if api_key_name == nil then
 | 
			
		||||
    error("Requires api_key_name")
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  if E.cache[api_key_name] ~= nil then
 | 
			
		||||
    return E.cache[api_key_name]
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local cmd = api_key_name:match("^cmd:(.*)")
 | 
			
		||||
 | 
			
		||||
  local key = nil
 | 
			
		||||
  if cmd ~= nil then
 | 
			
		||||
    local ok, job = pcall(vim.system, vim.split(cmd, " ", { trimempty = true }), { text = true })
 | 
			
		||||
    if not ok then
 | 
			
		||||
      Utils.error("Failed to execute command to retrieve secrets: " .. cmd, { once = true, title = "Avante" })
 | 
			
		||||
    else
 | 
			
		||||
      local out = job:wait()
 | 
			
		||||
      key = out.stdout
 | 
			
		||||
    end
 | 
			
		||||
  else
 | 
			
		||||
    key = os.getenv(api_key_name)
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  if key ~= nil then
 | 
			
		||||
    E.cache[api_key_name] = key
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return key
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- initialize the environment variable for current neovim session.
 | 
			
		||||
--- This will only run once and spawn a UI for users to input the envvar.
 | 
			
		||||
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
 | 
			
		||||
@ -131,7 +132,8 @@ E._once = false
 | 
			
		||||
E.setup = function(opts)
 | 
			
		||||
  local var = opts.provider.api_key_name
 | 
			
		||||
 | 
			
		||||
  if var == M.AVANTE_INTERNAL_KEY then
 | 
			
		||||
  -- check if var is a all caps string
 | 
			
		||||
  if var == M.AVANTE_INTERNAL_KEY or var:match("^cmd:(.*)") then
 | 
			
		||||
    return
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
@ -149,51 +151,54 @@ E.setup = function(opts)
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  if refresh then
 | 
			
		||||
  local function mount_dressing_buffer()
 | 
			
		||||
    vim.defer_fn(function()
 | 
			
		||||
      Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm })
 | 
			
		||||
      -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
 | 
			
		||||
      local exclude_buftypes = { "qf", "nofile" }
 | 
			
		||||
      local exclude_filetypes = {
 | 
			
		||||
        "NvimTree",
 | 
			
		||||
        "Outline",
 | 
			
		||||
        "help",
 | 
			
		||||
        "dashboard",
 | 
			
		||||
        "alpha",
 | 
			
		||||
        "qf",
 | 
			
		||||
        "ministarter",
 | 
			
		||||
        "TelescopePrompt",
 | 
			
		||||
        "gitcommit",
 | 
			
		||||
        "gitrebase",
 | 
			
		||||
        "DressingInput",
 | 
			
		||||
      }
 | 
			
		||||
      if
 | 
			
		||||
        not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
 | 
			
		||||
        and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
 | 
			
		||||
        and not opts.provider.has()
 | 
			
		||||
      then
 | 
			
		||||
        Dressing.initialize_input_buffer({
 | 
			
		||||
          opts = { prompt = "Enter " .. var .. ": " },
 | 
			
		||||
          on_confirm = on_confirm,
 | 
			
		||||
        })
 | 
			
		||||
      end
 | 
			
		||||
    end, 200)
 | 
			
		||||
  elseif not E._once then
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  if refresh then
 | 
			
		||||
    mount_dressing_buffer()
 | 
			
		||||
    return
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  if not E._once then
 | 
			
		||||
    E._once = true
 | 
			
		||||
    api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, {
 | 
			
		||||
      pattern = "*",
 | 
			
		||||
      once = true,
 | 
			
		||||
      callback = function()
 | 
			
		||||
        vim.defer_fn(function()
 | 
			
		||||
          -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
 | 
			
		||||
          local exclude_buftypes = { "qf", "nofile" }
 | 
			
		||||
          local exclude_filetypes = {
 | 
			
		||||
            "NvimTree",
 | 
			
		||||
            "Outline",
 | 
			
		||||
            "help",
 | 
			
		||||
            "dashboard",
 | 
			
		||||
            "alpha",
 | 
			
		||||
            "qf",
 | 
			
		||||
            "ministarter",
 | 
			
		||||
            "TelescopePrompt",
 | 
			
		||||
            "gitcommit",
 | 
			
		||||
            "gitrebase",
 | 
			
		||||
            "DressingInput",
 | 
			
		||||
          }
 | 
			
		||||
          if
 | 
			
		||||
            not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
 | 
			
		||||
            and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
 | 
			
		||||
            and not opts.provider.has()
 | 
			
		||||
          then
 | 
			
		||||
            Dressing.initialize_input_buffer({
 | 
			
		||||
              opts = { prompt = "Enter " .. var .. ": " },
 | 
			
		||||
              on_confirm = on_confirm,
 | 
			
		||||
            })
 | 
			
		||||
          end
 | 
			
		||||
        end, 200)
 | 
			
		||||
      end,
 | 
			
		||||
      callback = mount_dressing_buffer,
 | 
			
		||||
    })
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@param provider Provider
 | 
			
		||||
E.is_local = function(provider)
 | 
			
		||||
  local cur = M.get(provider)
 | 
			
		||||
  local cur = M.get_config(provider)
 | 
			
		||||
  return cur["local"] ~= nil and cur["local"] or false
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
@ -201,14 +206,47 @@ M.env = E
 | 
			
		||||
 | 
			
		||||
M.AVANTE_INTERNAL_KEY = "__avante_env_internal"
 | 
			
		||||
 | 
			
		||||
M = setmetatable(M, {
 | 
			
		||||
  ---@param t avante.Providers
 | 
			
		||||
  ---@param k Provider
 | 
			
		||||
  __index = function(t, k)
 | 
			
		||||
    ---@type AvanteProviderFunctor
 | 
			
		||||
    local Opts = M.get_config(k)
 | 
			
		||||
 | 
			
		||||
    if Config.vendors[k] ~= nil then
 | 
			
		||||
      Opts.parse_response = Opts.parse_response_data
 | 
			
		||||
      t[k] = Opts
 | 
			
		||||
    else
 | 
			
		||||
      t[k] = vim.tbl_deep_extend("keep", Opts, require("avante.providers." .. k))
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    t[k].parse_api_key = function()
 | 
			
		||||
      return E.parse_envvar(t[k])
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    if t[k].has == nil then
 | 
			
		||||
      t[k].has = function()
 | 
			
		||||
        return E.parse_envvar(t[k]) ~= nil
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    if t[k].setup == nil then
 | 
			
		||||
      t[k].setup = function()
 | 
			
		||||
        t[k].parse_api_key()
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    return t[k]
 | 
			
		||||
  end,
 | 
			
		||||
})
 | 
			
		||||
 | 
			
		||||
M.setup = function()
 | 
			
		||||
  ---@type AvanteProviderFunctor
 | 
			
		||||
  local provider = M[Config.provider]
 | 
			
		||||
  E.setup({ provider = provider })
 | 
			
		||||
 | 
			
		||||
  if provider.setup ~= nil then
 | 
			
		||||
  vim.schedule(function()
 | 
			
		||||
    provider.setup()
 | 
			
		||||
  end
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  M.commands()
 | 
			
		||||
end
 | 
			
		||||
@ -216,6 +254,8 @@ end
 | 
			
		||||
---@private
 | 
			
		||||
---@param provider Provider
 | 
			
		||||
function M.refresh(provider)
 | 
			
		||||
  require("avante.config").override({ provider = provider })
 | 
			
		||||
 | 
			
		||||
  ---@type AvanteProviderFunctor
 | 
			
		||||
  local p = M[Config.provider]
 | 
			
		||||
  if not p.has() then
 | 
			
		||||
@ -223,7 +263,6 @@ function M.refresh(provider)
 | 
			
		||||
  else
 | 
			
		||||
    Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
 | 
			
		||||
  end
 | 
			
		||||
  require("avante.config").override({ provider = provider })
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
local default_providers = { "openai", "claude", "azure", "gemini", "copilot" }
 | 
			
		||||
@ -242,7 +281,8 @@ M.commands = function()
 | 
			
		||||
      end
 | 
			
		||||
      local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
 | 
			
		||||
      -- join two tables
 | 
			
		||||
      local Keys = vim.list_extend(default_providers, vim.tbl_keys(Config.vendors or {}))
 | 
			
		||||
      local Keys = vim.list_extend({}, default_providers)
 | 
			
		||||
      Keys = vim.list_extend(Keys, vim.tbl_keys(Config.vendors or {}))
 | 
			
		||||
      return vim.tbl_filter(function(key)
 | 
			
		||||
        return key:find(prefix) == 1
 | 
			
		||||
      end, Keys)
 | 
			
		||||
@ -280,7 +320,8 @@ end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param provider Provider
 | 
			
		||||
M.get = function(provider)
 | 
			
		||||
---@return AvanteProviderFunctor
 | 
			
		||||
M.get_config = function(provider)
 | 
			
		||||
  local cur = Config.get_provider(provider or Config.provider)
 | 
			
		||||
  return type(cur) == "function" and cur() or cur
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
@ -26,10 +26,6 @@ local M = {}
 | 
			
		||||
 | 
			
		||||
M.api_key_name = "OPENAI_API_KEY"
 | 
			
		||||
 | 
			
		||||
M.has = function()
 | 
			
		||||
  return os.getenv(M.api_key_name) and true or false
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.parse_message = function(opts)
 | 
			
		||||
  local user_prompt = opts.base_prompt
 | 
			
		||||
    .. "\n\nCODE:\n"
 | 
			
		||||
@ -91,7 +87,7 @@ M.parse_curl_args = function(provider, code_opts)
 | 
			
		||||
    ["Content-Type"] = "application/json",
 | 
			
		||||
  }
 | 
			
		||||
  if not P.env.is_local("openai") then
 | 
			
		||||
    headers["Authorization"] = "Bearer " .. os.getenv(base.api_key_name or M.api_key_name)
 | 
			
		||||
    headers["Authorization"] = "Bearer " .. provider.parse_api_key()
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user