0
0
mirror of https://github.com/Wurst-Imperium/Wurst7.git synced 2024-09-19 17:02:13 +02:00

Remove support for legacy oobabooga API, add custom model settings

This should make AutoComplete work with any OpenAI-compatible API,
including the new oobabooga one.
This commit is contained in:
Alexander01998 2024-05-14 19:23:25 +02:00
parent 5b4ade795e
commit 49dd9cca63
5 changed files with 61 additions and 177 deletions

View File

@ -17,10 +17,8 @@ import net.wurstclient.SearchTags;
import net.wurstclient.events.ChatOutputListener;
import net.wurstclient.events.UpdateListener;
import net.wurstclient.hack.Hack;
import net.wurstclient.hacks.autocomplete.ApiProviderSetting;
import net.wurstclient.hacks.autocomplete.MessageCompleter;
import net.wurstclient.hacks.autocomplete.ModelSettings;
import net.wurstclient.hacks.autocomplete.OobaboogaMessageCompleter;
import net.wurstclient.hacks.autocomplete.OpenAiMessageCompleter;
import net.wurstclient.hacks.autocomplete.SuggestionHandler;
import net.wurstclient.util.ChatUtils;
@ -32,7 +30,6 @@ public final class AutoCompleteHack extends Hack
{
private final ModelSettings modelSettings = new ModelSettings();
private final SuggestionHandler suggestionHandler = new SuggestionHandler();
private final ApiProviderSetting apiProvider = new ApiProviderSetting();
private MessageCompleter completer;
private String draftMessage;
@ -47,7 +44,6 @@ public final class AutoCompleteHack extends Hack
super("AutoComplete");
setCategory(Category.CHAT);
addSetting(apiProvider);
modelSettings.forEach(this::addSetting);
suggestionHandler.getSettings().forEach(this::addSetting);
}
@ -55,11 +51,7 @@ public final class AutoCompleteHack extends Hack
@Override
protected void onEnable()
{
completer = switch(apiProvider.getSelected())
{
case OPENAI -> new OpenAiMessageCompleter(modelSettings);
case OOBABOOGA -> new OobaboogaMessageCompleter(modelSettings);
};
completer = new OpenAiMessageCompleter(modelSettings);
if(completer instanceof OpenAiMessageCompleter
&& System.getenv("WURST_OPENAI_KEY") == null)

View File

@ -1,47 +0,0 @@
/*
* Copyright (c) 2014-2024 Wurst-Imperium and contributors.
*
* This source code is subject to the terms of the GNU General Public
* License, version 3. If a copy of the GPL was not distributed with this
* file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt
*/
package net.wurstclient.hacks.autocomplete;
import net.wurstclient.settings.EnumSetting;
public final class ApiProviderSetting
extends EnumSetting<ApiProviderSetting.ApiProvider>
{
public ApiProviderSetting()
{
super("API provider",
"\u00a7lOpenAI\u00a7r lets you use models like ChatGPT, but requires an"
+ " account with API access, costs money to use and sends your chat"
+ " history to their servers. The name is a lie - it's closed"
+ " source.\n\n"
+ "\u00a7loobabooga\u00a7r lets you use models like LLaMA and many"
+ " others. It's a true open source alternative to OpenAI that you"
+ " can run locally on your own computer. It's free to use and does"
+ " not send your chat history to any servers.",
ApiProvider.values(), ApiProvider.OOBABOOGA);
}
public enum ApiProvider
{
OPENAI("OpenAI"),
OOBABOOGA("oobabooga");
private final String name;
private ApiProvider(String name)
{
this.name = name;
}
@Override
public String toString()
{
return name;
}
}
}

View File

@ -98,8 +98,7 @@ public final class ModelSettings
+ " history.\n\n"
+ "Positive values encourage the model to use synonyms and"
+ " talk about different topics. Negative values encourage the"
+ " model to repeat the same word over and over again.\n\n"
+ "Only works with OpenAI models.",
+ " model to repeat the same word over and over again.",
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
public final SliderSetting frequencyPenalty =
@ -108,26 +107,9 @@ public final class ModelSettings
+ " appears in the chat history.\n\n"
+ "Positive values encourage the model to use synonyms and"
+ " talk about different topics. Negative values encourage the"
+ " model to repeat existing chat messages.\n\n"
+ "Only works with OpenAI models.",
+ " model to repeat existing chat messages.",
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
public final SliderSetting repetitionPenalty =
new SliderSetting("Repetition penalty",
"Similar to presence penalty, but uses a different algorithm.\n\n"
+ "1.0 means no penalty, negative values are not possible and"
+ " 1.5 is the maximum value.\n\n"
+ "Only works with the oobabooga web UI.",
1, 1, 1.5, 0.01, ValueDisplay.DECIMAL);
public final SliderSetting encoderRepetitionPenalty =
new SliderSetting("Encoder repetition penalty",
"Similar to frequency penalty, but uses a different algorithm.\n\n"
+ "1.0 means no penalty, 0.8 behaves like a negative value and"
+ " 1.5 is the maximum value.\n\n"
+ "Only works with the oobabooga web UI.",
1, 0.8, 1.5, 0.01, ValueDisplay.DECIMAL);
public final EnumSetting<StopSequence> stopSequence = new EnumSetting<>(
"Stop sequence",
"Controls how AutoComplete detects the end of a chat message.\n\n"
@ -170,7 +152,7 @@ public final class ModelSettings
+ " predictions.\n\n"
+ "Higher values improve the quality of predictions, but also"
+ " increase the time it takes to generate them, as well as cost"
+ " (for OpenAI API users) or RAM usage (for oobabooga users).",
+ " (for APIs like OpenAI) or RAM usage (for self-hosted models).",
10, 0, 100, 1, ValueDisplay.INTEGER);
public final CheckboxSetting filterServerMessages =
@ -182,6 +164,47 @@ public final class ModelSettings
+ " etc.",
false);
public final TextFieldSetting customModel = new TextFieldSetting(
"Custom model",
"If set, this model will be used instead of the one specified in the"
+ " \"OpenAI model\" setting.\n\n"
+ "Use this if you have a fine-tuned OpenAI model or if you are"
+ " using a custom endpoint that is OpenAI-compatible but offers"
+ " different models.",
"");
public final EnumSetting<CustomModelType> customModelType =
new EnumSetting<>("Custom model type", "Whether the custom"
+ " model should use the chat endpoint or the legacy endpoint.\n\n"
+ "If \"Custom model\" is left blank, this setting is ignored.",
CustomModelType.values(), CustomModelType.CHAT);
public enum CustomModelType
{
CHAT("Chat", true),
LEGACY("Legacy", false);
private final String name;
private final boolean chat;
private CustomModelType(String name, boolean chat)
{
this.name = name;
this.chat = chat;
}
public boolean isChat()
{
return chat;
}
@Override
public String toString()
{
return name;
}
}
public final TextFieldSetting openaiChatEndpoint = new TextFieldSetting(
"OpenAI chat endpoint", "Endpoint for OpenAI's chat completion API.",
"https://api.openai.com/v1/chat/completions");
@ -191,19 +214,11 @@ public final class ModelSettings
"Endpoint for OpenAI's legacy completion API.",
"https://api.openai.com/v1/completions");
public final TextFieldSetting oobaboogaEndpoint =
new TextFieldSetting("Oobabooga endpoint",
"Endpoint for your Oobabooga web UI instance.\n"
+ "Remember to start the Oobabooga server with the"
+ " \u00a7e--extensions api\u00a7r flag.",
"http://127.0.0.1:5000/api/v1/generate");
private final List<Setting> settings =
Collections.unmodifiableList(Arrays.asList(openAiModel, maxTokens,
temperature, topP, presencePenalty, frequencyPenalty,
repetitionPenalty, encoderRepetitionPenalty, stopSequence,
contextLength, filterServerMessages, openaiChatEndpoint,
openaiLegacyEndpoint, oobaboogaEndpoint));
temperature, topP, presencePenalty, frequencyPenalty, stopSequence,
contextLength, filterServerMessages, customModel, customModelType,
openaiChatEndpoint, openaiLegacyEndpoint));
public void forEach(Consumer<Setting> action)
{

View File

@ -1,85 +0,0 @@
/*
* Copyright (c) 2014-2024 Wurst-Imperium and contributors.
*
* This source code is subject to the terms of the GNU General Public
* License, version 3. If a copy of the GPL was not distributed with this
* file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt
*/
package net.wurstclient.hacks.autocomplete;
import java.io.IOException;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URL;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import net.wurstclient.util.json.JsonException;
import net.wurstclient.util.json.JsonUtils;
import net.wurstclient.util.json.WsonObject;
public final class OobaboogaMessageCompleter extends MessageCompleter
{
public OobaboogaMessageCompleter(ModelSettings modelSettings)
{
super(modelSettings);
}
@Override
protected JsonObject buildParams(String prompt, int maxSuggestions)
{
JsonObject params = new JsonObject();
params.addProperty("prompt", prompt);
params.addProperty("max_length", modelSettings.maxTokens.getValueI());
params.addProperty("temperature", modelSettings.temperature.getValue());
params.addProperty("top_p", modelSettings.topP.getValue());
params.addProperty("repetition_penalty",
modelSettings.repetitionPenalty.getValue());
params.addProperty("encoder_repetition_penalty",
modelSettings.encoderRepetitionPenalty.getValue());
JsonArray stoppingStrings = new JsonArray();
stoppingStrings
.add(modelSettings.stopSequence.getSelected().getSequence());
params.add("stopping_strings", stoppingStrings);
return params;
}
@Override
protected WsonObject requestCompletions(JsonObject parameters)
throws IOException, JsonException
{
// set up the API request
URL url =
URI.create(modelSettings.oobaboogaEndpoint.getValue()).toURL();
HttpURLConnection conn = (HttpURLConnection)url.openConnection();
conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/json");
// set the request body
conn.setDoOutput(true);
try(OutputStream os = conn.getOutputStream())
{
os.write(JsonUtils.GSON.toJson(parameters).getBytes());
os.flush();
}
// parse the response
return JsonUtils.parseConnectionToObject(conn);
}
@Override
protected String[] extractCompletions(WsonObject response)
throws JsonException
{
// extract completion from response
String completion =
response.getArray("results").getObject(0).getString("text");
// remove newlines
completion = completion.replace("\n", " ");
return new String[]{completion};
}
}

View File

@ -35,8 +35,6 @@ public final class OpenAiMessageCompleter extends MessageCompleter
JsonObject params = new JsonObject();
params.addProperty("stop",
modelSettings.stopSequence.getSelected().getSequence());
params.addProperty("model",
"" + modelSettings.openAiModel.getSelected());
params.addProperty("max_tokens", modelSettings.maxTokens.getValueI());
params.addProperty("temperature", modelSettings.temperature.getValue());
params.addProperty("top_p", modelSettings.topP.getValue());
@ -46,8 +44,19 @@ public final class OpenAiMessageCompleter extends MessageCompleter
modelSettings.frequencyPenalty.getValue());
params.addProperty("n", maxSuggestions);
// add the prompt, depending on the model
if(modelSettings.openAiModel.getSelected().isChatModel())
// determine model name and type
boolean customModel = !modelSettings.customModel.getValue().isBlank();
String modelName = customModel ? modelSettings.customModel.getValue()
: "" + modelSettings.openAiModel.getSelected();
boolean chatModel =
customModel ? modelSettings.customModelType.getSelected().isChat()
: modelSettings.openAiModel.getSelected().isChatModel();
// add the model name
params.addProperty("model", modelName);
// add the prompt, depending on model type
if(chatModel)
{
JsonArray messages = new JsonArray();
JsonObject systemMessage = new JsonObject();