Compare commits
10 Commits
9bad591e8a
...
c0eaefa633
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c0eaefa633 | ||
![]() |
7fa7b0fa3b | ||
![]() |
c60dc6c316 | ||
![]() |
1a4f2575d6 | ||
![]() |
16bcbc0229 | ||
![]() |
8f32512949 | ||
![]() |
25111c6df3 | ||
![]() |
763dbe064d | ||
![]() |
76c06ed277 | ||
![]() |
ce55d7ac9e |
@ -582,15 +582,15 @@ For more information, see [Custom Providers](https://github.com/yetone/avante.nv
|
|||||||
|
|
||||||
## Web Search Engines
|
## Web Search Engines
|
||||||
|
|
||||||
Avante's tools include some web search engines, currently support [tavily](https://tavily.com/), [serpapi](https://serpapi.com/) and google's [programmable search engine](https://developers.google.com/custom-search/v1/overview). The default is tavily, and can be changed through configuring `Config.web_search_engine.provider`:
|
Avante's tools include some web search engines, currently support [tavily](https://tavily.com/), [serpapi](https://serpapi.com/), [searchapi](https://www.searchapi.io/) and google's [programmable search engine](https://developers.google.com/custom-search/v1/overview). The default is tavily, and can be changed through configuring `Config.web_search_engine.provider`:
|
||||||
|
|
||||||
```lua
|
```lua
|
||||||
web_search_engine = {
|
web_search_engine = {
|
||||||
provider = "tavily", -- tavily, serpapi or google
|
provider = "tavily", -- tavily, serpapi, searchapi or google
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
You need to set the environment variable `TAVILY_API_KEY` , `SERPAPI_API_KEY` to use tavily or serpapi.
|
You need to set the environment variable `TAVILY_API_KEY` , `SERPAPI_API_KEY`, `SEARCHAPI_API_KEY` to use tavily or serpapi or searchapi.
|
||||||
To use google, set the `GOOGLE_SEARCH_API_KEY` as the [API key](https://developers.google.com/custom-search/v1/overview), and `GOOGLE_SEARCH_ENGINE_ID` as the [search engine](https://programmablesearchengine.google.com) ID.
|
To use google, set the `GOOGLE_SEARCH_API_KEY` as the [API key](https://developers.google.com/custom-search/v1/overview), and `GOOGLE_SEARCH_ENGINE_ID` as the [search engine](https://programmablesearchengine.google.com) ID.
|
||||||
|
|
||||||
## Disable Tools
|
## Disable Tools
|
||||||
|
@ -56,10 +56,39 @@ M._defaults = {
|
|||||||
title = result.title,
|
title = result.title,
|
||||||
link = result.link,
|
link = result.link,
|
||||||
snippet = result.snippet,
|
snippet = result.snippet,
|
||||||
|
date = result.date,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
)
|
)
|
||||||
:take(5)
|
:take(10)
|
||||||
|
:totable()
|
||||||
|
return vim.json.encode(jsn), nil
|
||||||
|
end
|
||||||
|
return "", nil
|
||||||
|
end,
|
||||||
|
},
|
||||||
|
searchapi = {
|
||||||
|
api_key_name = "SEARCHAPI_API_KEY",
|
||||||
|
extra_request_body = {
|
||||||
|
engine = "google",
|
||||||
|
},
|
||||||
|
---@type WebSearchEngineProviderResponseBodyFormatter
|
||||||
|
format_response_body = function(body)
|
||||||
|
if body.answer_box ~= nil then return body.answer_box.result, nil end
|
||||||
|
if body.organic_results ~= nil then
|
||||||
|
local jsn = vim
|
||||||
|
.iter(body.organic_results)
|
||||||
|
:map(
|
||||||
|
function(result)
|
||||||
|
return {
|
||||||
|
title = result.title,
|
||||||
|
link = result.link,
|
||||||
|
snippet = result.snippet,
|
||||||
|
date = result.date,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
)
|
||||||
|
:take(10)
|
||||||
:totable()
|
:totable()
|
||||||
return vim.json.encode(jsn), nil
|
return vim.json.encode(jsn), nil
|
||||||
end
|
end
|
||||||
@ -84,7 +113,7 @@ M._defaults = {
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
)
|
)
|
||||||
:take(5)
|
:take(10)
|
||||||
:totable()
|
:totable()
|
||||||
return vim.json.encode(jsn), nil
|
return vim.json.encode(jsn), nil
|
||||||
end
|
end
|
||||||
|
@ -25,8 +25,8 @@ M.generate_prompts = function(opts)
|
|||||||
local Provider = opts.provider or P[Config.provider]
|
local Provider = opts.provider or P[Config.provider]
|
||||||
local mode = opts.mode or "planning"
|
local mode = opts.mode or "planning"
|
||||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||||
local _, body_opts = P.parse_config(Provider)
|
local _, request_body = P.parse_config(Provider)
|
||||||
local max_tokens = body_opts.max_tokens or 4096
|
local max_tokens = request_body.max_tokens or 4096
|
||||||
|
|
||||||
-- Check if the instructions contains an image path
|
-- Check if the instructions contains an image path
|
||||||
local image_paths = {}
|
local image_paths = {}
|
||||||
|
@ -42,13 +42,12 @@ function M.list_files(opts, on_log)
|
|||||||
add_dirs = true,
|
add_dirs = true,
|
||||||
depth = opts.depth,
|
depth = opts.depth,
|
||||||
})
|
})
|
||||||
local result = ""
|
local filepaths = {}
|
||||||
for _, file in ipairs(files) do
|
for _, file in ipairs(files) do
|
||||||
local uniform_path = Utils.uniform_path(file)
|
local uniform_path = Utils.uniform_path(file)
|
||||||
result = result .. uniform_path .. "\n"
|
table.insert(filepaths, uniform_path)
|
||||||
end
|
end
|
||||||
result = result:gsub("\n$", "")
|
return vim.json.encode(filepaths), nil
|
||||||
return result, nil
|
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string, keyword: string }
|
---@param opts { rel_path: string, keyword: string }
|
||||||
@ -63,12 +62,11 @@ function M.search_files(opts, on_log)
|
|||||||
local files = Utils.scan_directory_respect_gitignore({
|
local files = Utils.scan_directory_respect_gitignore({
|
||||||
directory = abs_path,
|
directory = abs_path,
|
||||||
})
|
})
|
||||||
local result = ""
|
local filepaths = {}
|
||||||
for _, file in ipairs(files) do
|
for _, file in ipairs(files) do
|
||||||
if file:find(opts.keyword) then result = result .. file .. "\n" end
|
if file:find(opts.keyword) then table.insert(filepaths, file) end
|
||||||
end
|
end
|
||||||
result = result:gsub("\n$", "")
|
return vim.json.encode(filepaths), nil
|
||||||
return result, nil
|
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string, keyword: string }
|
---@param opts { rel_path: string, keyword: string }
|
||||||
@ -105,7 +103,9 @@ function M.search(opts, on_log)
|
|||||||
if on_log then on_log("Running command: " .. cmd) end
|
if on_log then on_log("Running command: " .. cmd) end
|
||||||
local result = vim.fn.system(cmd)
|
local result = vim.fn.system(cmd)
|
||||||
|
|
||||||
return result or "", nil
|
local filepaths = vim.split(result, "\n")
|
||||||
|
|
||||||
|
return vim.json.encode(filepaths), nil
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string }
|
---@param opts { rel_path: string }
|
||||||
@ -184,9 +184,10 @@ function M.rename_file(opts, on_log)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string, new_rel_path: string }
|
---@param opts { rel_path: string, new_rel_path: string }
|
||||||
|
---@param on_log? fun(log: string): nil
|
||||||
---@return boolean success
|
---@return boolean success
|
||||||
---@return string|nil error
|
---@return string|nil error
|
||||||
function M.copy_file(opts)
|
function M.copy_file(opts, on_log)
|
||||||
local abs_path = get_abs_path(opts.rel_path)
|
local abs_path = get_abs_path(opts.rel_path)
|
||||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
||||||
@ -194,38 +195,47 @@ function M.copy_file(opts)
|
|||||||
local new_abs_path = get_abs_path(opts.new_rel_path)
|
local new_abs_path = get_abs_path(opts.new_rel_path)
|
||||||
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
|
if not has_permission_to_access(new_abs_path) then return false, "No permission to access path: " .. new_abs_path end
|
||||||
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
|
if Path:new(new_abs_path):exists() then return false, "File already exists: " .. new_abs_path end
|
||||||
|
if on_log then on_log("Copying file: " .. abs_path .. " to " .. new_abs_path) end
|
||||||
Path:new(new_abs_path):write(Path:new(abs_path):read())
|
Path:new(new_abs_path):write(Path:new(abs_path):read())
|
||||||
return true, nil
|
return true, nil
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string }
|
---@param opts { rel_path: string }
|
||||||
|
---@param on_log? fun(log: string): nil
|
||||||
---@return boolean success
|
---@return boolean success
|
||||||
---@return string|nil error
|
---@return string|nil error
|
||||||
function M.delete_file(opts)
|
function M.delete_file(opts, on_log)
|
||||||
local abs_path = get_abs_path(opts.rel_path)
|
local abs_path = get_abs_path(opts.rel_path)
|
||||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
if not Path:new(abs_path):exists() then return false, "File not found: " .. abs_path end
|
||||||
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
|
if not Path:new(abs_path):is_file() then return false, "Path is not a file: " .. abs_path end
|
||||||
if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end
|
if not M.confirm("Are you sure you want to delete the file: " .. abs_path) then return false, "User canceled" end
|
||||||
|
if on_log then on_log("Deleting file: " .. abs_path) end
|
||||||
os.remove(abs_path)
|
os.remove(abs_path)
|
||||||
return true, nil
|
return true, nil
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string }
|
---@param opts { rel_path: string }
|
||||||
|
---@param on_log? fun(log: string): nil
|
||||||
---@return boolean success
|
---@return boolean success
|
||||||
---@return string|nil error
|
---@return string|nil error
|
||||||
function M.create_dir(opts)
|
function M.create_dir(opts, on_log)
|
||||||
local abs_path = get_abs_path(opts.rel_path)
|
local abs_path = get_abs_path(opts.rel_path)
|
||||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
|
if Path:new(abs_path):exists() then return false, "Directory already exists: " .. abs_path end
|
||||||
|
if not M.confirm("Are you sure you want to create the directory: " .. abs_path) then
|
||||||
|
return false, "User canceled"
|
||||||
|
end
|
||||||
|
if on_log then on_log("Creating directory: " .. abs_path) end
|
||||||
Path:new(abs_path):mkdir({ parents = true })
|
Path:new(abs_path):mkdir({ parents = true })
|
||||||
return true, nil
|
return true, nil
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string, new_rel_path: string }
|
---@param opts { rel_path: string, new_rel_path: string }
|
||||||
|
---@param on_log? fun(log: string): nil
|
||||||
---@return boolean success
|
---@return boolean success
|
||||||
---@return string|nil error
|
---@return string|nil error
|
||||||
function M.rename_dir(opts)
|
function M.rename_dir(opts, on_log)
|
||||||
local abs_path = get_abs_path(opts.rel_path)
|
local abs_path = get_abs_path(opts.rel_path)
|
||||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
|
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
|
||||||
@ -236,14 +246,16 @@ function M.rename_dir(opts)
|
|||||||
if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then
|
if not M.confirm("Are you sure you want to rename directory " .. abs_path .. " to " .. new_abs_path .. "?") then
|
||||||
return false, "User canceled"
|
return false, "User canceled"
|
||||||
end
|
end
|
||||||
|
if on_log then on_log("Renaming directory: " .. abs_path .. " to " .. new_abs_path) end
|
||||||
os.rename(abs_path, new_abs_path)
|
os.rename(abs_path, new_abs_path)
|
||||||
return true, nil
|
return true, nil
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param opts { rel_path: string }
|
---@param opts { rel_path: string }
|
||||||
|
---@param on_log? fun(log: string): nil
|
||||||
---@return boolean success
|
---@return boolean success
|
||||||
---@return string|nil error
|
---@return string|nil error
|
||||||
function M.delete_dir(opts)
|
function M.delete_dir(opts, on_log)
|
||||||
local abs_path = get_abs_path(opts.rel_path)
|
local abs_path = get_abs_path(opts.rel_path)
|
||||||
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
if not has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||||
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
|
if not Path:new(abs_path):exists() then return false, "Directory not found: " .. abs_path end
|
||||||
@ -251,6 +263,7 @@ function M.delete_dir(opts)
|
|||||||
if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then
|
if not M.confirm("Are you sure you want to delete the directory: " .. abs_path) then
|
||||||
return false, "User canceled"
|
return false, "User canceled"
|
||||||
end
|
end
|
||||||
|
if on_log then on_log("Deleting directory: " .. abs_path) end
|
||||||
os.remove(abs_path)
|
os.remove(abs_path)
|
||||||
return true, nil
|
return true, nil
|
||||||
end
|
end
|
||||||
@ -327,6 +340,23 @@ function M.web_search(opts, on_log)
|
|||||||
if resp.status ~= 200 then return nil, "Error: " .. resp.body end
|
if resp.status ~= 200 then return nil, "Error: " .. resp.body end
|
||||||
local jsn = vim.json.decode(resp.body)
|
local jsn = vim.json.decode(resp.body)
|
||||||
return search_engine.format_response_body(jsn)
|
return search_engine.format_response_body(jsn)
|
||||||
|
elseif provider_type == "searchapi" then
|
||||||
|
local query_params = vim.tbl_deep_extend("force", {
|
||||||
|
api_key = api_key,
|
||||||
|
q = opts.query,
|
||||||
|
}, search_engine.extra_request_body)
|
||||||
|
local query_string = ""
|
||||||
|
for key, value in pairs(query_params) do
|
||||||
|
query_string = query_string .. key .. "=" .. vim.uri_encode(value) .. "&"
|
||||||
|
end
|
||||||
|
local resp = curl.get("https://searchapi.io/api/v1/search?" .. query_string, {
|
||||||
|
headers = {
|
||||||
|
["Content-Type"] = "application/json",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if resp.status ~= 200 then return nil, "Error: " .. resp.body end
|
||||||
|
local jsn = vim.json.decode(resp.body)
|
||||||
|
return search_engine.format_response_body(jsn)
|
||||||
elseif provider_type == "google" then
|
elseif provider_type == "google" then
|
||||||
local engine_id = os.getenv(search_engine.engine_id_name)
|
local engine_id = os.getenv(search_engine.engine_id_name)
|
||||||
if engine_id == nil or engine_id == "" then
|
if engine_id == nil or engine_id == "" then
|
||||||
|
@ -18,31 +18,34 @@ M.parse_response = O.parse_response
|
|||||||
M.parse_response_without_stream = O.parse_response_without_stream
|
M.parse_response_without_stream = O.parse_response_without_stream
|
||||||
|
|
||||||
M.parse_curl_args = function(provider, prompt_opts)
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
if P.env.require_api_key(base) then headers["api-key"] = provider.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then headers["api-key"] = provider.parse_api_key() end
|
||||||
|
|
||||||
-- NOTE: When using "o" series set the supported parameters only
|
-- NOTE: When using "o" series set the supported parameters only
|
||||||
if O.is_o_series_model(base.model) then
|
if O.is_o_series_model(provider_conf.model) then
|
||||||
body_opts.max_tokens = nil
|
request_body.max_tokens = nil
|
||||||
body_opts.temperature = 1
|
request_body.temperature = 1
|
||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(
|
url = Utils.url_join(
|
||||||
base.endpoint,
|
provider_conf.endpoint,
|
||||||
"/openai/deployments/" .. base.deployment .. "/chat/completions?api-version=" .. base.api_version
|
"/openai/deployments/"
|
||||||
|
.. provider_conf.deployment
|
||||||
|
.. "/chat/completions?api-version="
|
||||||
|
.. provider_conf.api_version
|
||||||
),
|
),
|
||||||
proxy = base.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
messages = M.parse_messages(prompt_opts),
|
messages = M.parse_messages(prompt_opts),
|
||||||
stream = true,
|
stream = true,
|
||||||
}, body_opts),
|
}, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
local Utils = require("avante.utils")
|
local Utils = require("avante.utils")
|
||||||
local Clipboard = require("avante.clipboard")
|
|
||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
|
||||||
---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
||||||
@ -17,17 +16,14 @@ M.api_key_name = "BEDROCK_KEYS"
|
|||||||
M.use_xml_format = true
|
M.use_xml_format = true
|
||||||
|
|
||||||
M.load_model_handler = function()
|
M.load_model_handler = function()
|
||||||
local base, _ = P.parse_config(P["bedrock"])
|
local provider_conf, _ = P.parse_config(P["bedrock"])
|
||||||
local bedrock_model = base.model
|
local bedrock_model = provider_conf.model
|
||||||
if base.model:match("anthropic") then bedrock_model = "claude" end
|
if provider_conf.model:match("anthropic") then bedrock_model = "claude" end
|
||||||
|
|
||||||
local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model)
|
local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model)
|
||||||
if ok then
|
if ok then return model_module end
|
||||||
return model_module
|
local error_msg = "Bedrock model handler not found: " .. bedrock_model
|
||||||
else
|
error(error_msg)
|
||||||
local error_msg = "Bedrock model handler not found: " .. bedrock_model
|
|
||||||
Utils.error(error_msg, { once = true, title = "Avante" })
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
M.parse_response = function(ctx, data_stream, event_state, opts)
|
M.parse_response = function(ctx, data_stream, event_state, opts)
|
||||||
@ -46,8 +42,8 @@ M.parse_stream_data = function(data, opts)
|
|||||||
-- The `type` field in the decoded JSON determines how the response is handled.
|
-- The `type` field in the decoded JSON determines how the response is handled.
|
||||||
local bedrock_match = data:gmatch("event(%b{})")
|
local bedrock_match = data:gmatch("event(%b{})")
|
||||||
for bedrock_data_match in bedrock_match do
|
for bedrock_data_match in bedrock_match do
|
||||||
local data = vim.json.decode(bedrock_data_match)
|
local jsn = vim.json.decode(bedrock_data_match)
|
||||||
local data_stream = vim.base64.decode(data.bytes)
|
local data_stream = vim.base64.decode(jsn.bytes)
|
||||||
local json = vim.json.decode(data_stream)
|
local json = vim.json.decode(data_stream)
|
||||||
M.parse_response({}, data_stream, json.type, opts)
|
M.parse_response({}, data_stream, json.type, opts)
|
||||||
end
|
end
|
||||||
@ -60,6 +56,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
local base, body_opts = P.parse_config(provider)
|
local base, body_opts = P.parse_config(provider)
|
||||||
|
|
||||||
local api_key = provider.parse_api_key()
|
local api_key = provider.parse_api_key()
|
||||||
|
if api_key == nil then error("Cannot get the bedrock api key!") end
|
||||||
local parts = vim.split(api_key, ",")
|
local parts = vim.split(api_key, ",")
|
||||||
local aws_access_key_id = parts[1]
|
local aws_access_key_id = parts[1]
|
||||||
local aws_secret_access_key = parts[2]
|
local aws_secret_access_key = parts[2]
|
||||||
@ -108,7 +105,6 @@ M.on_error = function(result)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local error_msg = body.error.message
|
local error_msg = body.error.message
|
||||||
local error_type = body.error.type
|
|
||||||
|
|
||||||
Utils.error(error_msg, { once = true, title = "Avante" })
|
Utils.error(error_msg, { once = true, title = "Avante" })
|
||||||
end
|
end
|
||||||
|
@ -226,7 +226,7 @@ end
|
|||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return table
|
||||||
M.parse_curl_args = function(provider, prompt_opts)
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
@ -234,7 +234,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
["anthropic-beta"] = "prompt-caching-2024-07-31",
|
||||||
}
|
}
|
||||||
|
|
||||||
if P.env.require_api_key(base) then headers["x-api-key"] = provider.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
|
||||||
|
|
||||||
local messages = M.parse_messages(prompt_opts)
|
local messages = M.parse_messages(prompt_opts)
|
||||||
|
|
||||||
@ -246,12 +246,12 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(base.endpoint, "/v1/messages"),
|
url = Utils.url_join(provider_conf.endpoint, "/v1/messages"),
|
||||||
proxy = base.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = provider_conf.model,
|
||||||
system = {
|
system = {
|
||||||
{
|
{
|
||||||
type = "text",
|
type = "text",
|
||||||
@ -262,7 +262,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
messages = messages,
|
messages = messages,
|
||||||
tools = tools,
|
tools = tools,
|
||||||
stream = true,
|
stream = true,
|
||||||
}, body_opts),
|
}, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ M.parse_stream_data = function(data, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
M.parse_curl_args = function(provider, prompt_opts)
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
["Accept"] = "application/json",
|
["Accept"] = "application/json",
|
||||||
@ -82,17 +82,17 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
.. "."
|
.. "."
|
||||||
.. vim.version().patch,
|
.. vim.version().patch,
|
||||||
}
|
}
|
||||||
if P.env.require_api_key(base) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(base.endpoint, "/chat"),
|
url = Utils.url_join(provider_conf.endpoint, "/chat"),
|
||||||
proxy = base.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = provider_conf.model,
|
||||||
stream = true,
|
stream = true,
|
||||||
}, M.parse_messages(prompt_opts), body_opts),
|
}, M.parse_messages(prompt_opts), request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
-- (this should rarely happen, as we refresh the token in the background)
|
-- (this should rarely happen, as we refresh the token in the background)
|
||||||
H.refresh_token(false, false)
|
H.refresh_token(false, false)
|
||||||
|
|
||||||
local base, body_opts = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
local tools = {}
|
local tools = {}
|
||||||
if prompt_opts.tools then
|
if prompt_opts.tools then
|
||||||
@ -259,10 +259,10 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = H.chat_completion_url(base.endpoint),
|
url = H.chat_completion_url(provider_conf.endpoint),
|
||||||
timeout = base.timeout,
|
timeout = provider_conf.timeout,
|
||||||
proxy = base.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = {
|
headers = {
|
||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
["Authorization"] = "Bearer " .. M.state.github_token.token,
|
||||||
@ -270,11 +270,11 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch),
|
||||||
},
|
},
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = provider_conf.model,
|
||||||
messages = M.parse_messages(prompt_opts),
|
messages = M.parse_messages(prompt_opts),
|
||||||
stream = true,
|
stream = true,
|
||||||
tools = tools,
|
tools = tools,
|
||||||
}, body_opts),
|
}, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -82,26 +82,29 @@ M.parse_response = function(ctx, data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
M.parse_curl_args = function(provider, prompt_opts)
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
body_opts = vim.tbl_deep_extend("force", body_opts, {
|
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||||
generationConfig = {
|
generationConfig = {
|
||||||
temperature = body_opts.temperature,
|
temperature = request_body.temperature,
|
||||||
maxOutputTokens = body_opts.max_tokens,
|
maxOutputTokens = request_body.max_tokens,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
body_opts.temperature = nil
|
request_body.temperature = nil
|
||||||
body_opts.max_tokens = nil
|
request_body.max_tokens = nil
|
||||||
|
|
||||||
local api_key = provider.parse_api_key()
|
local api_key = provider.parse_api_key()
|
||||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(base.endpoint, base.model .. ":streamGenerateContent?alt=sse&key=" .. api_key),
|
url = Utils.url_join(
|
||||||
proxy = base.proxy,
|
provider_conf.endpoint,
|
||||||
insecure = base.allow_insecure,
|
provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key
|
||||||
|
),
|
||||||
|
proxy = provider_conf.proxy,
|
||||||
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = { ["Content-Type"] = "application/json" },
|
headers = { ["Content-Type"] = "application/json" },
|
||||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
|
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -346,9 +346,9 @@ M = setmetatable(M, {
|
|||||||
if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end
|
if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end
|
||||||
|
|
||||||
if t[k].setup == nil then
|
if t[k].setup == nil then
|
||||||
local base = M.parse_config(t[k])
|
local provider_conf = M.parse_config(t[k])
|
||||||
t[k].setup = function()
|
t[k].setup = function()
|
||||||
if E.require_api_key(base) then t[k].parse_api_key() end
|
if E.require_api_key(provider_conf) then t[k].parse_api_key() end
|
||||||
require("avante.tokenizers").setup(t[k].tokenizer_id)
|
require("avante.tokenizers").setup(t[k].tokenizer_id)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -275,14 +275,14 @@ M.parse_response_without_stream = function(data, _, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
M.parse_curl_args = function(provider, prompt_opts)
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
local disable_tools = base.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
["Content-Type"] = "application/json",
|
["Content-Type"] = "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
if P.env.require_api_key(base) then
|
if P.env.require_api_key(provider_conf) then
|
||||||
local api_key = provider.parse_api_key()
|
local api_key = provider.parse_api_key()
|
||||||
if api_key == nil then
|
if api_key == nil then
|
||||||
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
error(Config.provider .. " API key is not set, please set it in your environment variable or config file")
|
||||||
@ -290,18 +290,18 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
headers["Authorization"] = "Bearer " .. api_key
|
headers["Authorization"] = "Bearer " .. api_key
|
||||||
end
|
end
|
||||||
|
|
||||||
if M.is_openrouter(base.endpoint) then
|
if M.is_openrouter(provider_conf.endpoint) then
|
||||||
headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim"
|
headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim"
|
||||||
headers["X-Title"] = "Avante.nvim"
|
headers["X-Title"] = "Avante.nvim"
|
||||||
body_opts.include_reasoning = true
|
request_body.include_reasoning = true
|
||||||
end
|
end
|
||||||
|
|
||||||
-- NOTE: When using "o" series set the supported parameters only
|
-- NOTE: When using "o" series set the supported parameters only
|
||||||
local stream = true
|
local stream = true
|
||||||
if M.is_o_series_model(base.model) then
|
if M.is_o_series_model(provider_conf.model) then
|
||||||
body_opts.max_completion_tokens = body_opts.max_tokens
|
request_body.max_completion_tokens = request_body.max_tokens
|
||||||
body_opts.max_tokens = nil
|
request_body.max_tokens = nil
|
||||||
body_opts.temperature = 1
|
request_body.temperature = 1
|
||||||
end
|
end
|
||||||
|
|
||||||
local tools = nil
|
local tools = nil
|
||||||
@ -312,20 +312,20 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Utils.debug("endpoint", base.endpoint)
|
Utils.debug("endpoint", provider_conf.endpoint)
|
||||||
Utils.debug("model", base.model)
|
Utils.debug("model", provider_conf.model)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
url = Utils.url_join(base.endpoint, "/chat/completions"),
|
url = Utils.url_join(provider_conf.endpoint, "/chat/completions"),
|
||||||
proxy = base.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = base.model,
|
model = provider_conf.model,
|
||||||
messages = M.parse_messages(prompt_opts),
|
messages = M.parse_messages(prompt_opts),
|
||||||
stream = stream,
|
stream = stream,
|
||||||
tools = tools,
|
tools = tools,
|
||||||
}, body_opts),
|
}, request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -32,22 +32,22 @@ M.parse_api_key = function()
|
|||||||
end
|
end
|
||||||
|
|
||||||
M.parse_curl_args = function(provider, prompt_opts)
|
M.parse_curl_args = function(provider, prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
local location = vim.fn.getenv("LOCATION") or "default-location"
|
local location = vim.fn.getenv("LOCATION") or "default-location"
|
||||||
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
|
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
|
||||||
local model_id = base.model or "default-model-id"
|
local model_id = provider_conf.model or "default-model-id"
|
||||||
local url = base.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
||||||
|
|
||||||
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
||||||
|
|
||||||
body_opts = vim.tbl_deep_extend("force", body_opts, {
|
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||||
generationConfig = {
|
generationConfig = {
|
||||||
temperature = body_opts.temperature,
|
temperature = request_body.temperature,
|
||||||
maxOutputTokens = body_opts.max_tokens,
|
maxOutputTokens = request_body.max_tokens,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
body_opts.temperature = nil
|
request_body.temperature = nil
|
||||||
body_opts.max_tokens = nil
|
request_body.max_tokens = nil
|
||||||
local bearer_token = M.parse_api_key()
|
local bearer_token = M.parse_api_key()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -56,9 +56,9 @@ M.parse_curl_args = function(provider, prompt_opts)
|
|||||||
["Authorization"] = "Bearer " .. bearer_token,
|
["Authorization"] = "Bearer " .. bearer_token,
|
||||||
["Content-Type"] = "application/json; charset=utf-8",
|
["Content-Type"] = "application/json; charset=utf-8",
|
||||||
},
|
},
|
||||||
proxy = base.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = base.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts),
|
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -219,39 +219,31 @@ local function transform_result_content(selected_files, result_content, prev_fil
|
|||||||
while true do
|
while true do
|
||||||
if i > #result_lines then break end
|
if i > #result_lines then break end
|
||||||
local line_content = result_lines[i]
|
local line_content = result_lines[i]
|
||||||
if line_content:match("<FILEPATH>.+</FILEPATH>") then
|
if line_content:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>.+</[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>") then
|
||||||
local filepath = line_content:match("<FILEPATH>(.+)</FILEPATH>")
|
local filepath = line_content:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>(.+)</[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>")
|
||||||
if filepath then
|
if filepath then
|
||||||
current_filepath = filepath
|
current_filepath = filepath
|
||||||
table.insert(transformed_lines, string.format("Filepath: %s", filepath))
|
table.insert(transformed_lines, string.format("Filepath: %s", filepath))
|
||||||
goto continue
|
goto continue
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if line_content:match("<filepath>.+</filepath>") then
|
if line_content:match("^%s*<[Ss][Ee][Aa][Rr][Cc][Hh]>") then
|
||||||
local filepath = line_content:match("<filepath>(.+)</filepath>")
|
|
||||||
if filepath then
|
|
||||||
current_filepath = filepath
|
|
||||||
table.insert(transformed_lines, string.format("Filepath: %s", filepath))
|
|
||||||
goto continue
|
|
||||||
end
|
|
||||||
end
|
|
||||||
if line_content:match("^%s*<SEARCH>") then
|
|
||||||
is_searching = true
|
is_searching = true
|
||||||
|
|
||||||
if not line_content:match("^%s*<SEARCH>%s*$") then
|
if not line_content:match("^%s*<[Ss][Ee][Aa][Rr][Cc][Hh]>%s*$") then
|
||||||
local search_start_line = line_content:match("<SEARCH>(.+)$")
|
local search_start_line = line_content:match("<[Ss][Ee][Aa][Rr][Cc][Hh]>(.+)$")
|
||||||
line_content = "<SEARCH>"
|
line_content = "<SEARCH>"
|
||||||
result_lines[i] = line_content
|
result_lines[i] = line_content
|
||||||
if search_start_line and search_start_line ~= "" then table.insert(result_lines, i + 1, search_start_line) end
|
if search_start_line and search_start_line ~= "" then table.insert(result_lines, i + 1, search_start_line) end
|
||||||
end
|
end
|
||||||
|
line_content = "<SEARCH>"
|
||||||
|
|
||||||
local prev_line = result_lines[i - 1]
|
local prev_line = result_lines[i - 1]
|
||||||
if
|
if
|
||||||
prev_line
|
prev_line
|
||||||
and prev_filepath
|
and prev_filepath
|
||||||
and not prev_line:match("Filepath:.+")
|
and not prev_line:match("Filepath:.+")
|
||||||
and not prev_line:match("<FILEPATH>.+</FILEPATH>")
|
and not prev_line:match("<[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>.+</[Ff][Ii][Ll][Ee][Pp][Aa][Tt][Hh]>")
|
||||||
and not prev_line:match("<filepath>.+</filepath>")
|
|
||||||
then
|
then
|
||||||
table.insert(transformed_lines, string.format("Filepath: %s", prev_filepath))
|
table.insert(transformed_lines, string.format("Filepath: %s", prev_filepath))
|
||||||
end
|
end
|
||||||
@ -259,15 +251,15 @@ local function transform_result_content(selected_files, result_content, prev_fil
|
|||||||
if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end
|
if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end
|
||||||
search_start = i + 1
|
search_start = i + 1
|
||||||
last_search_tag_start_line = i
|
last_search_tag_start_line = i
|
||||||
elseif line_content:match("</SEARCH>%s*$") then
|
elseif line_content:match("</[Ss][Ee][Aa][Rr][Cc][Hh]>%s*$") then
|
||||||
if is_replacing then
|
if is_replacing then
|
||||||
result_lines[i] = line_content:gsub("</SEARCH>", "</REPLACE>")
|
result_lines[i] = line_content:gsub("</[Ss][Ee][Aa][Rr][Cc][Hh]>", "</REPLACE>")
|
||||||
goto continue_without_increment
|
goto continue_without_increment
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Handle case where </SEARCH> is a suffix
|
-- Handle case where </SEARCH> is a suffix
|
||||||
if not line_content:match("^%s*</SEARCH>%s*$") then
|
if not line_content:match("^%s*</[Ss][Ee][Aa][Rr][Cc][Hh]>%s*$") then
|
||||||
local search_end_line = line_content:match("^(.+)</SEARCH>")
|
local search_end_line = line_content:match("^(.+)</[Ss][Ee][Aa][Rr][Cc][Hh]>")
|
||||||
line_content = "</SEARCH>"
|
line_content = "</SEARCH>"
|
||||||
result_lines[i] = line_content
|
result_lines[i] = line_content
|
||||||
if search_end_line and search_end_line ~= "" then
|
if search_end_line and search_end_line ~= "" then
|
||||||
@ -364,10 +356,10 @@ local function transform_result_content(selected_files, result_content, prev_fil
|
|||||||
string.format("```%s", match_filetype),
|
string.format("```%s", match_filetype),
|
||||||
})
|
})
|
||||||
goto continue
|
goto continue
|
||||||
elseif line_content:match("^%s*<REPLACE>") then
|
elseif line_content:match("^%s*<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>") then
|
||||||
is_replacing = true
|
is_replacing = true
|
||||||
if not line_content:match("^%s*<REPLACE>%s*$") then
|
if not line_content:match("^%s*<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>%s*$") then
|
||||||
local replace_first_line = line_content:match("<REPLACE>(.+)$")
|
local replace_first_line = line_content:match("<[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>(.+)$")
|
||||||
line_content = "<REPLACE>"
|
line_content = "<REPLACE>"
|
||||||
result_lines[i] = line_content
|
result_lines[i] = line_content
|
||||||
if replace_first_line and replace_first_line ~= "" then
|
if replace_first_line and replace_first_line ~= "" then
|
||||||
@ -378,10 +370,10 @@ local function transform_result_content(selected_files, result_content, prev_fil
|
|||||||
if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end
|
if next_line and next_line:match("^%s*```%w+$") then i = i + 1 end
|
||||||
last_replace_tag_start_line = i
|
last_replace_tag_start_line = i
|
||||||
goto continue
|
goto continue
|
||||||
elseif line_content:match("</REPLACE>%s*$") then
|
elseif line_content:match("</[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>%s*$") then
|
||||||
-- Handle case where </REPLACE> is a suffix
|
-- Handle case where </REPLACE> is a suffix
|
||||||
if not line_content:match("^%s*</REPLACE>%s*$") then
|
if not line_content:match("^%s*</[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>%s*$") then
|
||||||
local replace_end_line = line_content:match("^(.+)</REPLACE>")
|
local replace_end_line = line_content:match("^(.+)</[Rr][Ee][Pp][Ll][Aa][Cc][Ee]>")
|
||||||
line_content = "</REPLACE>"
|
line_content = "</REPLACE>"
|
||||||
result_lines[i] = line_content
|
result_lines[i] = line_content
|
||||||
if replace_end_line and replace_end_line ~= "" then
|
if replace_end_line and replace_end_line ~= "" then
|
||||||
@ -769,27 +761,22 @@ local function minimize_snippet(original_lines, snippet)
|
|||||||
return new_snippets
|
return new_snippets
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param snippets_map table<string, AvanteCodeSnippet[]>
|
---@param filepath string
|
||||||
|
---@param snippets AvanteCodeSnippet[]
|
||||||
---@return table<string, AvanteCodeSnippet[]>
|
---@return table<string, AvanteCodeSnippet[]>
|
||||||
function Sidebar:minimize_snippets(snippets_map)
|
function Sidebar:minimize_snippets(filepath, snippets)
|
||||||
local original_lines = {}
|
local original_lines = {}
|
||||||
|
|
||||||
if vim.tbl_count(snippets_map) > 0 then
|
local original_lines_ = Utils.read_file_from_buf_or_disk(filepath)
|
||||||
local filepaths = vim.tbl_keys(snippets_map)
|
if original_lines_ then original_lines = original_lines_ end
|
||||||
local original_lines_ = Utils.read_file_from_buf_or_disk(filepaths[1])
|
|
||||||
if original_lines_ then original_lines = original_lines_ end
|
|
||||||
end
|
|
||||||
|
|
||||||
local results = {}
|
local results = {}
|
||||||
|
|
||||||
for filepath, snippets in pairs(snippets_map) do
|
for _, snippet in ipairs(snippets) do
|
||||||
for _, snippet in ipairs(snippets) do
|
local new_snippets = minimize_snippet(original_lines, snippet)
|
||||||
local new_snippets = minimize_snippet(original_lines, snippet)
|
if new_snippets then
|
||||||
if new_snippets then
|
for _, new_snippet in ipairs(new_snippets) do
|
||||||
results[filepath] = results[filepath] or {}
|
table.insert(results, new_snippet)
|
||||||
for _, new_snippet in ipairs(new_snippets) do
|
|
||||||
table.insert(results[filepath], new_snippet)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -822,11 +809,10 @@ function Sidebar:apply(current_cursor)
|
|||||||
selected_snippets_map = all_snippets_map
|
selected_snippets_map = all_snippets_map
|
||||||
end
|
end
|
||||||
|
|
||||||
if Config.behaviour.minimize_diff then selected_snippets_map = self:minimize_snippets(selected_snippets_map) end
|
|
||||||
|
|
||||||
vim.defer_fn(function()
|
vim.defer_fn(function()
|
||||||
api.nvim_set_current_win(self.code.winid)
|
api.nvim_set_current_win(self.code.winid)
|
||||||
for filepath, snippets in pairs(selected_snippets_map) do
|
for filepath, snippets in pairs(selected_snippets_map) do
|
||||||
|
if Config.behaviour.minimize_diff then snippets = self:minimize_snippets(filepath, snippets) end
|
||||||
local bufnr = Utils.get_or_create_buffer_with_filepath(filepath)
|
local bufnr = Utils.get_or_create_buffer_with_filepath(filepath)
|
||||||
local path_ = PPath:new(filepath)
|
local path_ = PPath:new(filepath)
|
||||||
path_:parent():mkdir({ parents = true, exists_ok = true })
|
path_:parent():mkdir({ parents = true, exists_ok = true })
|
||||||
|
@ -20,6 +20,7 @@ Tools Usage Guide:
|
|||||||
- When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly!
|
- When attempting to modify a file that is not in the context, please first use the `list_files` tool and `search_files` tool to check if the file you want to modify exists, then use the `read_file` tool to read the file content. Don't modify blindly!
|
||||||
- When generating files, first use `list_files` tool to read the directory structure, don't generate blindly!
|
- When generating files, first use `list_files` tool to read the directory structure, don't generate blindly!
|
||||||
- When creating files, first check if the directory exists. If it doesn't exist, create the directory before creating the file.
|
- When creating files, first check if the directory exists. If it doesn't exist, create the directory before creating the file.
|
||||||
|
- After `web_search`, if you don't get detailed enough information, do not continue use `web_search`, just continue using the `fetch` tool to get more information you need from the links in the search results.
|
||||||
|
|
||||||
{% if system_info -%}
|
{% if system_info -%}
|
||||||
Use the appropriate shell based on the user's system info:
|
Use the appropriate shell based on the user's system info:
|
||||||
|
@ -665,6 +665,19 @@ function M.scan_directory_respect_gitignore(options)
|
|||||||
local directory = options.directory
|
local directory = options.directory
|
||||||
local gitignore_path = directory .. "/.gitignore"
|
local gitignore_path = directory .. "/.gitignore"
|
||||||
local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path)
|
local gitignore_patterns, gitignore_negate_patterns = M.parse_gitignore(gitignore_path)
|
||||||
|
|
||||||
|
-- Convert relative paths in gitignore to absolute paths based on project root
|
||||||
|
local project_root = M.get_project_root()
|
||||||
|
local function to_absolute_path(pattern)
|
||||||
|
-- Skip if already absolute path
|
||||||
|
if pattern:sub(1, 1) == "/" then return pattern end
|
||||||
|
-- Convert relative path to absolute
|
||||||
|
return Path:new(project_root, pattern):absolute()
|
||||||
|
end
|
||||||
|
|
||||||
|
gitignore_patterns = vim.tbl_map(to_absolute_path, gitignore_patterns)
|
||||||
|
gitignore_negate_patterns = vim.tbl_map(to_absolute_path, gitignore_negate_patterns)
|
||||||
|
|
||||||
return M.scan_directory({
|
return M.scan_directory({
|
||||||
directory = directory,
|
directory = directory,
|
||||||
gitignore_patterns = gitignore_patterns,
|
gitignore_patterns = gitignore_patterns,
|
||||||
@ -895,7 +908,7 @@ function M.get_filetype(filepath)
|
|||||||
-- https://github.com/neovim/neovim/issues/27265
|
-- https://github.com/neovim/neovim/issues/27265
|
||||||
|
|
||||||
local buf = vim.api.nvim_create_buf(false, true)
|
local buf = vim.api.nvim_create_buf(false, true)
|
||||||
local filetype = vim.filetype.match({ filename = filepath, buf = buf })
|
local filetype = vim.filetype.match({ filename = filepath, buf = buf }) or ""
|
||||||
vim.api.nvim_buf_delete(buf, { force = true })
|
vim.api.nvim_buf_delete(buf, { force = true })
|
||||||
|
|
||||||
return filetype
|
return filetype
|
||||||
|
Loading…
x
Reference in New Issue
Block a user