diff --git a/lua/avante/config.lua b/lua/avante/config.lua new file mode 100644 index 0000000..8d55518 --- /dev/null +++ b/lua/avante/config.lua @@ -0,0 +1,51 @@ +local M = {} + +local config = { + provider = "claude", -- "claude" or "openai" or "azure" + openai = { + endpoint = "https://api.openai.com", + model = "gpt-4o", + temperature = 0, + max_tokens = 4096, + }, + azure = { + endpoint = "", -- example: "https://.openai.azure.com" + deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") + api_version = "2024-06-01", + temperature = 0, + max_tokens = 4096, + }, + claude = { + endpoint = "https://api.anthropic.com", + model = "claude-3-5-sonnet-20240620", + temperature = 0, + max_tokens = 4096, + }, + highlights = { + diff = { + current = "DiffText", -- need have background color + incoming = "DiffAdd", -- need have background color + }, + }, + mappings = { + show_sidebar = "aa", + diff = { + ours = "co", + theirs = "ct", + none = "c0", + both = "cb", + next = "]x", + prev = "[x", + }, + }, +} + +function M.update(opts) + config = vim.tbl_deep_extend("force", config, opts or {}) +end + +function M.get() + return config +end + +return M diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 8fc814a..667297c 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -1,824 +1,10 @@ local M = {} -local curl = require("plenary.curl") -local Path = require("plenary.path") -local n = require("nui-components") -local diff = require("avante.diff") -local utils = require("avante.utils") -local tiktoken = require("avante.tiktoken") -local api = vim.api -local fn = vim.fn - -local RESULT_BUF_NAME = "AVANTE_RESULT" -local CONFLICT_BUF_NAME = "AVANTE_CONFLICT" - -local function create_result_buf() - local buf = api.nvim_create_buf(false, true) - api.nvim_set_option_value("filetype", "markdown", { buf = buf }) - api.nvim_set_option_value("buftype", "nofile", { buf = buf }) - api.nvim_set_option_value("swapfile", false, { buf = buf }) - api.nvim_set_option_value("modifiable", false, { buf = buf }) - api.nvim_set_option_value("bufhidden", "wipe", { buf = buf }) - api.nvim_buf_set_name(buf, RESULT_BUF_NAME) - return buf -end - -local result_buf = create_result_buf() - -local function is_code_buf(buf) - local ignored_filetypes = { - "dashboard", - "alpha", - "neo-tree", - "NvimTree", - "TelescopePrompt", - "Prompt", - "qf", - "help", - } - - if api.nvim_buf_is_valid(buf) and api.nvim_get_option_value("buflisted", { buf = buf }) then - local buftype = api.nvim_get_option_value("buftype", { buf = buf }) - local filetype = api.nvim_get_option_value("filetype", { buf = buf }) - - if buftype == "" and filetype ~= "" and not vim.tbl_contains(ignored_filetypes, filetype) then - local bufname = api.nvim_buf_get_name(buf) - if bufname ~= "" and bufname ~= RESULT_BUF_NAME and bufname ~= CONFLICT_BUF_NAME then - return true - end - end - end - - return false -end - -local _cur_code_buf = nil - -local function get_cur_code_buf() - return _cur_code_buf -end - -local function get_cur_code_buf_name() - local code_buf = get_cur_code_buf() - if code_buf == nil then - print("Error: cannot get code buffer") - return - end - return api.nvim_buf_get_name(code_buf) -end - -local function get_cur_code_win() - local code_buf = get_cur_code_buf() - if code_buf == nil then - print("Error: cannot get code buffer") - return - end - return fn.bufwinid(code_buf) -end - -local function get_cur_code_buf_content() - local code_buf = get_cur_code_buf() - if code_buf == nil then - print("Error: cannot get code buffer") - return {} - end - return api.nvim_buf_get_lines(code_buf, 0, -1, false) -end - -local function prepend_line_number(content) - local lines = vim.split(content, "\n") - local result = {} - for i, line in ipairs(lines) do - table.insert(result, "L" .. i .. ": " .. line) - end - return table.concat(result, "\n") -end - -local function extract_code_snippets(content) - local snippets = {} - local current_snippet = {} - local in_code_block = false - local lang, start_line, end_line - local explanation = "" - - for line in content:gmatch("[^\r\n]+") do - local start_line_str, end_line_str = line:match("^Replace lines: (%d+)-(%d+)") - if start_line_str ~= nil and end_line_str ~= nil then - start_line = tonumber(start_line_str) - end_line = tonumber(end_line_str) - end - if line:match("^```") then - if in_code_block then - if start_line ~= nil and end_line ~= nil then - table.insert(snippets, { - range = { start_line, end_line }, - content = table.concat(current_snippet, "\n"), - lang = lang, - explanation = explanation, - }) - end - current_snippet = {} - start_line, end_line = nil, nil - explanation = "" - in_code_block = false - else - lang = line:match("^```(%w+)") - if not lang or lang == "" then - lang = "text" - end - in_code_block = true - end - elseif in_code_block then - table.insert(current_snippet, line) - else - explanation = explanation .. line .. "\n" - end - end - - return snippets -end - -local system_prompt = [[ -You are an excellent programming expert. -]] - -local base_user_prompt = [[ -Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: - -1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones. - -2. When suggesting modifications: - a. Explain why the change is necessary or beneficial. - b. Provide the exact code snippet to be replaced using this format: - -Replace lines: {{start_line}}-{{end_line}} -```{{language}} -{{suggested_code}} -``` - -3. Crucial guidelines for line numbers: - - The range {{start_line}}-{{end_line}} is INCLUSIVE. Both start_line and end_line are included in the replacement. - - Count EVERY line, including empty lines, comments, and the LAST line of the file. - - For single-line changes, use the same number for start and end lines. - - For multi-line changes, ensure the range covers ALL affected lines, from the very first to the very last. - - Include the entire block (e.g., complete function) when modifying structured code. - - Pay special attention to the start_line, ensuring it's not omitted or incorrectly set. - - Double-check that your start_line is correct, especially for changes at the beginning of the file. - - Also, be careful with the end_line, especially when it's the last line of the file. - - Double-check that your line numbers align perfectly with the original code structure. - -4. Context and verification: - - Show 1-2 unchanged lines before and after each modification as context. - - These context lines are NOT included in the replacement range. - - After each suggestion, recount the lines to verify the accuracy of your line numbers. - - Double-check that both the start_line and end_line are correct for each modification. - - Verify that your suggested changes align perfectly with the original code structure. - -5. Final check: - - Review all suggestions, ensuring each line number is correct, especially the start_line and end_line. - - Pay extra attention to the start_line of each modification, ensuring it hasn't shifted down. - - Confirm that no unrelated code is accidentally modified or deleted. - - Verify that the start_line and end_line correctly include all intended lines for replacement. - - If a modification involves the first or last line of the file, explicitly state this in your explanation. - - Perform a final alignment check to ensure your line numbers haven't shifted, especially the start_line. - - Double-check that your line numbers align perfectly with the original code structure. - - Do not show the content after these modifications. - -Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. -]] - -local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) - local api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key then - error("ANTHROPIC_API_KEY environment variable is not set") - end - - local user_prompt = base_user_prompt - - local tokens = M.config.claude.max_tokens - local headers = { - ["Content-Type"] = "application/json", - ["x-api-key"] = api_key, - ["anthropic-version"] = "2023-06-01", - ["anthropic-beta"] = "prompt-caching-2024-07-31", - } - - local code_prompt_obj = { - type = "text", - text = string.format("```%s\n%s```", code_lang, code_content), - } - - local user_prompt_obj = { - type = "text", - text = user_prompt, - } - - if tiktoken.count(code_prompt_obj.text) > 1024 then - code_prompt_obj.cache_control = { type = "ephemeral" } - end - - if tiktoken.count(user_prompt_obj.text) > 1024 then - user_prompt_obj.cache_control = { type = "ephemeral" } - end - - local body = { - model = M.config.claude.model, - system = system_prompt, - messages = { - { - role = "user", - content = { - code_prompt_obj, - { - type = "text", - text = string.format("%s", question), - }, - user_prompt_obj, - }, - }, - }, - stream = true, - temperature = M.config.claude.temperature, - max_tokens = tokens, - } - - local url = utils.trim_suffix(M.config.claude.endpoint, "/") .. "/v1/messages" - - print("Sending request to Claude API...") - - curl.post(url, { - ---@diagnostic disable-next-line: unused-local - stream = function(err, data, job) - if err then - on_complete(err) - return - end - if not data then - return - end - for line in data:gmatch("[^\r\n]+") do - if line:sub(1, 6) ~= "data: " then - return - end - vim.schedule(function() - local success, parsed = pcall(fn.json_decode, line:sub(7)) - if not success then - error("Error: failed to parse json: " .. parsed) - return - end - if parsed and parsed.type == "content_block_delta" then - on_chunk(parsed.delta.text) - elseif parsed and parsed.type == "message_stop" then - -- Stream request completed - on_complete(nil) - elseif parsed and parsed.type == "error" then - -- Stream request completed - on_complete(parsed) - end - end) - end - end, - headers = headers, - body = fn.json_encode(body), - }) -end - -local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - local api_key = os.getenv("OPENAI_API_KEY") - if not api_key then - error("OPENAI_API_KEY environment variable is not set") - end - - local user_prompt = base_user_prompt - .. "\n\nQUESTION:\n" - .. question - .. "\n\nCODE:\n" - .. "```" - .. code_lang - .. "\n" - .. code_content - .. "\n```" - - local url, headers, body - if M.config.provider == "azure" then - api_key = os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") - if not api_key then - error("Azure OpenAI API key is not set. Please set AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable.") - end - url = M.config.azure.endpoint - .. "/openai/deployments/" - .. M.config.azure.deployment - .. "/chat/completions?api-version=" - .. M.config.azure.api_version - headers = { - ["Content-Type"] = "application/json", - ["api-key"] = api_key, - } - body = { - messages = { - { role = "system", content = system_prompt }, - { role = "user", content = user_prompt }, - }, - temperature = M.config.azure.temperature, - max_tokens = M.config.azure.max_tokens, - stream = true, - } - else - url = utils.trim_suffix(M.config.openai.endpoint, "/") .. "/v1/chat/completions" - headers = { - ["Content-Type"] = "application/json", - ["Authorization"] = "Bearer " .. api_key, - } - body = { - model = M.config.openai.model, - messages = { - { role = "system", content = system_prompt }, - { role = "user", content = user_prompt }, - }, - temperature = M.config.openai.temperature, - max_tokens = M.config.openai.max_tokens, - stream = true, - } - end - - print("Sending request to " .. (M.config.provider == "azure" and "Azure OpenAI" or "OpenAI") .. " API...") - - curl.post(url, { - ---@diagnostic disable-next-line: unused-local - stream = function(err, data, job) - if err then - on_complete(err) - return - end - if not data then - return - end - for line in data:gmatch("[^\r\n]+") do - if line:sub(1, 6) ~= "data: " then - return - end - vim.schedule(function() - local piece = line:sub(7) - local success, parsed = pcall(fn.json_decode, piece) - if not success then - if piece == "[DONE]" then - on_complete(nil) - return - end - error("Error: failed to parse json: " .. parsed) - return - end - if parsed and parsed.choices and parsed.choices[1].delta.content then - on_chunk(parsed.choices[1].delta.content) - elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then - -- Stream request completed - on_complete(nil) - end - end) - end - end, - headers = headers, - body = fn.json_encode(body), - }) -end - -local function call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - if M.config.provider == "openai" or M.config.provider == "azure" then - call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - elseif M.config.provider == "claude" then - call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) - end -end - -local function update_result_buf_content(content) - local current_win = api.nvim_get_current_win() - local result_win = fn.bufwinid(result_buf) - - vim.defer_fn(function() - api.nvim_set_option_value("modifiable", true, { buf = result_buf }) - api.nvim_buf_set_lines(result_buf, 0, -1, false, vim.split(content, "\n")) - api.nvim_set_option_value("filetype", "markdown", { buf = result_buf }) - if result_win ~= -1 then - -- Move to the bottom - api.nvim_win_set_cursor(result_win, { api.nvim_buf_line_count(result_buf), 0 }) - api.nvim_set_option_value("modifiable", false, { buf = result_buf }) - api.nvim_set_current_win(current_win) - end - end, 0) -end - --- Add a new function to display notifications -local function show_notification(message) - vim.notify(message, vim.log.levels.INFO, { - title = "AI Assistant", - timeout = 3000, - }) -end - --- Function to get the current project root directory -local function get_project_root() - local current_file = fn.expand("%:p") - local current_dir = fn.fnamemodify(current_file, ":h") - local git_root = fn.systemlist("git -C " .. fn.shellescape(current_dir) .. " rev-parse --show-toplevel")[1] - return git_root or current_dir -end - -local function get_chat_history_filename() - local code_buf_name = get_cur_code_buf_name() - if code_buf_name == nil then - print("Error: cannot get code buffer name") - return - end - local relative_path = fn.fnamemodify(code_buf_name, ":~:.") - -- Replace path separators with double underscores - local path_with_separators = fn.substitute(relative_path, "/", "__", "g") - -- Replace other non-alphanumeric characters with single underscores - return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") -end - --- Function to get the chat history file path -local function get_chat_history_file() - local project_root = get_project_root() - local filename = get_chat_history_filename() - local history_dir = Path:new(project_root, ".avante_chat_history") - return history_dir:joinpath(filename .. ".json") -end - --- Function to get current timestamp -local function get_timestamp() - return os.date("%Y-%m-%d %H:%M:%S") -end - --- Function to load chat history -local function load_chat_history() - local history_file = get_chat_history_file() - if history_file:exists() then - local content = history_file:read() - return fn.json_decode(content) - end - return {} -end - --- Function to save chat history -local function save_chat_history(history) - local history_file = get_chat_history_file() - local history_dir = history_file:parent() - - -- Create the directory if it doesn't exist - if not history_dir:exists() then - history_dir:mkdir({ parents = true }) - end - - history_file:write(fn.json_encode(history), "w") -end - -local function update_result_buf_with_history(history) - local content = "" - for _, entry in ipairs(history) do - content = content .. "## " .. entry.timestamp .. "\n\n" - content = content .. "> " .. entry.requirement:gsub("\n", "\n> ") .. "\n\n" - content = content .. entry.response .. "\n\n" - content = content .. "---\n\n" - end - update_result_buf_content(content) -end - -local function trim_line_number_prefix(line) - return line:gsub("^L%d+: ", "") -end - -local function get_conflict_content(content, snippets) - -- sort snippets by start_line - table.sort(snippets, function(a, b) - return a.range[1] < b.range[1] - end) - - local lines = vim.split(content, "\n") - local result = {} - local current_line = 1 - - for _, snippet in ipairs(snippets) do - local start_line, end_line = unpack(snippet.range) - - while current_line < start_line do - table.insert(result, lines[current_line]) - current_line = current_line + 1 - end - - table.insert(result, "<<<<<<< HEAD") - for i = start_line, end_line do - table.insert(result, lines[i]) - end - table.insert(result, "=======") - - for _, line in ipairs(vim.split(snippet.content, "\n")) do - line = trim_line_number_prefix(line) - table.insert(result, line) - end - - table.insert(result, ">>>>>>> Snippet") - - current_line = end_line + 1 - end - - while current_line <= #lines do - table.insert(result, lines[current_line]) - current_line = current_line + 1 - end - - return result -end - -local get_renderer_size_and_position = function() - local renderer_width = math.ceil(vim.o.columns * 0.3) - local renderer_height = vim.o.lines - local renderer_position = vim.o.columns - renderer_width - return renderer_width, renderer_height, renderer_position -end - -function M.render_sidebar() - if result_buf ~= nil and api.nvim_buf_is_valid(result_buf) then - api.nvim_buf_delete(result_buf, { force = true }) - end - - result_buf = create_result_buf() - - local renderer_width, renderer_height, renderer_position = get_renderer_size_and_position() - - local renderer = n.create_renderer({ - width = renderer_width, - height = renderer_height, - position = renderer_position, - relative = "editor", - }) - - local autocmd_id - renderer:on_mount(function() - autocmd_id = api.nvim_create_autocmd("VimResized", { - callback = function() - local width, height, _ = get_renderer_size_and_position() - renderer:set_size({ width = width, height = height }) - end, - }) - end) - - renderer:on_unmount(function() - if autocmd_id ~= nil then - api.nvim_del_autocmd(autocmd_id) - end - end) - - local signal = n.create_signal({ - is_loading = false, - text = "", - }) - - local chat_history = load_chat_history() - update_result_buf_with_history(chat_history) - - local function handle_submit() - local state = signal:get_value() - local user_input = state.text - - local timestamp = get_timestamp() - update_result_buf_content( - "## " - .. timestamp - .. "\n\n> " - .. user_input:gsub("\n", "\n> ") - .. "\n\nGenerating response from " - .. M.config.provider - .. " ...\n" - ) - - local code_buf = get_cur_code_buf() - if code_buf == nil then - error("Error: cannot get code buffer") - return - end - local content = table.concat(get_cur_code_buf_content(), "\n") - local content_with_line_numbers = prepend_line_number(content) - local full_response = "" - - signal.is_loading = true - - local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) - - call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) - full_response = full_response .. chunk - update_result_buf_content( - "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response - ) - vim.schedule(function() - vim.cmd("redraw") - end) - end, function(err) - signal.is_loading = false - - if err ~= nil then - update_result_buf_content( - "## " - .. timestamp - .. "\n\n> " - .. user_input:gsub("\n", "\n> ") - .. "\n\n" - .. full_response - .. "\n\n**Error**: " - .. vim.inspect(err) - ) - return - end - - -- Execute when the stream request is actually completed - update_result_buf_content( - "## " - .. timestamp - .. "\n\n> " - .. user_input:gsub("\n", "\n> ") - .. "\n\n" - .. full_response - .. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n" - ) - - -- Display notification - show_notification("Content generation complete!") - - -- Save chat history - table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response }) - save_chat_history(chat_history) - - local snippets = extract_code_snippets(full_response) - local conflict_content = get_conflict_content(content, snippets) - - vim.defer_fn(function() - api.nvim_buf_set_lines(code_buf, 0, -1, false, conflict_content) - local code_win = get_cur_code_win() - if code_win == nil then - error("Error: cannot get code window") - return - end - api.nvim_set_current_win(code_win) - api.nvim_feedkeys(api.nvim_replace_termcodes("", true, false, true), "n", true) - diff.add_visited_buffer(code_buf) - diff.process(code_buf) - api.nvim_feedkeys("gg", "n", false) - vim.defer_fn(function() - vim.cmd("AvanteConflictNextConflict") - api.nvim_feedkeys("zz", "n", false) - end, 1000) - end, 10) - end) - end - - local body = function() - local code_buf = get_cur_code_buf() - if code_buf == nil then - error("Error: cannot get code buffer") - return - end - local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) - local icon = require("nvim-web-devicons").get_icon_by_filetype(filetype, {}) - local code_file_fullpath = api.nvim_buf_get_name(code_buf) - local code_filename = fn.fnamemodify(code_file_fullpath, ":t") - - return n.rows( - { flex = 0 }, - n.box( - { - direction = "column", - size = vim.o.lines - 4, - }, - n.buffer({ - id = "response", - flex = 1, - buf = result_buf, - autoscroll = true, - border_label = { - text = "šŸ’¬ Avante Chat", - align = "center", - }, - padding = { - top = 1, - bottom = 1, - left = 1, - right = 1, - }, - }) - ), - n.gap(1), - n.columns( - { flex = 0 }, - n.text_input({ - id = "text-input", - border_label = { - text = string.format(" šŸ™‹ Your question (with %s %s): ", icon, code_filename), - }, - autofocus = true, - wrap = true, - flex = 1, - on_change = function(value) - local state = signal:get_value() - local is_enter = value:sub(-1) == "\n" and #state.text < #value - if is_enter then - value = value:sub(1, -2) - end - signal.text = value - if is_enter and #value > 0 then - handle_submit() - end - end, - padding = { left = 1, right = 1 }, - }), - n.gap(1), - n.spinner({ - is_loading = signal.is_loading, - padding = { top = 1, right = 1 }, - ---@diagnostic disable-next-line: undefined-field - hidden = signal.is_loading:negate(), - }) - ) - ) - end - - renderer:render(body) -end - -M.config = { - provider = "claude", -- "claude" or "openai" or "azure" - openai = { - endpoint = "https://api.openai.com", - model = "gpt-4o", - temperature = 0, - max_tokens = 4096, - }, - azure = { - endpoint = "", -- example: "https://.openai.azure.com" - deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") - api_version = "2024-06-01", - temperature = 0, - max_tokens = 4096, - }, - claude = { - endpoint = "https://api.anthropic.com", - model = "claude-3-5-sonnet-20240620", - temperature = 0, - max_tokens = 4096, - }, - highlights = { - diff = { - current = "DiffText", -- need have background color - incoming = "DiffAdd", -- need have background color - }, - }, - mappings = { - show_sidebar = "aa", - diff = { - ours = "co", - theirs = "ct", - none = "c0", - both = "cb", - next = "]x", - prev = "[x", - }, - }, -} +local sidebar = require("avante.sidebar") +local config = require("avante.config") function M.setup(opts) - M.config = vim.tbl_deep_extend("force", M.config, opts or {}) - - local bufnr = vim.api.nvim_get_current_buf() - if is_code_buf(bufnr) then - _cur_code_buf = bufnr - end - - tiktoken.setup("gpt-4o") - - diff.setup({ - debug = false, -- log output to console - default_mappings = M.config.mappings.diff, -- disable buffer local mapping created by this plugin - default_commands = true, -- disable commands created by this plugin - disable_diagnostics = true, -- This will disable the diagnostics in a buffer whilst it is conflicted - list_opener = "copen", - highlights = M.config.highlights.diff, - }) - - local function on_buf_enter() - bufnr = vim.api.nvim_get_current_buf() - if is_code_buf(bufnr) then - _cur_code_buf = bufnr - end - end - - api.nvim_create_autocmd("BufEnter", { - callback = on_buf_enter, - }) - - api.nvim_create_user_command("AvanteAsk", function() - M.render_sidebar() - end, { - nargs = 0, - }) - - api.nvim_set_keymap("n", M.config.mappings.show_sidebar, "AvanteAsk", { noremap = true, silent = true }) + config.update(opts) + sidebar.setup() end return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua new file mode 100644 index 0000000..035d4bb --- /dev/null +++ b/lua/avante/sidebar.lua @@ -0,0 +1,783 @@ +local M = {} +local curl = require("plenary.curl") +local Path = require("plenary.path") +local n = require("nui-components") +local diff = require("avante.diff") +local utils = require("avante.utils") +local tiktoken = require("avante.tiktoken") +local config = require("avante.config") +local api = vim.api +local fn = vim.fn + +local RESULT_BUF_NAME = "AVANTE_RESULT" +local CONFLICT_BUF_NAME = "AVANTE_CONFLICT" + +local function create_result_buf() + local buf = api.nvim_create_buf(false, true) + api.nvim_set_option_value("filetype", "markdown", { buf = buf }) + api.nvim_set_option_value("buftype", "nofile", { buf = buf }) + api.nvim_set_option_value("swapfile", false, { buf = buf }) + api.nvim_set_option_value("modifiable", false, { buf = buf }) + api.nvim_set_option_value("bufhidden", "wipe", { buf = buf }) + api.nvim_buf_set_name(buf, RESULT_BUF_NAME) + return buf +end + +local result_buf = create_result_buf() + +local function is_code_buf(buf) + local ignored_filetypes = { + "dashboard", + "alpha", + "neo-tree", + "NvimTree", + "TelescopePrompt", + "Prompt", + "qf", + "help", + } + + if api.nvim_buf_is_valid(buf) and api.nvim_get_option_value("buflisted", { buf = buf }) then + local buftype = api.nvim_get_option_value("buftype", { buf = buf }) + local filetype = api.nvim_get_option_value("filetype", { buf = buf }) + + if buftype == "" and filetype ~= "" and not vim.tbl_contains(ignored_filetypes, filetype) then + local bufname = api.nvim_buf_get_name(buf) + if bufname ~= "" and bufname ~= RESULT_BUF_NAME and bufname ~= CONFLICT_BUF_NAME then + return true + end + end + end + + return false +end + +local _cur_code_buf = nil + +local function get_cur_code_buf() + return _cur_code_buf +end + +local function get_cur_code_buf_name() + local code_buf = get_cur_code_buf() + if code_buf == nil then + print("Error: cannot get code buffer") + return + end + return api.nvim_buf_get_name(code_buf) +end + +local function get_cur_code_win() + local code_buf = get_cur_code_buf() + if code_buf == nil then + print("Error: cannot get code buffer") + return + end + return fn.bufwinid(code_buf) +end + +local function get_cur_code_buf_content() + local code_buf = get_cur_code_buf() + if code_buf == nil then + print("Error: cannot get code buffer") + return {} + end + return api.nvim_buf_get_lines(code_buf, 0, -1, false) +end + +local function prepend_line_number(content) + local lines = vim.split(content, "\n") + local result = {} + for i, line in ipairs(lines) do + table.insert(result, "L" .. i .. ": " .. line) + end + return table.concat(result, "\n") +end + +local function extract_code_snippets(content) + local snippets = {} + local current_snippet = {} + local in_code_block = false + local lang, start_line, end_line + local explanation = "" + + for line in content:gmatch("[^\r\n]+") do + local start_line_str, end_line_str = line:match("^Replace lines: (%d+)-(%d+)") + if start_line_str ~= nil and end_line_str ~= nil then + start_line = tonumber(start_line_str) + end_line = tonumber(end_line_str) + end + if line:match("^```") then + if in_code_block then + if start_line ~= nil and end_line ~= nil then + table.insert(snippets, { + range = { start_line, end_line }, + content = table.concat(current_snippet, "\n"), + lang = lang, + explanation = explanation, + }) + end + current_snippet = {} + start_line, end_line = nil, nil + explanation = "" + in_code_block = false + else + lang = line:match("^```(%w+)") + if not lang or lang == "" then + lang = "text" + end + in_code_block = true + end + elseif in_code_block then + table.insert(current_snippet, line) + else + explanation = explanation .. line .. "\n" + end + end + + return snippets +end + +local system_prompt = [[ +You are an excellent programming expert. +]] + +local base_user_prompt = [[ +Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: + +1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones. + +2. When suggesting modifications: + a. Explain why the change is necessary or beneficial. + b. Provide the exact code snippet to be replaced using this format: + +Replace lines: {{start_line}}-{{end_line}} +```{{language}} +{{suggested_code}} +``` + +3. Crucial guidelines for line numbers: + - The range {{start_line}}-{{end_line}} is INCLUSIVE. Both start_line and end_line are included in the replacement. + - Count EVERY line, including empty lines, comments, and the LAST line of the file. + - For single-line changes, use the same number for start and end lines. + - For multi-line changes, ensure the range covers ALL affected lines, from the very first to the very last. + - Include the entire block (e.g., complete function) when modifying structured code. + - Pay special attention to the start_line, ensuring it's not omitted or incorrectly set. + - Double-check that your start_line is correct, especially for changes at the beginning of the file. + - Also, be careful with the end_line, especially when it's the last line of the file. + - Double-check that your line numbers align perfectly with the original code structure. + +4. Context and verification: + - Show 1-2 unchanged lines before and after each modification as context. + - These context lines are NOT included in the replacement range. + - After each suggestion, recount the lines to verify the accuracy of your line numbers. + - Double-check that both the start_line and end_line are correct for each modification. + - Verify that your suggested changes align perfectly with the original code structure. + +5. Final check: + - Review all suggestions, ensuring each line number is correct, especially the start_line and end_line. + - Pay extra attention to the start_line of each modification, ensuring it hasn't shifted down. + - Confirm that no unrelated code is accidentally modified or deleted. + - Verify that the start_line and end_line correctly include all intended lines for replacement. + - If a modification involves the first or last line of the file, explicitly state this in your explanation. + - Perform a final alignment check to ensure your line numbers haven't shifted, especially the start_line. + - Double-check that your line numbers align perfectly with the original code structure. + - Do not show the content after these modifications. + +Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. +]] + +local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) + local api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key then + error("ANTHROPIC_API_KEY environment variable is not set") + end + + local user_prompt = base_user_prompt + + local tokens = config.get().claude.max_tokens + local headers = { + ["Content-Type"] = "application/json", + ["x-api-key"] = api_key, + ["anthropic-version"] = "2023-06-01", + ["anthropic-beta"] = "prompt-caching-2024-07-31", + } + + local code_prompt_obj = { + type = "text", + text = string.format("```%s\n%s```", code_lang, code_content), + } + + local user_prompt_obj = { + type = "text", + text = user_prompt, + } + + if tiktoken.count(code_prompt_obj.text) > 1024 then + code_prompt_obj.cache_control = { type = "ephemeral" } + end + + if tiktoken.count(user_prompt_obj.text) > 1024 then + user_prompt_obj.cache_control = { type = "ephemeral" } + end + + local body = { + model = config.get().claude.model, + system = system_prompt, + messages = { + { + role = "user", + content = { + code_prompt_obj, + { + type = "text", + text = string.format("%s", question), + }, + user_prompt_obj, + }, + }, + }, + stream = true, + temperature = config.get().claude.temperature, + max_tokens = tokens, + } + + local url = utils.trim_suffix(config.get().claude.endpoint, "/") .. "/v1/messages" + + print("Sending request to Claude API...") + + curl.post(url, { + ---@diagnostic disable-next-line: unused-local + stream = function(err, data, job) + if err then + on_complete(err) + return + end + if not data then + return + end + for line in data:gmatch("[^\r\n]+") do + if line:sub(1, 6) ~= "data: " then + return + end + vim.schedule(function() + local success, parsed = pcall(fn.json_decode, line:sub(7)) + if not success then + error("Error: failed to parse json: " .. parsed) + return + end + if parsed and parsed.type == "content_block_delta" then + on_chunk(parsed.delta.text) + elseif parsed and parsed.type == "message_stop" then + -- Stream request completed + on_complete(nil) + elseif parsed and parsed.type == "error" then + -- Stream request completed + on_complete(parsed) + end + end) + end + end, + headers = headers, + body = fn.json_encode(body), + }) +end + +local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) + local api_key = os.getenv("OPENAI_API_KEY") + if not api_key then + error("OPENAI_API_KEY environment variable is not set") + end + + local user_prompt = base_user_prompt + .. "\n\nQUESTION:\n" + .. question + .. "\n\nCODE:\n" + .. "```" + .. code_lang + .. "\n" + .. code_content + .. "\n```" + + local url, headers, body + if config.get().provider == "azure" then + api_key = os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") + if not api_key then + error("Azure OpenAI API key is not set. Please set AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable.") + end + url = config.get().azure.endpoint + .. "/openai/deployments/" + .. config.get().azure.deployment + .. "/chat/completions?api-version=" + .. config.get().azure.api_version + headers = { + ["Content-Type"] = "application/json", + ["api-key"] = api_key, + } + body = { + messages = { + { role = "system", content = system_prompt }, + { role = "user", content = user_prompt }, + }, + temperature = config.get().azure.temperature, + max_tokens = config.get().azure.max_tokens, + stream = true, + } + else + url = utils.trim_suffix(config.get().openai.endpoint, "/") .. "/v1/chat/completions" + headers = { + ["Content-Type"] = "application/json", + ["Authorization"] = "Bearer " .. api_key, + } + body = { + model = config.get().openai.model, + messages = { + { role = "system", content = system_prompt }, + { role = "user", content = user_prompt }, + }, + temperature = config.get().openai.temperature, + max_tokens = config.get().openai.max_tokens, + stream = true, + } + end + + print("Sending request to " .. (config.get().provider == "azure" and "Azure OpenAI" or "OpenAI") .. " API...") + + curl.post(url, { + ---@diagnostic disable-next-line: unused-local + stream = function(err, data, job) + if err then + on_complete(err) + return + end + if not data then + return + end + for line in data:gmatch("[^\r\n]+") do + if line:sub(1, 6) ~= "data: " then + return + end + vim.schedule(function() + local piece = line:sub(7) + local success, parsed = pcall(fn.json_decode, piece) + if not success then + if piece == "[DONE]" then + on_complete(nil) + return + end + error("Error: failed to parse json: " .. parsed) + return + end + if parsed and parsed.choices and parsed.choices[1].delta.content then + on_chunk(parsed.choices[1].delta.content) + elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then + -- Stream request completed + on_complete(nil) + end + end) + end + end, + headers = headers, + body = fn.json_encode(body), + }) +end + +local function call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) + if config.get().provider == "openai" or config.get().provider == "azure" then + call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) + elseif config.get().provider == "claude" then + call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) + end +end + +local function update_result_buf_content(content) + local current_win = api.nvim_get_current_win() + local result_win = fn.bufwinid(result_buf) + + vim.defer_fn(function() + api.nvim_set_option_value("modifiable", true, { buf = result_buf }) + api.nvim_buf_set_lines(result_buf, 0, -1, false, vim.split(content, "\n")) + api.nvim_set_option_value("filetype", "markdown", { buf = result_buf }) + if result_win ~= -1 then + -- Move to the bottom + api.nvim_win_set_cursor(result_win, { api.nvim_buf_line_count(result_buf), 0 }) + api.nvim_set_option_value("modifiable", false, { buf = result_buf }) + api.nvim_set_current_win(current_win) + end + end, 0) +end + +-- Add a new function to display notifications +local function show_notification(message) + vim.notify(message, vim.log.levels.INFO, { + title = "AI Assistant", + timeout = 3000, + }) +end + +-- Function to get the current project root directory +local function get_project_root() + local current_file = fn.expand("%:p") + local current_dir = fn.fnamemodify(current_file, ":h") + local git_root = fn.systemlist("git -C " .. fn.shellescape(current_dir) .. " rev-parse --show-toplevel")[1] + return git_root or current_dir +end + +local function get_chat_history_filename() + local code_buf_name = get_cur_code_buf_name() + if code_buf_name == nil then + print("Error: cannot get code buffer name") + return + end + local relative_path = fn.fnamemodify(code_buf_name, ":~:.") + -- Replace path separators with double underscores + local path_with_separators = fn.substitute(relative_path, "/", "__", "g") + -- Replace other non-alphanumeric characters with single underscores + return fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g") +end + +-- Function to get the chat history file path +local function get_chat_history_file() + local project_root = get_project_root() + local filename = get_chat_history_filename() + local history_dir = Path:new(project_root, ".avante_chat_history") + return history_dir:joinpath(filename .. ".json") +end + +-- Function to get current timestamp +local function get_timestamp() + return os.date("%Y-%m-%d %H:%M:%S") +end + +-- Function to load chat history +local function load_chat_history() + local history_file = get_chat_history_file() + if history_file:exists() then + local content = history_file:read() + return fn.json_decode(content) + end + return {} +end + +-- Function to save chat history +local function save_chat_history(history) + local history_file = get_chat_history_file() + local history_dir = history_file:parent() + + -- Create the directory if it doesn't exist + if not history_dir:exists() then + history_dir:mkdir({ parents = true }) + end + + history_file:write(fn.json_encode(history), "w") +end + +local function update_result_buf_with_history(history) + local content = "" + for _, entry in ipairs(history) do + content = content .. "## " .. entry.timestamp .. "\n\n" + content = content .. "> " .. entry.requirement:gsub("\n", "\n> ") .. "\n\n" + content = content .. entry.response .. "\n\n" + content = content .. "---\n\n" + end + update_result_buf_content(content) +end + +local function trim_line_number_prefix(line) + return line:gsub("^L%d+: ", "") +end + +local function get_conflict_content(content, snippets) + -- sort snippets by start_line + table.sort(snippets, function(a, b) + return a.range[1] < b.range[1] + end) + + local lines = vim.split(content, "\n") + local result = {} + local current_line = 1 + + for _, snippet in ipairs(snippets) do + local start_line, end_line = unpack(snippet.range) + + while current_line < start_line do + table.insert(result, lines[current_line]) + current_line = current_line + 1 + end + + table.insert(result, "<<<<<<< HEAD") + for i = start_line, end_line do + table.insert(result, lines[i]) + end + table.insert(result, "=======") + + for _, line in ipairs(vim.split(snippet.content, "\n")) do + line = trim_line_number_prefix(line) + table.insert(result, line) + end + + table.insert(result, ">>>>>>> Snippet") + + current_line = end_line + 1 + end + + while current_line <= #lines do + table.insert(result, lines[current_line]) + current_line = current_line + 1 + end + + return result +end + +local get_renderer_size_and_position = function() + local renderer_width = math.ceil(vim.o.columns * 0.3) + local renderer_height = vim.o.lines + local renderer_position = vim.o.columns - renderer_width + return renderer_width, renderer_height, renderer_position +end + +function M.render_sidebar() + if result_buf ~= nil and api.nvim_buf_is_valid(result_buf) then + api.nvim_buf_delete(result_buf, { force = true }) + end + + result_buf = create_result_buf() + + local renderer_width, renderer_height, renderer_position = get_renderer_size_and_position() + + local renderer = n.create_renderer({ + width = renderer_width, + height = renderer_height, + position = renderer_position, + relative = "editor", + }) + + local autocmd_id + renderer:on_mount(function() + autocmd_id = api.nvim_create_autocmd("VimResized", { + callback = function() + local width, height, _ = get_renderer_size_and_position() + renderer:set_size({ width = width, height = height }) + end, + }) + end) + + renderer:on_unmount(function() + if autocmd_id ~= nil then + api.nvim_del_autocmd(autocmd_id) + end + end) + + local signal = n.create_signal({ + is_loading = false, + text = "", + }) + + local chat_history = load_chat_history() + update_result_buf_with_history(chat_history) + + local function handle_submit() + local state = signal:get_value() + local user_input = state.text + + local timestamp = get_timestamp() + update_result_buf_content( + "## " + .. timestamp + .. "\n\n> " + .. user_input:gsub("\n", "\n> ") + .. "\n\nGenerating response from " + .. config.get().provider + .. " ...\n" + ) + + local code_buf = get_cur_code_buf() + if code_buf == nil then + error("Error: cannot get code buffer") + return + end + local content = table.concat(get_cur_code_buf_content(), "\n") + local content_with_line_numbers = prepend_line_number(content) + local full_response = "" + + signal.is_loading = true + + local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) + + call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) + full_response = full_response .. chunk + update_result_buf_content( + "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response + ) + vim.schedule(function() + vim.cmd("redraw") + end) + end, function(err) + signal.is_loading = false + + if err ~= nil then + update_result_buf_content( + "## " + .. timestamp + .. "\n\n> " + .. user_input:gsub("\n", "\n> ") + .. "\n\n" + .. full_response + .. "\n\nšŸšØ Error: " + .. vim.inspect(err) + ) + return + end + + -- Execute when the stream request is actually completed + update_result_buf_content( + "## " + .. timestamp + .. "\n\n> " + .. user_input:gsub("\n", "\n> ") + .. "\n\n" + .. full_response + .. "\n\n**Generation complete!** Please review the code suggestions above.\n\n\n\n" + ) + + -- Display notification + show_notification("Content generation complete!") + + -- Save chat history + table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response }) + save_chat_history(chat_history) + + local snippets = extract_code_snippets(full_response) + local conflict_content = get_conflict_content(content, snippets) + + vim.defer_fn(function() + api.nvim_buf_set_lines(code_buf, 0, -1, false, conflict_content) + local code_win = get_cur_code_win() + if code_win == nil then + error("Error: cannot get code window") + return + end + api.nvim_set_current_win(code_win) + api.nvim_feedkeys(api.nvim_replace_termcodes("", true, false, true), "n", true) + diff.add_visited_buffer(code_buf) + diff.process(code_buf) + api.nvim_feedkeys("gg", "n", false) + vim.defer_fn(function() + vim.cmd("AvanteConflictNextConflict") + api.nvim_feedkeys("zz", "n", false) + end, 1000) + end, 10) + end) + end + + local body = function() + local code_buf = get_cur_code_buf() + if code_buf == nil then + error("Error: cannot get code buffer") + return + end + local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) + local icon = require("nvim-web-devicons").get_icon_by_filetype(filetype, {}) + local code_file_fullpath = api.nvim_buf_get_name(code_buf) + local code_filename = fn.fnamemodify(code_file_fullpath, ":t") + + return n.rows( + { flex = 0 }, + n.box( + { + direction = "column", + size = vim.o.lines - 4, + }, + n.buffer({ + id = "response", + flex = 1, + buf = result_buf, + autoscroll = true, + border_label = { + text = "šŸ’¬ Avante Chat", + align = "center", + }, + padding = { + top = 1, + bottom = 1, + left = 1, + right = 1, + }, + }) + ), + n.gap(1), + n.columns( + { flex = 0 }, + n.text_input({ + id = "text-input", + border_label = { + text = string.format(" šŸ™‹ Your question (with %s %s): ", icon, code_filename), + }, + autofocus = true, + wrap = true, + flex = 1, + on_change = function(value) + local state = signal:get_value() + local is_enter = value:sub(-1) == "\n" and #state.text < #value + if is_enter then + value = value:sub(1, -2) + end + signal.text = value + if is_enter and #value > 0 then + handle_submit() + end + end, + padding = { left = 1, right = 1 }, + }), + n.gap(1), + n.spinner({ + is_loading = signal.is_loading, + padding = { top = 1, right = 1 }, + ---@diagnostic disable-next-line: undefined-field + hidden = signal.is_loading:negate(), + }) + ) + ) + end + + renderer:render(body) +end + +function M.setup() + local bufnr = vim.api.nvim_get_current_buf() + if is_code_buf(bufnr) then + _cur_code_buf = bufnr + end + + tiktoken.setup("gpt-4o") + + diff.setup({ + debug = false, -- log output to console + default_mappings = config.get().mappings.diff, -- disable buffer local mapping created by this plugin + default_commands = true, -- disable commands created by this plugin + disable_diagnostics = true, -- This will disable the diagnostics in a buffer whilst it is conflicted + list_opener = "copen", + highlights = config.get().highlights.diff, + }) + + local function on_buf_enter() + bufnr = vim.api.nvim_get_current_buf() + if is_code_buf(bufnr) then + _cur_code_buf = bufnr + end + end + + api.nvim_create_autocmd("BufEnter", { + callback = on_buf_enter, + }) + + api.nvim_create_user_command("AvanteAsk", function() + M.render_sidebar() + end, { + nargs = 0, + }) + + api.nvim_set_keymap("n", config.get().mappings.show_sidebar, "AvanteAsk", { noremap = true, silent = true }) +end + +return M