From c8a764b3a134b8fe42e5e95df1933551fc8e707c Mon Sep 17 00:00:00 2001
From: Hanchin Hsieh <me@yuchanns.xyz>
Date: Sun, 18 Aug 2024 17:54:29 +0800
Subject: [PATCH] fix(env): remove fallback and respect one env (#60)

Make sure to set the corresponding env.
---
 lua/avante/ai_bot.lua | 36 ++++++++++++++----------------------
 1 file changed, 14 insertions(+), 22 deletions(-)

diff --git a/lua/avante/ai_bot.lua b/lua/avante/ai_bot.lua
index 18b9984..8c5d0de 100644
--- a/lua/avante/ai_bot.lua
+++ b/lua/avante/ai_bot.lua
@@ -12,17 +12,13 @@ local Tiktoken = require("avante.tiktoken")
 ---@class avante.AiBot
 local M = {}
 
----@class Environment: table<[string], any>
----@field [string] string the environment variable name
----@field fallback? string Optional fallback API key environment variable name
-
 ---@class EnvironmentHandler: table<[Provider], string>
 local E = {
-  ---@type table<Provider, Environment | string>
+  ---@type table<Provider, string>
   env = {
     openai = "OPENAI_API_KEY",
     claude = "ANTHROPIC_API_KEY",
-    azure = { "AZURE_OPENAI_API_KEY", fallback = "OPENAI_API_KEY" },
+    azure = "AZURE_OPENAI_API_KEY",
   },
   _once = false,
 }
@@ -30,20 +26,7 @@ local E = {
 E = setmetatable(E, {
   ---@param k Provider
   __index = function(_, k)
-    local envvar = E.env[k]
-    if type(envvar) == "string" then
-      local value = os.getenv(envvar)
-      return value and true or false
-    elseif type(envvar) == "table" then
-      local main_key = envvar[1]
-      local value = os.getenv(main_key)
-      if value then
-        return true
-      elseif envvar.fallback then
-        return os.getenv(envvar.fallback) and true or false
-      end
-    end
-    return false
+    return os.getenv(E.env[k]) and true or false
   end,
 })
 
@@ -137,8 +120,17 @@ E.setup = function(var)
       vim.defer_fn(function()
         -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
         local exclude_buftypes = { "dashboard", "alpha", "qf", "nofile" }
-        local exclude_filetypes =
-          { "NvimTree", "Outline", "help", "dashboard", "alpha", "qf", "ministarter", "TelescopePrompt", "gitcommit" }
+        local exclude_filetypes = {
+          "NvimTree",
+          "Outline",
+          "help",
+          "dashboard",
+          "alpha",
+          "qf",
+          "ministarter",
+          "TelescopePrompt",
+          "gitcommit",
+        }
         if
           not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
           and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)