mirror of
https://github.com/Wurst-Imperium/Wurst7.git
synced 2024-09-20 01:12:13 +02:00
Remove support for legacy oobabooga API, add custom model settings
This should make AutoComplete work with any OpenAI-compatible API, including the new oobabooga one.
This commit is contained in:
parent
5b4ade795e
commit
49dd9cca63
@ -17,10 +17,8 @@ import net.wurstclient.SearchTags;
|
|||||||
import net.wurstclient.events.ChatOutputListener;
|
import net.wurstclient.events.ChatOutputListener;
|
||||||
import net.wurstclient.events.UpdateListener;
|
import net.wurstclient.events.UpdateListener;
|
||||||
import net.wurstclient.hack.Hack;
|
import net.wurstclient.hack.Hack;
|
||||||
import net.wurstclient.hacks.autocomplete.ApiProviderSetting;
|
|
||||||
import net.wurstclient.hacks.autocomplete.MessageCompleter;
|
import net.wurstclient.hacks.autocomplete.MessageCompleter;
|
||||||
import net.wurstclient.hacks.autocomplete.ModelSettings;
|
import net.wurstclient.hacks.autocomplete.ModelSettings;
|
||||||
import net.wurstclient.hacks.autocomplete.OobaboogaMessageCompleter;
|
|
||||||
import net.wurstclient.hacks.autocomplete.OpenAiMessageCompleter;
|
import net.wurstclient.hacks.autocomplete.OpenAiMessageCompleter;
|
||||||
import net.wurstclient.hacks.autocomplete.SuggestionHandler;
|
import net.wurstclient.hacks.autocomplete.SuggestionHandler;
|
||||||
import net.wurstclient.util.ChatUtils;
|
import net.wurstclient.util.ChatUtils;
|
||||||
@ -32,7 +30,6 @@ public final class AutoCompleteHack extends Hack
|
|||||||
{
|
{
|
||||||
private final ModelSettings modelSettings = new ModelSettings();
|
private final ModelSettings modelSettings = new ModelSettings();
|
||||||
private final SuggestionHandler suggestionHandler = new SuggestionHandler();
|
private final SuggestionHandler suggestionHandler = new SuggestionHandler();
|
||||||
private final ApiProviderSetting apiProvider = new ApiProviderSetting();
|
|
||||||
|
|
||||||
private MessageCompleter completer;
|
private MessageCompleter completer;
|
||||||
private String draftMessage;
|
private String draftMessage;
|
||||||
@ -47,7 +44,6 @@ public final class AutoCompleteHack extends Hack
|
|||||||
super("AutoComplete");
|
super("AutoComplete");
|
||||||
setCategory(Category.CHAT);
|
setCategory(Category.CHAT);
|
||||||
|
|
||||||
addSetting(apiProvider);
|
|
||||||
modelSettings.forEach(this::addSetting);
|
modelSettings.forEach(this::addSetting);
|
||||||
suggestionHandler.getSettings().forEach(this::addSetting);
|
suggestionHandler.getSettings().forEach(this::addSetting);
|
||||||
}
|
}
|
||||||
@ -55,11 +51,7 @@ public final class AutoCompleteHack extends Hack
|
|||||||
@Override
|
@Override
|
||||||
protected void onEnable()
|
protected void onEnable()
|
||||||
{
|
{
|
||||||
completer = switch(apiProvider.getSelected())
|
completer = new OpenAiMessageCompleter(modelSettings);
|
||||||
{
|
|
||||||
case OPENAI -> new OpenAiMessageCompleter(modelSettings);
|
|
||||||
case OOBABOOGA -> new OobaboogaMessageCompleter(modelSettings);
|
|
||||||
};
|
|
||||||
|
|
||||||
if(completer instanceof OpenAiMessageCompleter
|
if(completer instanceof OpenAiMessageCompleter
|
||||||
&& System.getenv("WURST_OPENAI_KEY") == null)
|
&& System.getenv("WURST_OPENAI_KEY") == null)
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2014-2024 Wurst-Imperium and contributors.
|
|
||||||
*
|
|
||||||
* This source code is subject to the terms of the GNU General Public
|
|
||||||
* License, version 3. If a copy of the GPL was not distributed with this
|
|
||||||
* file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt
|
|
||||||
*/
|
|
||||||
package net.wurstclient.hacks.autocomplete;
|
|
||||||
|
|
||||||
import net.wurstclient.settings.EnumSetting;
|
|
||||||
|
|
||||||
public final class ApiProviderSetting
|
|
||||||
extends EnumSetting<ApiProviderSetting.ApiProvider>
|
|
||||||
{
|
|
||||||
public ApiProviderSetting()
|
|
||||||
{
|
|
||||||
super("API provider",
|
|
||||||
"\u00a7lOpenAI\u00a7r lets you use models like ChatGPT, but requires an"
|
|
||||||
+ " account with API access, costs money to use and sends your chat"
|
|
||||||
+ " history to their servers. The name is a lie - it's closed"
|
|
||||||
+ " source.\n\n"
|
|
||||||
+ "\u00a7loobabooga\u00a7r lets you use models like LLaMA and many"
|
|
||||||
+ " others. It's a true open source alternative to OpenAI that you"
|
|
||||||
+ " can run locally on your own computer. It's free to use and does"
|
|
||||||
+ " not send your chat history to any servers.",
|
|
||||||
ApiProvider.values(), ApiProvider.OOBABOOGA);
|
|
||||||
}
|
|
||||||
|
|
||||||
public enum ApiProvider
|
|
||||||
{
|
|
||||||
OPENAI("OpenAI"),
|
|
||||||
OOBABOOGA("oobabooga");
|
|
||||||
|
|
||||||
private final String name;
|
|
||||||
|
|
||||||
private ApiProvider(String name)
|
|
||||||
{
|
|
||||||
this.name = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString()
|
|
||||||
{
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -98,8 +98,7 @@ public final class ModelSettings
|
|||||||
+ " history.\n\n"
|
+ " history.\n\n"
|
||||||
+ "Positive values encourage the model to use synonyms and"
|
+ "Positive values encourage the model to use synonyms and"
|
||||||
+ " talk about different topics. Negative values encourage the"
|
+ " talk about different topics. Negative values encourage the"
|
||||||
+ " model to repeat the same word over and over again.\n\n"
|
+ " model to repeat the same word over and over again.",
|
||||||
+ "Only works with OpenAI models.",
|
|
||||||
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
|
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
|
||||||
|
|
||||||
public final SliderSetting frequencyPenalty =
|
public final SliderSetting frequencyPenalty =
|
||||||
@ -108,26 +107,9 @@ public final class ModelSettings
|
|||||||
+ " appears in the chat history.\n\n"
|
+ " appears in the chat history.\n\n"
|
||||||
+ "Positive values encourage the model to use synonyms and"
|
+ "Positive values encourage the model to use synonyms and"
|
||||||
+ " talk about different topics. Negative values encourage the"
|
+ " talk about different topics. Negative values encourage the"
|
||||||
+ " model to repeat existing chat messages.\n\n"
|
+ " model to repeat existing chat messages.",
|
||||||
+ "Only works with OpenAI models.",
|
|
||||||
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
|
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
|
||||||
|
|
||||||
public final SliderSetting repetitionPenalty =
|
|
||||||
new SliderSetting("Repetition penalty",
|
|
||||||
"Similar to presence penalty, but uses a different algorithm.\n\n"
|
|
||||||
+ "1.0 means no penalty, negative values are not possible and"
|
|
||||||
+ " 1.5 is the maximum value.\n\n"
|
|
||||||
+ "Only works with the oobabooga web UI.",
|
|
||||||
1, 1, 1.5, 0.01, ValueDisplay.DECIMAL);
|
|
||||||
|
|
||||||
public final SliderSetting encoderRepetitionPenalty =
|
|
||||||
new SliderSetting("Encoder repetition penalty",
|
|
||||||
"Similar to frequency penalty, but uses a different algorithm.\n\n"
|
|
||||||
+ "1.0 means no penalty, 0.8 behaves like a negative value and"
|
|
||||||
+ " 1.5 is the maximum value.\n\n"
|
|
||||||
+ "Only works with the oobabooga web UI.",
|
|
||||||
1, 0.8, 1.5, 0.01, ValueDisplay.DECIMAL);
|
|
||||||
|
|
||||||
public final EnumSetting<StopSequence> stopSequence = new EnumSetting<>(
|
public final EnumSetting<StopSequence> stopSequence = new EnumSetting<>(
|
||||||
"Stop sequence",
|
"Stop sequence",
|
||||||
"Controls how AutoComplete detects the end of a chat message.\n\n"
|
"Controls how AutoComplete detects the end of a chat message.\n\n"
|
||||||
@ -170,7 +152,7 @@ public final class ModelSettings
|
|||||||
+ " predictions.\n\n"
|
+ " predictions.\n\n"
|
||||||
+ "Higher values improve the quality of predictions, but also"
|
+ "Higher values improve the quality of predictions, but also"
|
||||||
+ " increase the time it takes to generate them, as well as cost"
|
+ " increase the time it takes to generate them, as well as cost"
|
||||||
+ " (for OpenAI API users) or RAM usage (for oobabooga users).",
|
+ " (for APIs like OpenAI) or RAM usage (for self-hosted models).",
|
||||||
10, 0, 100, 1, ValueDisplay.INTEGER);
|
10, 0, 100, 1, ValueDisplay.INTEGER);
|
||||||
|
|
||||||
public final CheckboxSetting filterServerMessages =
|
public final CheckboxSetting filterServerMessages =
|
||||||
@ -182,6 +164,47 @@ public final class ModelSettings
|
|||||||
+ " etc.",
|
+ " etc.",
|
||||||
false);
|
false);
|
||||||
|
|
||||||
|
public final TextFieldSetting customModel = new TextFieldSetting(
|
||||||
|
"Custom model",
|
||||||
|
"If set, this model will be used instead of the one specified in the"
|
||||||
|
+ " \"OpenAI model\" setting.\n\n"
|
||||||
|
+ "Use this if you have a fine-tuned OpenAI model or if you are"
|
||||||
|
+ " using a custom endpoint that is OpenAI-compatible but offers"
|
||||||
|
+ " different models.",
|
||||||
|
"");
|
||||||
|
|
||||||
|
public final EnumSetting<CustomModelType> customModelType =
|
||||||
|
new EnumSetting<>("Custom model type", "Whether the custom"
|
||||||
|
+ " model should use the chat endpoint or the legacy endpoint.\n\n"
|
||||||
|
+ "If \"Custom model\" is left blank, this setting is ignored.",
|
||||||
|
CustomModelType.values(), CustomModelType.CHAT);
|
||||||
|
|
||||||
|
public enum CustomModelType
|
||||||
|
{
|
||||||
|
CHAT("Chat", true),
|
||||||
|
LEGACY("Legacy", false);
|
||||||
|
|
||||||
|
private final String name;
|
||||||
|
private final boolean chat;
|
||||||
|
|
||||||
|
private CustomModelType(String name, boolean chat)
|
||||||
|
{
|
||||||
|
this.name = name;
|
||||||
|
this.chat = chat;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isChat()
|
||||||
|
{
|
||||||
|
return chat;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString()
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public final TextFieldSetting openaiChatEndpoint = new TextFieldSetting(
|
public final TextFieldSetting openaiChatEndpoint = new TextFieldSetting(
|
||||||
"OpenAI chat endpoint", "Endpoint for OpenAI's chat completion API.",
|
"OpenAI chat endpoint", "Endpoint for OpenAI's chat completion API.",
|
||||||
"https://api.openai.com/v1/chat/completions");
|
"https://api.openai.com/v1/chat/completions");
|
||||||
@ -191,19 +214,11 @@ public final class ModelSettings
|
|||||||
"Endpoint for OpenAI's legacy completion API.",
|
"Endpoint for OpenAI's legacy completion API.",
|
||||||
"https://api.openai.com/v1/completions");
|
"https://api.openai.com/v1/completions");
|
||||||
|
|
||||||
public final TextFieldSetting oobaboogaEndpoint =
|
|
||||||
new TextFieldSetting("Oobabooga endpoint",
|
|
||||||
"Endpoint for your Oobabooga web UI instance.\n"
|
|
||||||
+ "Remember to start the Oobabooga server with the"
|
|
||||||
+ " \u00a7e--extensions api\u00a7r flag.",
|
|
||||||
"http://127.0.0.1:5000/api/v1/generate");
|
|
||||||
|
|
||||||
private final List<Setting> settings =
|
private final List<Setting> settings =
|
||||||
Collections.unmodifiableList(Arrays.asList(openAiModel, maxTokens,
|
Collections.unmodifiableList(Arrays.asList(openAiModel, maxTokens,
|
||||||
temperature, topP, presencePenalty, frequencyPenalty,
|
temperature, topP, presencePenalty, frequencyPenalty, stopSequence,
|
||||||
repetitionPenalty, encoderRepetitionPenalty, stopSequence,
|
contextLength, filterServerMessages, customModel, customModelType,
|
||||||
contextLength, filterServerMessages, openaiChatEndpoint,
|
openaiChatEndpoint, openaiLegacyEndpoint));
|
||||||
openaiLegacyEndpoint, oobaboogaEndpoint));
|
|
||||||
|
|
||||||
public void forEach(Consumer<Setting> action)
|
public void forEach(Consumer<Setting> action)
|
||||||
{
|
{
|
||||||
|
@ -1,85 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2014-2024 Wurst-Imperium and contributors.
|
|
||||||
*
|
|
||||||
* This source code is subject to the terms of the GNU General Public
|
|
||||||
* License, version 3. If a copy of the GPL was not distributed with this
|
|
||||||
* file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt
|
|
||||||
*/
|
|
||||||
package net.wurstclient.hacks.autocomplete;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.net.HttpURLConnection;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.net.URL;
|
|
||||||
|
|
||||||
import com.google.gson.JsonArray;
|
|
||||||
import com.google.gson.JsonObject;
|
|
||||||
|
|
||||||
import net.wurstclient.util.json.JsonException;
|
|
||||||
import net.wurstclient.util.json.JsonUtils;
|
|
||||||
import net.wurstclient.util.json.WsonObject;
|
|
||||||
|
|
||||||
public final class OobaboogaMessageCompleter extends MessageCompleter
|
|
||||||
{
|
|
||||||
public OobaboogaMessageCompleter(ModelSettings modelSettings)
|
|
||||||
{
|
|
||||||
super(modelSettings);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected JsonObject buildParams(String prompt, int maxSuggestions)
|
|
||||||
{
|
|
||||||
JsonObject params = new JsonObject();
|
|
||||||
params.addProperty("prompt", prompt);
|
|
||||||
params.addProperty("max_length", modelSettings.maxTokens.getValueI());
|
|
||||||
params.addProperty("temperature", modelSettings.temperature.getValue());
|
|
||||||
params.addProperty("top_p", modelSettings.topP.getValue());
|
|
||||||
params.addProperty("repetition_penalty",
|
|
||||||
modelSettings.repetitionPenalty.getValue());
|
|
||||||
params.addProperty("encoder_repetition_penalty",
|
|
||||||
modelSettings.encoderRepetitionPenalty.getValue());
|
|
||||||
JsonArray stoppingStrings = new JsonArray();
|
|
||||||
stoppingStrings
|
|
||||||
.add(modelSettings.stopSequence.getSelected().getSequence());
|
|
||||||
params.add("stopping_strings", stoppingStrings);
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected WsonObject requestCompletions(JsonObject parameters)
|
|
||||||
throws IOException, JsonException
|
|
||||||
{
|
|
||||||
// set up the API request
|
|
||||||
URL url =
|
|
||||||
URI.create(modelSettings.oobaboogaEndpoint.getValue()).toURL();
|
|
||||||
HttpURLConnection conn = (HttpURLConnection)url.openConnection();
|
|
||||||
conn.setRequestMethod("POST");
|
|
||||||
conn.setRequestProperty("Content-Type", "application/json");
|
|
||||||
|
|
||||||
// set the request body
|
|
||||||
conn.setDoOutput(true);
|
|
||||||
try(OutputStream os = conn.getOutputStream())
|
|
||||||
{
|
|
||||||
os.write(JsonUtils.GSON.toJson(parameters).getBytes());
|
|
||||||
os.flush();
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse the response
|
|
||||||
return JsonUtils.parseConnectionToObject(conn);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String[] extractCompletions(WsonObject response)
|
|
||||||
throws JsonException
|
|
||||||
{
|
|
||||||
// extract completion from response
|
|
||||||
String completion =
|
|
||||||
response.getArray("results").getObject(0).getString("text");
|
|
||||||
|
|
||||||
// remove newlines
|
|
||||||
completion = completion.replace("\n", " ");
|
|
||||||
|
|
||||||
return new String[]{completion};
|
|
||||||
}
|
|
||||||
}
|
|
@ -35,8 +35,6 @@ public final class OpenAiMessageCompleter extends MessageCompleter
|
|||||||
JsonObject params = new JsonObject();
|
JsonObject params = new JsonObject();
|
||||||
params.addProperty("stop",
|
params.addProperty("stop",
|
||||||
modelSettings.stopSequence.getSelected().getSequence());
|
modelSettings.stopSequence.getSelected().getSequence());
|
||||||
params.addProperty("model",
|
|
||||||
"" + modelSettings.openAiModel.getSelected());
|
|
||||||
params.addProperty("max_tokens", modelSettings.maxTokens.getValueI());
|
params.addProperty("max_tokens", modelSettings.maxTokens.getValueI());
|
||||||
params.addProperty("temperature", modelSettings.temperature.getValue());
|
params.addProperty("temperature", modelSettings.temperature.getValue());
|
||||||
params.addProperty("top_p", modelSettings.topP.getValue());
|
params.addProperty("top_p", modelSettings.topP.getValue());
|
||||||
@ -46,8 +44,19 @@ public final class OpenAiMessageCompleter extends MessageCompleter
|
|||||||
modelSettings.frequencyPenalty.getValue());
|
modelSettings.frequencyPenalty.getValue());
|
||||||
params.addProperty("n", maxSuggestions);
|
params.addProperty("n", maxSuggestions);
|
||||||
|
|
||||||
// add the prompt, depending on the model
|
// determine model name and type
|
||||||
if(modelSettings.openAiModel.getSelected().isChatModel())
|
boolean customModel = !modelSettings.customModel.getValue().isBlank();
|
||||||
|
String modelName = customModel ? modelSettings.customModel.getValue()
|
||||||
|
: "" + modelSettings.openAiModel.getSelected();
|
||||||
|
boolean chatModel =
|
||||||
|
customModel ? modelSettings.customModelType.getSelected().isChat()
|
||||||
|
: modelSettings.openAiModel.getSelected().isChatModel();
|
||||||
|
|
||||||
|
// add the model name
|
||||||
|
params.addProperty("model", modelName);
|
||||||
|
|
||||||
|
// add the prompt, depending on model type
|
||||||
|
if(chatModel)
|
||||||
{
|
{
|
||||||
JsonArray messages = new JsonArray();
|
JsonArray messages = new JsonArray();
|
||||||
JsonObject systemMessage = new JsonObject();
|
JsonObject systemMessage = new JsonObject();
|
||||||
|
Loading…
Reference in New Issue
Block a user