0
0
mirror of https://github.com/florisboard/florisboard.git synced 2024-09-19 19:42:20 +02:00

Implement spell checking support using the new NLP core

This commit is contained in:
Patrick Goldinger 2023-05-31 12:46:29 +02:00
parent a21114ad03
commit eb23dc8ba1
No known key found for this signature in database
GPG Key ID: 533467C3DC7B9262
6 changed files with 201 additions and 105 deletions

View File

@ -61,3 +61,22 @@ Java_dev_patrickgold_florisboard_ime_nlp_latin_LatinNlpSession_00024CXX_nativeLo
session->loadConfigFromFile(config_path);
});
}
extern "C" JNIEXPORT fl::jni::NativeStr JNICALL
Java_dev_patrickgold_florisboard_ime_nlp_latin_LatinNlpSession_00024CXX_nativeSpell( //
JNIEnv* env,
jobject,
jlong native_ptr,
fl::jni::NativeStr j_word,
fl::jni::NativeList j_prev_words,
jint flags
) {
auto* session = reinterpret_cast<fl::nlp::LatinNlpSession*>(native_ptr);
auto word = fl::jni::j2std_string(env, j_word);
auto prev_words = fl::jni::j2std_list<std::string>(env, j_prev_words);
auto spelling_result = session->spell(word, prev_words, flags);
auto json = nlohmann::json();
json["suggestionAttributes"] = spelling_result.suggestion_attributes;
json["suggestions"] = spelling_result.suggestions;
return fl::jni::std2j_string(env, json.dump());
}

View File

