mirror of
https://github.com/Wurst-Imperium/Wurst7.git
synced 2024-09-19 17:02:13 +02:00
Remove support for legacy oobabooga API, add custom model settings
This should make AutoComplete work with any OpenAI-compatible API, including the new oobabooga one.
This commit is contained in:
parent
5b4ade795e
commit
49dd9cca63
@ -17,10 +17,8 @@ import net.wurstclient.SearchTags;
|
||||
import net.wurstclient.events.ChatOutputListener;
|
||||
import net.wurstclient.events.UpdateListener;
|
||||
import net.wurstclient.hack.Hack;
|
||||
import net.wurstclient.hacks.autocomplete.ApiProviderSetting;
|
||||
import net.wurstclient.hacks.autocomplete.MessageCompleter;
|
||||
import net.wurstclient.hacks.autocomplete.ModelSettings;
|
||||
import net.wurstclient.hacks.autocomplete.OobaboogaMessageCompleter;
|
||||
import net.wurstclient.hacks.autocomplete.OpenAiMessageCompleter;
|
||||
import net.wurstclient.hacks.autocomplete.SuggestionHandler;
|
||||
import net.wurstclient.util.ChatUtils;
|
||||
@ -32,7 +30,6 @@ public final class AutoCompleteHack extends Hack
|
||||
{
|
||||
private final ModelSettings modelSettings = new ModelSettings();
|
||||
private final SuggestionHandler suggestionHandler = new SuggestionHandler();
|
||||
private final ApiProviderSetting apiProvider = new ApiProviderSetting();
|
||||
|
||||
private MessageCompleter completer;
|
||||
private String draftMessage;
|
||||
@ -47,7 +44,6 @@ public final class AutoCompleteHack extends Hack
|
||||
super("AutoComplete");
|
||||
setCategory(Category.CHAT);
|
||||
|
||||
addSetting(apiProvider);
|
||||
modelSettings.forEach(this::addSetting);
|
||||
suggestionHandler.getSettings().forEach(this::addSetting);
|
||||
}
|
||||
@ -55,11 +51,7 @@ public final class AutoCompleteHack extends Hack
|
||||
@Override
|
||||
protected void onEnable()
|
||||
{
|
||||
completer = switch(apiProvider.getSelected())
|
||||
{
|
||||
case OPENAI -> new OpenAiMessageCompleter(modelSettings);
|
||||
case OOBABOOGA -> new OobaboogaMessageCompleter(modelSettings);
|
||||
};
|
||||
completer = new OpenAiMessageCompleter(modelSettings);
|
||||
|
||||
if(completer instanceof OpenAiMessageCompleter
|
||||
&& System.getenv("WURST_OPENAI_KEY") == null)
|
||||
|
@ -1,47 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2014-2024 Wurst-Imperium and contributors.
|
||||
*
|
||||
* This source code is subject to the terms of the GNU General Public
|
||||
* License, version 3. If a copy of the GPL was not distributed with this
|
||||
* file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt
|
||||
*/
|
||||
package net.wurstclient.hacks.autocomplete;
|
||||
|
||||
import net.wurstclient.settings.EnumSetting;
|
||||
|
||||
public final class ApiProviderSetting
|
||||
extends EnumSetting<ApiProviderSetting.ApiProvider>
|
||||
{
|
||||
public ApiProviderSetting()
|
||||
{
|
||||
super("API provider",
|
||||
"\u00a7lOpenAI\u00a7r lets you use models like ChatGPT, but requires an"
|
||||
+ " account with API access, costs money to use and sends your chat"
|
||||
+ " history to their servers. The name is a lie - it's closed"
|
||||
+ " source.\n\n"
|
||||
+ "\u00a7loobabooga\u00a7r lets you use models like LLaMA and many"
|
||||
+ " others. It's a true open source alternative to OpenAI that you"
|
||||
+ " can run locally on your own computer. It's free to use and does"
|
||||
+ " not send your chat history to any servers.",
|
||||
ApiProvider.values(), ApiProvider.OOBABOOGA);
|
||||
}
|
||||
|
||||
public enum ApiProvider
|
||||
{
|
||||
OPENAI("OpenAI"),
|
||||
OOBABOOGA("oobabooga");
|
||||
|
||||
private final String name;
|
||||
|
||||
private ApiProvider(String name)
|
||||
{
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString()
|
||||
{
|
||||
return name;
|
||||
}
|
||||
}
|
||||
}
|
@ -98,8 +98,7 @@ public final class ModelSettings
|
||||
+ " history.\n\n"
|
||||
+ "Positive values encourage the model to use synonyms and"
|
||||
+ " talk about different topics. Negative values encourage the"
|
||||
+ " model to repeat the same word over and over again.\n\n"
|
||||
+ "Only works with OpenAI models.",
|
||||
+ " model to repeat the same word over and over again.",
|
||||
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
|
||||
|
||||
public final SliderSetting frequencyPenalty =
|
||||
@ -108,26 +107,9 @@ public final class ModelSettings
|
||||
+ " appears in the chat history.\n\n"
|
||||
+ "Positive values encourage the model to use synonyms and"
|
||||
+ " talk about different topics. Negative values encourage the"
|
||||
+ " model to repeat existing chat messages.\n\n"
|
||||
+ "Only works with OpenAI models.",
|
||||
+ " model to repeat existing chat messages.",
|
||||
0, -2, 2, 0.01, ValueDisplay.DECIMAL);
|
||||
|
||||
public final SliderSetting repetitionPenalty =
|
||||
new SliderSetting("Repetition penalty",
|
||||
"Similar to presence penalty, but uses a different algorithm.\n\n"
|
||||
+ "1.0 means no penalty, negative values are not possible and"
|
||||
+ " 1.5 is the maximum value.\n\n"
|
||||
+ "Only works with the oobabooga web UI.",
|
||||
1, 1, 1.5, 0.01, ValueDisplay.DECIMAL);
|
||||
|
||||
public final SliderSetting encoderRepetitionPenalty =
|
||||
new SliderSetting("Encoder repetition penalty",
|
||||
"Similar to frequency penalty, but uses a different algorithm.\n\n"
|
||||
+ "1.0 means no penalty, 0.8 behaves like a negative value and"
|
||||
+ " 1.5 is the maximum value.\n\n"
|
||||
+ "Only works with the oobabooga web UI.",
|
||||
1, 0.8, 1.5, 0.01, ValueDisplay.DECIMAL);
|
||||
|
||||
public final EnumSetting<StopSequence> stopSequence = new EnumSetting<>(
|
||||
"Stop sequence",
|
||||
"Controls how AutoComplete detects the end of a chat message.\n\n"
|
||||
@ -170,7 +152,7 @@ public final class ModelSettings
|
||||
+ " predictions.\n\n"
|
||||
+ "Higher values improve the quality of predictions, but also"
|
||||
+ " increase the time it takes to generate them, as well as cost"
|
||||
+ " (for OpenAI API users) or RAM usage (for oobabooga users).",
|
||||
+ " (for APIs like OpenAI) or RAM usage (for self-hosted models).",
|
||||
10, 0, 100, 1, ValueDisplay.INTEGER);
|
||||
|
||||
public final CheckboxSetting filterServerMessages =
|
||||
@ -182,6 +164,47 @@ public final class ModelSettings
|
||||
+ " etc.",
|
||||
false);
|
||||
|
||||
public final TextFieldSetting customModel = new TextFieldSetting(
|
||||
"Custom model",
|
||||
"If set, this model will be used instead of the one specified in the"
|
||||
+ " \"OpenAI model\" setting.\n\n"
|
||||
+ "Use this if you have a fine-tuned OpenAI model or if you are"
|
||||
+ " using a custom endpoint that is OpenAI-compatible but offers"
|
||||
+ " different models.",
|
||||
"");
|
||||
|
||||
public final EnumSetting<CustomModelType> customModelType =
|
||||
new EnumSetting<>("Custom model type", "Whether the custom"
|
||||
+ " model should use the chat endpoint or the legacy endpoint.\n\n"
|
||||
+ "If \"Custom model\" is left blank, this setting is ignored.",
|
||||
CustomModelType.values(), CustomModelType.CHAT);
|
||||
|
||||
public enum CustomModelType
|
||||
{
|
||||
CHAT("Chat", true),
|
||||
LEGACY("Legacy", false);
|
||||
|
||||
private final String name;
|
||||
private final boolean chat;
|
||||
|
||||
private CustomModelType(String name, boolean chat)
|
||||
{
|
||||
this.name = name;
|
||||
this.chat = chat;
|
||||
}
|
||||
|
||||
public boolean isChat()
|
||||
{
|
||||
return chat;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString()
|
||||
{
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
public final TextFieldSetting openaiChatEndpoint = new TextFieldSetting(
|
||||
"OpenAI chat endpoint", "Endpoint for OpenAI's chat completion API.",
|
||||
"https://api.openai.com/v1/chat/completions");
|
||||
@ -191,19 +214,11 @@ public final class ModelSettings
|
||||
"Endpoint for OpenAI's legacy completion API.",
|
||||
"https://api.openai.com/v1/completions");
|
||||
|
||||
public final TextFieldSetting oobaboogaEndpoint =
|
||||
new TextFieldSetting("Oobabooga endpoint",
|
||||
"Endpoint for your Oobabooga web UI instance.\n"
|
||||
+ "Remember to start the Oobabooga server with the"
|
||||
+ " \u00a7e--extensions api\u00a7r flag.",
|
||||
"http://127.0.0.1:5000/api/v1/generate");
|
||||
|
||||
private final List<Setting> settings =
|
||||
Collections.unmodifiableList(Arrays.asList(openAiModel, maxTokens,
|
||||
temperature, topP, presencePenalty, frequencyPenalty,
|
||||
repetitionPenalty, encoderRepetitionPenalty, stopSequence,
|
||||
contextLength, filterServerMessages, openaiChatEndpoint,
|
||||
openaiLegacyEndpoint, oobaboogaEndpoint));
|
||||
temperature, topP, presencePenalty, frequencyPenalty, stopSequence,
|
||||
contextLength, filterServerMessages, customModel, customModelType,
|
||||
openaiChatEndpoint, openaiLegacyEndpoint));
|
||||
|
||||
public void forEach(Consumer<Setting> action)
|
||||
{
|
||||
|
@ -1,85 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2014-2024 Wurst-Imperium and contributors.
|
||||
*
|
||||
* This source code is subject to the terms of the GNU General Public
|
||||
* License, version 3. If a copy of the GPL was not distributed with this
|
||||
* file, You can obtain one at: https://www.gnu.org/licenses/gpl-3.0.txt
|
||||
*/
|
||||
package net.wurstclient.hacks.autocomplete;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
|
||||
import com.google.gson.JsonArray;
|
||||
import com.google.gson.JsonObject;
|
||||
|
||||
import net.wurstclient.util.json.JsonException;
|
||||
import net.wurstclient.util.json.JsonUtils;
|
||||
import net.wurstclient.util.json.WsonObject;
|
||||
|
||||
public final class OobaboogaMessageCompleter extends MessageCompleter
|
||||
{
|
||||
public OobaboogaMessageCompleter(ModelSettings modelSettings)
|
||||
{
|
||||
super(modelSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected JsonObject buildParams(String prompt, int maxSuggestions)
|
||||
{
|
||||
JsonObject params = new JsonObject();
|
||||
params.addProperty("prompt", prompt);
|
||||
params.addProperty("max_length", modelSettings.maxTokens.getValueI());
|
||||
params.addProperty("temperature", modelSettings.temperature.getValue());
|
||||
params.addProperty("top_p", modelSettings.topP.getValue());
|
||||
params.addProperty("repetition_penalty",
|
||||
modelSettings.repetitionPenalty.getValue());
|
||||
params.addProperty("encoder_repetition_penalty",
|
||||
modelSettings.encoderRepetitionPenalty.getValue());
|
||||
JsonArray stoppingStrings = new JsonArray();
|
||||
stoppingStrings
|
||||
.add(modelSettings.stopSequence.getSelected().getSequence());
|
||||
params.add("stopping_strings", stoppingStrings);
|
||||
return params;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WsonObject requestCompletions(JsonObject parameters)
|
||||
throws IOException, JsonException
|
||||
{
|
||||
// set up the API request
|
||||
URL url =
|
||||
URI.create(modelSettings.oobaboogaEndpoint.getValue()).toURL();
|
||||
HttpURLConnection conn = (HttpURLConnection)url.openConnection();
|
||||
conn.setRequestMethod("POST");
|
||||
conn.setRequestProperty("Content-Type", "application/json");
|
||||
|
||||
// set the request body
|
||||
conn.setDoOutput(true);
|
||||
try(OutputStream os = conn.getOutputStream())
|
||||
{
|
||||
os.write(JsonUtils.GSON.toJson(parameters).getBytes());
|
||||
os.flush();
|
||||
}
|
||||
|
||||
// parse the response
|
||||
return JsonUtils.parseConnectionToObject(conn);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String[] extractCompletions(WsonObject response)
|
||||
throws JsonException
|
||||
{
|
||||
// extract completion from response
|
||||
String completion =
|
||||
response.getArray("results").getObject(0).getString("text");
|
||||
|
||||
// remove newlines
|
||||
completion = completion.replace("\n", " ");
|
||||
|
||||
return new String[]{completion};
|
||||
}
|
||||
}
|
@ -35,8 +35,6 @@ public final class OpenAiMessageCompleter extends MessageCompleter
|
||||
JsonObject params = new JsonObject();
|
||||
params.addProperty("stop",
|
||||
modelSettings.stopSequence.getSelected().getSequence());
|
||||
params.addProperty("model",
|
||||
"" + modelSettings.openAiModel.getSelected());
|
||||
params.addProperty("max_tokens", modelSettings.maxTokens.getValueI());
|
||||
params.addProperty("temperature", modelSettings.temperature.getValue());
|
||||
params.addProperty("top_p", modelSettings.topP.getValue());
|
||||
@ -46,8 +44,19 @@ public final class OpenAiMessageCompleter extends MessageCompleter
|
||||
modelSettings.frequencyPenalty.getValue());
|
||||
params.addProperty("n", maxSuggestions);
|
||||
|
||||
// add the prompt, depending on the model
|
||||
if(modelSettings.openAiModel.getSelected().isChatModel())
|
||||
// determine model name and type
|
||||
boolean customModel = !modelSettings.customModel.getValue().isBlank();
|
||||
String modelName = customModel ? modelSettings.customModel.getValue()
|
||||
: "" + modelSettings.openAiModel.getSelected();
|
||||
boolean chatModel =
|
||||
customModel ? modelSettings.customModelType.getSelected().isChat()
|
||||
: modelSettings.openAiModel.getSelected().isChatModel();
|
||||
|
||||
// add the model name
|
||||
params.addProperty("model", modelName);
|
||||
|
||||
// add the prompt, depending on model type
|
||||
if(chatModel)
|
||||
{
|
||||
JsonArray messages = new JsonArray();
|
||||
JsonObject systemMessage = new JsonObject();
|
||||
|
Loading…
Reference in New Issue
Block a user