diff --git a/src/main/java/net/wurstclient/hacks/AutoCompleteHack.java b/src/main/java/net/wurstclient/hacks/AutoCompleteHack.java index 5db7e245..0ed38e23 100644 --- a/src/main/java/net/wurstclient/hacks/AutoCompleteHack.java +++ b/src/main/java/net/wurstclient/hacks/AutoCompleteHack.java @@ -17,10 +17,8 @@ import net.wurstclient.SearchTags; import net.wurstclient.events.ChatOutputListener; import net.wurstclient.events.UpdateListener; import net.wurstclient.hack.Hack; -import net.wurstclient.hacks.autocomplete.ApiProviderSetting; import net.wurstclient.hacks.autocomplete.MessageCompleter; import net.wurstclient.hacks.autocomplete.ModelSettings; -import net.wurstclient.hacks.autocomplete.OobaboogaMessageCompleter; import net.wurstclient.hacks.autocomplete.OpenAiMessageCompleter; import net.wurstclient.hacks.autocomplete.SuggestionHandler; import net.wurstclient.util.ChatUtils; @@ -32,7 +30,6 @@ public final class AutoCompleteHack extends Hack { private final ModelSettings modelSettings = new ModelSettings(); private final SuggestionHandler suggestionHandler = new SuggestionHandler(); - private final ApiProviderSetting apiProvider = new ApiProviderSetting(); private MessageCompleter completer; private String draftMessage; @@ -47,7 +44,6 @@ public final class AutoCompleteHack extends Hack super("AutoComplete"); setCategory(Category.CHAT); - addSetting(apiProvider); modelSettings.forEach(this::addSetting); suggestionHandler.getSettings().forEach(this::addSetting); } @@ -55,11 +51,7 @@ public final class AutoCompleteHack extends Hack @Override protected void onEnable() { - completer = switch(apiProvider.getSelected()) - { - case OPENAI -> new OpenAiMessageCompleter(modelSettings); - case OOBABOOGA -> new OobaboogaMessageCompleter(modelSettings); - }; + completer = new OpenAiMessageCompleter(modelSettings); if(completer instanceof OpenAiMessageCompleter && System.getenv("WURST_OPENAI_KEY") == null) diff --git a/src/main/java/net/wurstclient/hacks/autocomplete/ApiProviderSetting.java b/src/main/java/net/wurstclient/hacks/autocomplete/ApiProviderSetting.java deleted file mode 100644 index 7cd1b473..00000000 --- a/src/main/java/net/wurstclient/hacks/autocomplete/ApiProviderSetting.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2014-2024 Wurst-Imperium and contributors. - * - * This source code is subject to the terms of the GNU General Public - * License, version 3. If a copy of the GPL was not distributed with this - * file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt - */ -package net.wurstclient.hacks.autocomplete; - -import net.wurstclient.settings.EnumSetting; - -public final class ApiProviderSetting - extends EnumSetting -{ - public ApiProviderSetting() - { - super("API provider", - "\u00a7lOpenAI\u00a7r lets you use models like ChatGPT, but requires an" - + " account with API access, costs money to use and sends your chat" - + " history to their servers. The name is a lie - it's closed" - + " source.\n\n" - + "\u00a7loobabooga\u00a7r lets you use models like LLaMA and many" - + " others. It's a true open source alternative to OpenAI that you" - + " can run locally on your own computer. It's free to use and does" - + " not send your chat history to any servers.", - ApiProvider.values(), ApiProvider.OOBABOOGA); - } - - public enum ApiProvider - { - OPENAI("OpenAI"), - OOBABOOGA("oobabooga"); - - private final String name; - - private ApiProvider(String name) - { - this.name = name; - } - - @Override - public String toString() - { - return name; - } - } -} diff --git a/src/main/java/net/wurstclient/hacks/autocomplete/ModelSettings.java b/src/main/java/net/wurstclient/hacks/autocomplete/ModelSettings.java index 4a0e7b03..1288d1ca 100644 --- a/src/main/java/net/wurstclient/hacks/autocomplete/ModelSettings.java +++ b/src/main/java/net/wurstclient/hacks/autocomplete/ModelSettings.java @@ -98,8 +98,7 @@ public final class ModelSettings + " history.\n\n" + "Positive values encourage the model to use synonyms and" + " talk about different topics. Negative values encourage the" - + " model to repeat the same word over and over again.\n\n" - + "Only works with OpenAI models.", + + " model to repeat the same word over and over again.", 0, -2, 2, 0.01, ValueDisplay.DECIMAL); public final SliderSetting frequencyPenalty = @@ -108,26 +107,9 @@ public final class ModelSettings + " appears in the chat history.\n\n" + "Positive values encourage the model to use synonyms and" + " talk about different topics. Negative values encourage the" - + " model to repeat existing chat messages.\n\n" - + "Only works with OpenAI models.", + + " model to repeat existing chat messages.", 0, -2, 2, 0.01, ValueDisplay.DECIMAL); - public final SliderSetting repetitionPenalty = - new SliderSetting("Repetition penalty", - "Similar to presence penalty, but uses a different algorithm.\n\n" - + "1.0 means no penalty, negative values are not possible and" - + " 1.5 is the maximum value.\n\n" - + "Only works with the oobabooga web UI.", - 1, 1, 1.5, 0.01, ValueDisplay.DECIMAL); - - public final SliderSetting encoderRepetitionPenalty = - new SliderSetting("Encoder repetition penalty", - "Similar to frequency penalty, but uses a different algorithm.\n\n" - + "1.0 means no penalty, 0.8 behaves like a negative value and" - + " 1.5 is the maximum value.\n\n" - + "Only works with the oobabooga web UI.", - 1, 0.8, 1.5, 0.01, ValueDisplay.DECIMAL); - public final EnumSetting stopSequence = new EnumSetting<>( "Stop sequence", "Controls how AutoComplete detects the end of a chat message.\n\n" @@ -170,7 +152,7 @@ public final class ModelSettings + " predictions.\n\n" + "Higher values improve the quality of predictions, but also" + " increase the time it takes to generate them, as well as cost" - + " (for OpenAI API users) or RAM usage (for oobabooga users).", + + " (for APIs like OpenAI) or RAM usage (for self-hosted models).", 10, 0, 100, 1, ValueDisplay.INTEGER); public final CheckboxSetting filterServerMessages = @@ -182,6 +164,47 @@ public final class ModelSettings + " etc.", false); + public final TextFieldSetting customModel = new TextFieldSetting( + "Custom model", + "If set, this model will be used instead of the one specified in the" + + " \"OpenAI model\" setting.\n\n" + + "Use this if you have a fine-tuned OpenAI model or if you are" + + " using a custom endpoint that is OpenAI-compatible but offers" + + " different models.", + ""); + + public final EnumSetting customModelType = + new EnumSetting<>("Custom model type", "Whether the custom" + + " model should use the chat endpoint or the legacy endpoint.\n\n" + + "If \"Custom model\" is left blank, this setting is ignored.", + CustomModelType.values(), CustomModelType.CHAT); + + public enum CustomModelType + { + CHAT("Chat", true), + LEGACY("Legacy", false); + + private final String name; + private final boolean chat; + + private CustomModelType(String name, boolean chat) + { + this.name = name; + this.chat = chat; + } + + public boolean isChat() + { + return chat; + } + + @Override + public String toString() + { + return name; + } + } + public final TextFieldSetting openaiChatEndpoint = new TextFieldSetting( "OpenAI chat endpoint", "Endpoint for OpenAI's chat completion API.", "https://api.openai.com/v1/chat/completions"); @@ -191,19 +214,11 @@ public final class ModelSettings "Endpoint for OpenAI's legacy completion API.", "https://api.openai.com/v1/completions"); - public final TextFieldSetting oobaboogaEndpoint = - new TextFieldSetting("Oobabooga endpoint", - "Endpoint for your Oobabooga web UI instance.\n" - + "Remember to start the Oobabooga server with the" - + " \u00a7e--extensions api\u00a7r flag.", - "http://127.0.0.1:5000/api/v1/generate"); - private final List settings = Collections.unmodifiableList(Arrays.asList(openAiModel, maxTokens, - temperature, topP, presencePenalty, frequencyPenalty, - repetitionPenalty, encoderRepetitionPenalty, stopSequence, - contextLength, filterServerMessages, openaiChatEndpoint, - openaiLegacyEndpoint, oobaboogaEndpoint)); + temperature, topP, presencePenalty, frequencyPenalty, stopSequence, + contextLength, filterServerMessages, customModel, customModelType, + openaiChatEndpoint, openaiLegacyEndpoint)); public void forEach(Consumer action) { diff --git a/src/main/java/net/wurstclient/hacks/autocomplete/OobaboogaMessageCompleter.java b/src/main/java/net/wurstclient/hacks/autocomplete/OobaboogaMessageCompleter.java deleted file mode 100644 index efbd66c9..00000000 --- a/src/main/java/net/wurstclient/hacks/autocomplete/OobaboogaMessageCompleter.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2014-2024 Wurst-Imperium and contributors. - * - * This source code is subject to the terms of the GNU General Public - * License, version 3. If a copy of the GPL was not distributed with this - * file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt - */ -package net.wurstclient.hacks.autocomplete; - -import java.io.IOException; -import java.io.OutputStream; -import java.net.HttpURLConnection; -import java.net.URI; -import java.net.URL; - -import com.google.gson.JsonArray; -import com.google.gson.JsonObject; - -import net.wurstclient.util.json.JsonException; -import net.wurstclient.util.json.JsonUtils; -import net.wurstclient.util.json.WsonObject; - -public final class OobaboogaMessageCompleter extends MessageCompleter -{ - public OobaboogaMessageCompleter(ModelSettings modelSettings) - { - super(modelSettings); - } - - @Override - protected JsonObject buildParams(String prompt, int maxSuggestions) - { - JsonObject params = new JsonObject(); - params.addProperty("prompt", prompt); - params.addProperty("max_length", modelSettings.maxTokens.getValueI()); - params.addProperty("temperature", modelSettings.temperature.getValue()); - params.addProperty("top_p", modelSettings.topP.getValue()); - params.addProperty("repetition_penalty", - modelSettings.repetitionPenalty.getValue()); - params.addProperty("encoder_repetition_penalty", - modelSettings.encoderRepetitionPenalty.getValue()); - JsonArray stoppingStrings = new JsonArray(); - stoppingStrings - .add(modelSettings.stopSequence.getSelected().getSequence()); - params.add("stopping_strings", stoppingStrings); - return params; - } - - @Override - protected WsonObject requestCompletions(JsonObject parameters) - throws IOException, JsonException - { - // set up the API request - URL url = - URI.create(modelSettings.oobaboogaEndpoint.getValue()).toURL(); - HttpURLConnection conn = (HttpURLConnection)url.openConnection(); - conn.setRequestMethod("POST"); - conn.setRequestProperty("Content-Type", "application/json"); - - // set the request body - conn.setDoOutput(true); - try(OutputStream os = conn.getOutputStream()) - { - os.write(JsonUtils.GSON.toJson(parameters).getBytes()); - os.flush(); - } - - // parse the response - return JsonUtils.parseConnectionToObject(conn); - } - - @Override - protected String[] extractCompletions(WsonObject response) - throws JsonException - { - // extract completion from response - String completion = - response.getArray("results").getObject(0).getString("text"); - - // remove newlines - completion = completion.replace("\n", " "); - - return new String[]{completion}; - } -} diff --git a/src/main/java/net/wurstclient/hacks/autocomplete/OpenAiMessageCompleter.java b/src/main/java/net/wurstclient/hacks/autocomplete/OpenAiMessageCompleter.java index 3b8ea5ce..1140a536 100644 --- a/src/main/java/net/wurstclient/hacks/autocomplete/OpenAiMessageCompleter.java +++ b/src/main/java/net/wurstclient/hacks/autocomplete/OpenAiMessageCompleter.java @@ -35,8 +35,6 @@ public final class OpenAiMessageCompleter extends MessageCompleter JsonObject params = new JsonObject(); params.addProperty("stop", modelSettings.stopSequence.getSelected().getSequence()); - params.addProperty("model", - "" + modelSettings.openAiModel.getSelected()); params.addProperty("max_tokens", modelSettings.maxTokens.getValueI()); params.addProperty("temperature", modelSettings.temperature.getValue()); params.addProperty("top_p", modelSettings.topP.getValue()); @@ -46,8 +44,19 @@ public final class OpenAiMessageCompleter extends MessageCompleter modelSettings.frequencyPenalty.getValue()); params.addProperty("n", maxSuggestions); - // add the prompt, depending on the model - if(modelSettings.openAiModel.getSelected().isChatModel()) + // determine model name and type + boolean customModel = !modelSettings.customModel.getValue().isBlank(); + String modelName = customModel ? modelSettings.customModel.getValue() + : "" + modelSettings.openAiModel.getSelected(); + boolean chatModel = + customModel ? modelSettings.customModelType.getSelected().isChat() + : modelSettings.openAiModel.getSelected().isChatModel(); + + // add the model name + params.addProperty("model", modelName); + + // add the prompt, depending on model type + if(chatModel) { JsonArray messages = new JsonArray(); JsonObject systemMessage = new JsonObject();