@ -92,7 +92,7 @@ class FlorisSpellCheckerService : SpellCheckerService() {
): Array<SpellingResult> = runBlocking {
val retInfos = Array(textInfos.size) { n ->
val word = textInfos[n].text ?: ""
async { nlpManager.spell(spellingSubtype, word, emptyList(), emptyList(), suggestionsLimit) }
async { nlpManager.spell(spellingSubtype, word, emptyList(), suggestionsLimit) }
}
Array(textInfos.size) { n ->
retInfos[n].await().apply {
@ -110,7 +110,7 @@ class FlorisSpellCheckerService : SpellCheckerService() {
return runBlocking {
nlpManager
.spell(spellingSubtype, textInfo.text, emptyList(), emptyList(), suggestionsLimit)
.spell(spellingSubtype, textInfo.text, emptyList(), suggestionsLimit)
.sendToDebugOverlayIfEnabled(textInfo)
.suggestionsInfo
}

View File

@ -37,6 +37,7 @@ import dev.patrickgold.florisboard.ime.core.ComputedSubtype
import dev.patrickgold.florisboard.ime.core.Subtype
import dev.patrickgold.florisboard.ime.editor.EditorContent
import dev.patrickgold.florisboard.ime.editor.EditorRange
import dev.patrickgold.florisboard.ime.input.InputShiftState
import dev.patrickgold.florisboard.keyboardManager
import dev.patrickgold.florisboard.lib.devtools.flogDebug
import dev.patrickgold.florisboard.lib.devtools.flogError
@ -57,6 +58,7 @@ import java.util.*
import java.util.concurrent.atomic.AtomicInteger
import kotlin.properties.Delegates
// TODO: VERY IMPORTANT: This class is the definition of spaghetti code and chaos, clean up or rewrite this class
class NlpManager(context: Context) {
private val prefs by florisPreferenceModel()
private val clipboardManager by context.clipboardManager()
@ -163,20 +165,15 @@ class NlpManager(context: Context) {
suspend fun spell(
subtype: Subtype,
word: String,
precedingWords: List<String>,
followingWords: List<String>,
prevWords: List<String>,
maxSuggestionCount: Int,
): SpellingResult {
//return nlpProviderRegistry.getSpellingProvider(subtype).spell(
// subtype = subtype,
// word = word,
// precedingWords = precedingWords,
// followingWords = followingWords,
// maxSuggestionCount = maxSuggestionCount,
// allowPossiblyOffensive = !prefs.suggestion.blockPossiblyOffensive.get(),
// isPrivateSession = keyboardManager.activeState.isIncognitoMode,
//)
return SpellingResult.unspecified()
return plugins.getOrNull(subtype.nlpProviders.spelling)?.spell(
subtypeId = subtype.id,
word = word,
prevWords = prevWords,
flags = activeSuggestionRequestFlags(maxSuggestionCount),
) ?: SpellingResult.unspecified()
}
suspend fun determineLocalComposing(
@ -212,22 +209,22 @@ class NlpManager(context: Context) {
prefs.suggestion.enabled.get() || providerForcesSuggestionOn(subtypeManager.activeSubtype)
fun suggest(subtype: Subtype, content: EditorContent) {
/*val reqTime = SystemClock.uptimeMillis()
val reqTime = SystemClock.uptimeMillis()
scope.launch {
val suggestions = nlpProviderRegistry.getSuggestionProvider(subtype).suggest(
subtype = subtype,
content = content,
maxCandidateCount = 8,
allowPossiblyOffensive = !prefs.suggestion.blockPossiblyOffensive.get(),
isPrivateSession = keyboardManager.activeState.isIncognitoMode,
)
val suggestions = plugins.getOrNull(subtype.nlpProviders.spelling)?.spell(
subtypeId = subtype.id,
word = content.composingText,
prevWords = content.textBeforeSelection.split(" "), // TODO this split is incorrect
flags = activeSuggestionRequestFlags(),
) ?: SpellingResult.unspecified()
val candidates = suggestions.suggestions().map { WordSuggestionCandidate(it) }
flogDebug { "candidates: $candidates" }
internalSuggestionsGuard.withLock {
if (internalSuggestions.first < reqTime) {
internalSuggestions = reqTime to suggestions
internalSuggestions = reqTime to candidates
}
}
}*/
return
}
}
fun suggestDirectly(suggestions: List<SuggestionCandidate>) {
@ -263,16 +260,27 @@ class NlpManager(context: Context) {
}
}
private fun activeSuggestionRequestFlags(maxSuggestionCount: Int? = null): SuggestionRequestFlags {
return SuggestionRequestFlags.new(
maxSuggestionCount = maxSuggestionCount ?: 8, // TODO make dynamic
issStart = InputShiftState.UNSHIFTED, // TODO evaluate correctly
issCurrent = InputShiftState.UNSHIFTED, // TODO evaluate correctly
maxNgramLevel = 3, // TODO make dynamic
allowPossiblyOffensive = !prefs.suggestion.blockPossiblyOffensive.get(),
overrideHiddenFlag = false, // TODO make dynamic
isPrivateSession = keyboardManager.activeState.isIncognitoMode,
)
}
private fun assembleCandidates() {
runBlocking {
/*val candidates = when {
val candidates = when {
isSuggestionOn() -> {
clipboardSuggestionProvider.suggest(
subtype = Subtype.FALLBACK,
content = editorInstance.activeContent,
maxCandidateCount = 8,
allowPossiblyOffensive = !prefs.suggestion.blockPossiblyOffensive.get(),
isPrivateSession = keyboardManager.activeState.isIncognitoMode,
subtypeId = Subtype.FALLBACK.id,
word = editorInstance.activeContent.currentWordText,
prevWords = listOf(),
flags = activeSuggestionRequestFlags(),
).ifEmpty {
buildList {
internalSuggestionsGuard.withLock {
@ -284,7 +292,7 @@ class NlpManager(context: Context) {
else -> emptyList()
}
activeCandidates = candidates
autoExpandCollapseSmartbarActions(candidates, inlineSuggestions.value)*/
autoExpandCollapseSmartbarActions(candidates, inlineSuggestions.value)
}
}
@ -363,20 +371,19 @@ class NlpManager(context: Context) {
inner class ClipboardSuggestionProvider internal constructor() : SuggestionProvider {
private var lastClipboardItemId: Long = -1
override fun create() {
override suspend fun create() {
// Do nothing
}
override fun preload(subtype: ComputedSubtype) {
override suspend fun preload(subtype: ComputedSubtype) {
// Do nothing
}
override fun suggest(
override suspend fun suggest(
subtypeId: Long,
flags: SuggestionRequestFlags,
word: String,
precedingWords: List<String>,
followingWords: List<String>
prevWords: List<String>,
flags: SuggestionRequestFlags,
): List<SuggestionCandidate> {
// Check if enabled
if (!prefs.suggestion.clipboardContentEnabled.get()) return emptyList()
@ -421,17 +428,17 @@ class NlpManager(context: Context) {
}
}
override fun notifySuggestionAccepted(subtypeId: Long, candidate: SuggestionCandidate) {
override suspend fun notifySuggestionAccepted(subtypeId: Long, candidate: SuggestionCandidate) {
if (candidate is ClipboardSuggestionCandidate) {
lastClipboardItemId = candidate.clipboardItem.id
}
}
override fun notifySuggestionReverted(subtypeId: Long, candidate: SuggestionCandidate) {
override suspend fun notifySuggestionReverted(subtypeId: Long, candidate: SuggestionCandidate) {
// Do nothing
}
override fun removeSuggestion(subtypeId: Long, candidate: SuggestionCandidate): Boolean {
override suspend fun removeSuggestion(subtypeId: Long, candidate: SuggestionCandidate): Boolean {
if (candidate is ClipboardSuggestionCandidate) {
lastClipboardItemId = candidate.clipboardItem.id
return true
@ -439,7 +446,7 @@ class NlpManager(context: Context) {
return false
}
override fun destroy() {
override suspend fun destroy() {
// Do nothing
}
}

View File

@ -19,6 +19,9 @@ package dev.patrickgold.florisboard.ime.nlp.latin
import dev.patrickgold.florisboard.extensionManager
import dev.patrickgold.florisboard.ime.core.ComputedSubtype
import dev.patrickgold.florisboard.ime.keyboard.KeyProximityChecker
import dev.patrickgold.florisboard.ime.nlp.SpellingProvider
import dev.patrickgold.florisboard.ime.nlp.SpellingResult
import dev.patrickgold.florisboard.ime.nlp.SuggestionRequestFlags
import dev.patrickgold.florisboard.lib.FlorisLocale
import dev.patrickgold.florisboard.lib.io.subFile
import dev.patrickgold.florisboard.lib.io.writeJson
@ -27,10 +30,6 @@ import dev.patrickgold.florisboard.native.NativeStr
import dev.patrickgold.florisboard.native.toNativeStr
import dev.patrickgold.florisboard.plugin.FlorisPluginService
import dev.patrickgold.florisboard.subtypeManager
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
private val DEFAULT_PREDICTION_WEIGHTS = LatinPredictionWeights(
lookup = LatinPredictionLookupWeights(
@ -62,7 +61,7 @@ private data class LatinNlpSessionWrapper(
var session: LatinNlpSession,
)
class LatinLanguageProviderService : FlorisPluginService() {
class LatinLanguageProviderService : FlorisPluginService(), SpellingProvider {
companion object {
const val NlpSessionConfigFileName = "nlp_session_config.json"
const val UserDictionaryFileName = "user_dict.fldic"
@ -72,53 +71,62 @@ class LatinLanguageProviderService : FlorisPluginService() {
private val extensionManager by extensionManager()
private val subtypeManager by subtypeManager()
private val scope = CoroutineScope(Dispatchers.Main + SupervisorJob())
private val cachedSessionWrappers = guardedByLock {
mutableListOf<LatinNlpSessionWrapper>()
}
override fun create() {
override suspend fun create() {
// Do nothing
}
override fun preload(subtype: ComputedSubtype) {
override suspend fun preload(subtype: ComputedSubtype) {
if (subtype.isFallback()) return
scope.launch {
cachedSessionWrappers.withLock { sessionWrappers ->
var sessionWrapper = sessionWrappers.find { it.subtype.id == subtype.id }
if (sessionWrapper == null || sessionWrapper.subtype != subtype) {
if (sessionWrapper == null) {
sessionWrapper = LatinNlpSessionWrapper(
subtype = subtype,
session = LatinNlpSession(),
)
sessionWrappers.add(sessionWrapper)
} else {
sessionWrapper.subtype = subtype
}
val cacheDir = subtypeManager.cacheDirFor(subtype)
val filesDir = subtypeManager.filesDirFor(subtype)
val configFile = cacheDir.subFile(NlpSessionConfigFileName)
val userDictFile = filesDir.subFile(UserDictionaryFileName)
if (!userDictFile.exists()) {
nativeInitEmptyDictionary(userDictFile.absolutePath.toNativeStr())
}
val config = LatinNlpSessionConfig(
primaryLocale = subtype.primaryLocale,
secondaryLocales = subtype.secondaryLocales,
baseDictionaryPaths = getBaseDictionaryPaths(subtype),
userDictionaryPath = userDictFile.absolutePath,
predictionWeights = DEFAULT_PREDICTION_WEIGHTS,
keyProximityChecker = DEFAULT_KEY_PROXIMITY_CHECKER,
cachedSessionWrappers.withLock { sessionWrappers ->
var sessionWrapper = sessionWrappers.find { it.subtype.id == subtype.id }
if (sessionWrapper == null || sessionWrapper.subtype != subtype) {
if (sessionWrapper == null) {
sessionWrapper = LatinNlpSessionWrapper(
subtype = subtype,
session = LatinNlpSession(),
)
configFile.writeJson(config)
sessionWrapper.session.loadFromConfigFile(configFile)
sessionWrappers.add(sessionWrapper)
} else {
sessionWrapper.subtype = subtype
}
val cacheDir = subtypeManager.cacheDirFor(subtype)
val filesDir = subtypeManager.filesDirFor(subtype)
val configFile = cacheDir.subFile(NlpSessionConfigFileName)
val userDictFile = filesDir.subFile(UserDictionaryFileName)
if (!userDictFile.exists()) {
nativeInitEmptyDictionary(userDictFile.absolutePath.toNativeStr())
}
val config = LatinNlpSessionConfig(
primaryLocale = subtype.primaryLocale,
secondaryLocales = subtype.secondaryLocales,
baseDictionaryPaths = getBaseDictionaryPaths(subtype),
userDictionaryPath = userDictFile.absolutePath,
predictionWeights = DEFAULT_PREDICTION_WEIGHTS,
keyProximityChecker = DEFAULT_KEY_PROXIMITY_CHECKER,
)
configFile.writeJson(config)
sessionWrapper.session.loadFromConfigFile(configFile)
}
}
}
override fun destroy() {
override suspend fun spell(
subtypeId: Long,
word: String,
prevWords: List<String>,
flags: SuggestionRequestFlags,
): SpellingResult {
return cachedSessionWrappers.withLock { sessionWrappers ->
val sessionWrapper = sessionWrappers.find { it.subtype.id == subtypeId }
return@withLock sessionWrapper?.session?.spell(word, prevWords, flags) ?: SpellingResult.unspecified()
}
}
override suspend fun destroy() {
//
}

View File

@ -16,16 +16,25 @@
package dev.patrickgold.florisboard.ime.nlp.latin
import android.view.textservice.SuggestionsInfo
import dev.patrickgold.florisboard.ime.keyboard.KeyProximityChecker
import dev.patrickgold.florisboard.ime.nlp.SpellingResult
import dev.patrickgold.florisboard.ime.nlp.SuggestionRequestFlags
import dev.patrickgold.florisboard.lib.io.FsFile
import dev.patrickgold.florisboard.lib.kotlin.tryOrNull
import dev.patrickgold.florisboard.native.NativeInstanceWrapper
import dev.patrickgold.florisboard.native.NativePtr
import dev.patrickgold.florisboard.native.NativeStr
import dev.patrickgold.florisboard.native.NativeList
import dev.patrickgold.florisboard.native.toJavaString
import dev.patrickgold.florisboard.native.toNativeList
import dev.patrickgold.florisboard.native.toNativeStr
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json
@Serializable
data class LatinNlpSessionConfig(
@ -47,6 +56,30 @@ value class LatinNlpSession(private val _nativePtr: NativePtr = nativeInit()) :
}
}
suspend fun spell(
word: String,
prevWords: List<String>,
flags: SuggestionRequestFlags,
): SpellingResult {
return withContext(Dispatchers.IO) {
val nativeSpellingResultStr = nativeSpell(
nativePtr = _nativePtr,
word = word.toNativeStr(),
prevWords = prevWords.toNativeList(),
flags = flags.toInt(),
).toJavaString()
val nativeSpellingResult = Json.decodeFromString<NativeSpellingResult>(nativeSpellingResultStr)
return@withContext tryOrNull {
SpellingResult(
SuggestionsInfo(
nativeSpellingResult.suggestionAttributes,
nativeSpellingResult.suggestions.toTypedArray(),
)
)
} ?: SpellingResult.unspecified()
}
}
override fun nativePtr(): NativePtr {
return _nativePtr
}
@ -55,12 +88,23 @@ value class LatinNlpSession(private val _nativePtr: NativePtr = nativeInit()) :
nativeDispose(_nativePtr)
}
@Serializable
private data class NativeSpellingResult(
val suggestionAttributes: Int,
val suggestions: List<String>,
)
companion object CXX {
external fun nativeInit(): NativePtr
external fun nativeDispose(nativePtr: NativePtr)
external fun nativeLoadFromConfigFile(nativePtr: NativePtr, configPath: NativeStr)
//external fun nativeSpell(word: NativeStr, prevWords: List<NativeStr>, flags: Int): SpellingResult
external fun nativeSpell(
nativePtr: NativePtr,
word: NativeStr,
prevWords: NativeList,
flags: Int,
): NativeStr
//external fun nativeSuggest(word: NativeStr, prevWords: List<NativeStr>, flags: Int)
//external fun nativeTrain(sentence: List<NativeStr>, maxPrevWords: Int)
}

View File

@ -16,6 +16,7 @@
package dev.patrickgold.florisboard.plugin
import android.annotation.SuppressLint
import android.content.BroadcastReceiver
import android.content.ComponentName
import android.content.Context
@ -27,10 +28,12 @@ import android.os.Handler
import android.os.IBinder
import android.os.Message
import android.os.Messenger
import android.view.textservice.SuggestionsInfo
import dev.patrickgold.florisboard.BuildConfig
import dev.patrickgold.florisboard.ime.core.ComputedSubtype
import dev.patrickgold.florisboard.ime.nlp.SpellingProvider
import dev.patrickgold.florisboard.ime.nlp.SpellingResult
import dev.patrickgold.florisboard.ime.nlp.SuggestionRequest
import dev.patrickgold.florisboard.ime.nlp.SuggestionRequestFlags
import dev.patrickgold.florisboard.lib.devtools.flogDebug
import dev.patrickgold.florisboard.lib.io.FlorisRef
@ -41,10 +44,12 @@ import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.serialization.json.Json
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
@ -62,7 +67,7 @@ class FlorisPluginIndexer(private val context: Context) {
pluginIndex.withLock { pluginIndex ->
val newPluginIndex = mutableListOf<IndexedPlugin>()
fun registerPlugin(
suspend fun registerPlugin(
serviceName: ComponentName,
state: IndexedPluginState,
metadata: FlorisPluginMetadata = FlorisPluginMetadata(""),
@ -123,7 +128,8 @@ class FlorisPluginIndexer(private val context: Context) {
context.registerReceiver(receiver, IntentFilter("android.intent.action.PACKAGE_CHANGED"))
}
suspend fun getOrNull(pluginId: String): IndexedPlugin? {
suspend fun getOrNull(pluginId: String?): IndexedPlugin? {
if (pluginId == null) return null
return pluginIndex.withLock { pluginIndex ->
pluginIndex.find { it.metadata.id == pluginId }
}
@ -139,12 +145,12 @@ class IndexedPlugin(
private var messageIdGenerator = AtomicInteger(1)
private var connection = IndexedPluginConnection(context)
override fun create() {
override suspend fun create() {
if (isValidAndBound()) return
connection.bindService(serviceName)
}
override fun preload(subtype: ComputedSubtype) {
override suspend fun preload(subtype: ComputedSubtype) {
val message = FlorisPluginMessage.requestToService(
action = FlorisPluginMessage.ACTION_PRELOAD,
id = messageIdGenerator.getAndIncrement(),
@ -153,17 +159,27 @@ class IndexedPlugin(
connection.sendMessage(message)
}
override fun spell(
override suspend fun spell(
subtypeId: Long,
flags: SuggestionRequestFlags,
word: String,
precedingWords: List<String>,
followingWords: List<String>
prevWords: List<String>,
flags: SuggestionRequestFlags,
): SpellingResult {
TODO("Not yet implemented")
val request = SuggestionRequest(subtypeId, word, prevWords, flags)
val message = FlorisPluginMessage.requestToService(
action = FlorisPluginMessage.ACTION_SPELL,
id = messageIdGenerator.getAndIncrement(),
data = Json.encodeToString(SuggestionRequest.serializer(), request),
)
connection.sendMessage(message)
return withTimeoutOrNull(5000L) {
val replyMessage = connection.replyMessages.first { it.id == message.id }
val resultObj = replyMessage.obj as? SuggestionsInfo ?: return@withTimeoutOrNull null
SpellingResult(resultObj)
} ?: SpellingResult.unspecified()
}
override fun destroy() {
override suspend fun destroy() {
if (!isValidAndBound()) return
connection.unbindService()
}
@ -225,13 +241,20 @@ class IndexedPlugin(
class IndexedPluginConnection(private val context: Context) {
private val scope = CoroutineScope(Dispatchers.Main + SupervisorJob())
private var serviceMessenger = MutableStateFlow<Messenger?>(null)
private val consumerMessenger = Messenger(IncomingHandler(context))
private val consumerMessenger = Messenger(IncomingHandler())
private var isBound = AtomicBoolean(false)
private val stagedOutgoingMessages = MutableSharedFlow<FlorisPluginMessage>(
replay = 8,
extraBufferCapacity = 8,
onBufferOverflow = BufferOverflow.DROP_OLDEST,
)
private val _replyMessages = MutableSharedFlow<FlorisPluginMessage>(
replay = 8,
extraBufferCapacity = 8,
onBufferOverflow = BufferOverflow.DROP_OLDEST,
)
val replyMessages = _replyMessages.asSharedFlow()
private val serviceConnection = object : ServiceConnection {
override fun onServiceConnected(name: ComponentName?, binder: IBinder?) {
flogDebug { "$name, $binder" }
@ -269,7 +292,7 @@ class IndexedPluginConnection(private val context: Context) {
scope.launch {
stagedOutgoingMessages.collect { message ->
val messenger = serviceMessenger.first { it != null }!!
messenger.send(message.msg.also { it.replyTo = consumerMessenger })
messenger.send(message.also { it.replyTo = consumerMessenger }.toAndroidMessage())
}
}
}
@ -298,21 +321,16 @@ class IndexedPluginConnection(private val context: Context) {
stagedOutgoingMessages.emit(message)
}
class IncomingHandler(context: Context) : Handler(context.mainLooper) {
@SuppressLint("HandlerLeak")
inner class IncomingHandler : Handler(context.mainLooper) {
override fun handleMessage(msg: Message) {
val message = FlorisPluginMessage(msg)
val (source, type, action) = message.metadata()
if (source != FlorisPluginMessage.SOURCE_SERVICE) {
val message = FlorisPluginMessage.fromAndroidMessage(msg)
val (source, type, _) = message.metadata()
if (source != FlorisPluginMessage.SOURCE_SERVICE || type != FlorisPluginMessage.TYPE_RESPONSE) {
return
}
when (type) {
FlorisPluginMessage.TYPE_RESPONSE -> when (action) {
FlorisPluginMessage.ACTION_SPELL -> {
}
FlorisPluginMessage.ACTION_SUGGEST -> {
}
}
runBlocking {
_replyMessages.emit(message)
}
}
}