0
0
mirror of https://github.com/florisboard/florisboard.git synced 2024-09-19 19:42:20 +02:00

Add initial flest implementation

This commit is contained in:
Patrick Goldinger 2024-07-04 17:16:47 +02:00
parent f1c5b1802b
commit 72c4f7d4d8
No known key found for this signature in database
10 changed files with 636 additions and 0 deletions

View File

@ -0,0 +1,9 @@
[package]
name = "flest"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
fxhash = "0.2.1"

View File

@ -0,0 +1,102 @@
use fxhash::FxHashMap;
#[derive(Default)]
struct DynTrieNode<V> where V: Default {
children: FxHashMap<char, Box<DynTrieNode<V>>>,
value: Option<V>,
}
impl<V> DynTrieNode<V> where V: Default {
fn for_each_recursive<'a, F>(&'a self, current_word: &mut Vec<char>, f: &mut F)
where F: FnMut(&[char], &'a V) {
if let Some(value) = &self.value {
f(&current_word, value);
}
for (letter, node) in &self.children {
current_word.push(*letter);
node.for_each_recursive(current_word, f);
current_word.pop();
}
}
}
#[derive(Default)]
pub struct DynTrie<V> where V: Default {
root: DynTrieNode<V>,
}
impl<V> DynTrie<V>
where V: Default {
pub fn find(&self, word: &[char]) -> Option<&V> {
let mut current_node = &self.root;
for letter in word {
match current_node.children.get(letter) {
Some(node) => current_node = node,
None => return None,
}
}
return current_node.value.as_ref();
}
fn str_fuzzy_match_whole(str1: &[char], str2: &[char]) -> f64 {
let len1 = str1.len();
let len2 = str2.len();
let max_len = std::cmp::max(len1, len2);
let mut score: f64 = 0.0;
let mut penalty: f64 = 0.0;
for i in 0..max_len {
let ch1 = str1.get(i).unwrap_or(&' ');
let ch2 = str2.get(i).unwrap_or(&' ');
if ch1 == ch2 {
score += 1.0;
} else if ch1.to_lowercase().eq(ch2.to_lowercase()) {
score += 0.5;
} else {
penalty += if i == 0 { 2.0 } else { 1.0 };
}
}
return f64::max(0.0, score - penalty)
}
// TODO: optimization: we do not need to iterate over all
// the trie, we can predict if the score will never be >= 0
// and skip the whole subtree
pub fn find_many(&self, word: &[char]) -> Vec<(Vec<char>, &V)> {
let mut results = Vec::new();
self.for_each(&mut |current_word, value| {
let score = Self::str_fuzzy_match_whole(word, current_word);
if score > 0.0 {
results.push((current_word.to_owned(), value));
}
});
return results;
}
pub fn find_or_insert(&mut self, word: &[char], value: V) -> &mut V {
let mut current_node = &mut self.root;
for letter in word {
current_node = current_node.children.entry(*letter)
.or_insert_with(|| Box::new(DynTrieNode::default()));
}
if current_node.value.is_none() {
current_node.value = Some(value);
}
return current_node.value.as_mut().unwrap();
}
#[allow(dead_code)]
fn insert(&mut self, word: &[char], value: V) {
let mut current_node = &mut self.root;
for letter in word {
current_node = current_node.children.entry(*letter)
.or_insert_with(|| Box::new(DynTrieNode::default()));
}
current_node.value = Some(value);
}
pub fn for_each<'a, F>(&'a self, f: &mut F)
where F: FnMut(&[char], &'a V) {
let mut current_word: Vec<char> = Vec::new();
self.root.for_each_recursive(&mut current_word, f);
}
}

View File

@ -0,0 +1,4 @@
mod dyntrie;
mod ngrammodel;
pub use ngrammodel::*;

View File

@ -0,0 +1,212 @@
use std::collections::HashMap;
use crate::dyntrie::DynTrie;
#[derive(Default)]
struct NgramModelNode {
children: DynTrie<Box<NgramModelNode>>,
time: u64,
usage: u64,
}
impl NgramModelNode {
fn find(&self, ngram: &[&str]) -> Option<&NgramModelNode> {
if ngram.is_empty() {
return None;
}
let token: Vec<char> = ngram[0].chars().collect();
let child = self.children.find(&token);
if child.is_none() {
return None;
}
let child = child.unwrap();
if ngram.len() == 1 {
return Some(child);
}
return child.find(&ngram[1..]);
}
fn find_many(&self, ngram: &[&str]) -> Vec<(Vec<char>, &NgramModelNode)> {
if ngram.is_empty() {
return Vec::new();
}
let token: Vec<char> = ngram[0].chars().collect();
let ret = self.children.find_many(&token);
if ngram.len() == 1 {
return ret
.into_iter()
.map(|node| (node.0, node.1.as_ref()))
.collect();
}
let mut ret2 = Vec::new();
for (_, child) in &ret {
ret2.extend(child.find_many(&ngram[1..]));
}
return ret2;
}
fn train(&mut self, ngram: &[&str], current_time: u64) {
if ngram.is_empty() {
panic!("ngram must not be empty");
}
let token: Vec<char> = ngram[0].chars().collect();
let child = self.children.find_or_insert(&token, Box::new(NgramModelNode::default()));
if ngram.len() == 1 {
if current_time != 0 {
child.time = current_time;
}
child.usage += 1;
} else {
child.train(&ngram[1..], current_time);
}
}
fn debug_print(&self, _indent: usize) {
// println!("{}{}{}", " ".repeat(indent), self.token, if self.time > 0 { "*" } else { "" });
// for child in &self.children {
// child.debug_print(indent + 1);
// }
}
}
#[derive(Default)]
pub struct NgramModel {
root: NgramModelNode,
time: u64,
}
impl NgramModel {
#[allow(dead_code)]
fn find(&self, ngram: &[&str]) -> Option<&NgramModelNode> {
self.root.find(ngram)
}
fn find_many(&self, ngram: &[&str]) -> Vec<(Vec<char>, &NgramModelNode)> {
self.root.find_many(ngram)
}
pub fn train_dataset(&mut self, token_list: &[&str]) {
self.root.train(token_list, 0);
}
pub fn train_input(&mut self, token_list: &[&str]) {
self.time += 1;
self.root.train(token_list, self.time);
}
pub fn debug_print(&self) {
self.root.debug_print(0);
}
pub fn predict(&self, history: &Vec<&str>) -> Vec<(String, f64)> {
let mut tmin = u64::MAX;
let mut tmax = u64::MIN;
let mut umin = u64::MAX;
let mut umax = u64::MIN;
let nmin = 1;
let nmax = 3;
let mut candidate_nodes: Vec<(Vec<char>, &NgramModelNode, f64)> = Vec::new();
let user_input_word = history.last().unwrap_or(&"");
for n in nmin..=std::cmp::min(history.len(), nmax) {
let nweight = 1.0 - (nmax - n) as f64 * 0.1;
let ngram = &history[history.len() - n..history.len() - 1];
let nodes = self.find_many(ngram);
for (_, node) in nodes {
node.children.for_each(&mut |curr_word, child| {
candidate_nodes.push((curr_word.to_owned(), child, nweight));
tmin = tmin.min(child.time);
tmax = tmax.max(child.time);
umin = umin.min(child.usage);
umax = umax.max(child.usage);
});
}
}
candidate_nodes = candidate_nodes
.into_iter()
.map(|(word, node, nweight)| {
(
word,
node,
nweight
* norm_weight(node.time, tmin, tmax)
* norm_weight(node.usage, umin, umax),
)
})
.collect();
if !user_input_word.is_empty() {
let user_input_word: Vec<char> = user_input_word.chars().collect();
let mut filtered_nodes = Vec::new();
for (word, node, weight) in candidate_nodes {
let score_len = std::cmp::min(
(word.len() + user_input_word.len()) / 2,
user_input_word.len(),
) as f64;
let score = str_fuzzy_match_live(&word, &user_input_word);
if score > 0.0 {
let new_weight = 0.95 * (score / score_len) + 0.05 * weight;
filtered_nodes.push((word, node, new_weight));
}
}
self.root.children.for_each(&mut |word, node| {
let score_len = std::cmp::min(
(word.len() + user_input_word.len()) / 2,
user_input_word.len(),
) as f64;
let score = str_fuzzy_match_live(&word, &user_input_word);
if score > 0.0 {
let new_weight = 0.75 * (score / score_len) + 0.25 * 0.0;
filtered_nodes.push((word.to_owned(), node, new_weight));
}
});
candidate_nodes = filtered_nodes;
}
candidate_nodes.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
let mut predictions: HashMap<String, f64> = HashMap::new();
for (word, _, weight) in candidate_nodes {
predictions
.entry(word.iter().collect())
.or_insert(weight);
}
let mut predictions_vec: Vec<(String, f64)> = predictions.into_iter().collect();
predictions_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
predictions_vec.into_iter().take(8).collect()
}
}
fn norm_weight(x: u64, xmin: u64, xmax: u64) -> f64 {
if x <= xmin {
return 0.0;
}
if x >= xmax {
return 1.0;
}
let xnorm = (x - xmin) as f64 / (xmax - xmin) as f64;
return 2.0 * xnorm - xnorm.powi(2);
}
fn str_fuzzy_match_live(word: &[char], current_word: &[char]) -> f64 {
//let len1 = word.len();
let len2 = current_word.len();
let mut score = 0.0;
let mut penalty: f64 = 0.0;
for i in 0..len2 {
let ch1 = word.get(i).unwrap_or(&' ');
let ch2 = current_word.get(i).unwrap_or(&' ');
if ch1 == ch2 {
score += 1.0;
} else if ch1.to_lowercase().eq(ch2.to_lowercase()) {
score += 0.9;
} else {
penalty += if i == 0 { 2.0 } else { 1.0 };
}
}
return f64::max(0.0, score - 0.125 * penalty.powi(2));
}

View File

@ -0,0 +1,13 @@
[package]
name = "textutils"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
icu_segmenter = "1.5.0"
itertools = "0.13.0"
lazy_static = "1.5.0"
linkify = "0.10.0"
regex = "1.10.5"

View File

@ -0,0 +1,20 @@
use lazy_static::lazy_static;
use linkify::{self, LinkFinder};
use regex::Regex;
lazy_static! {
static ref LINK_FINDER: LinkFinder = LinkFinder::new();
static ref REDDIT_REGEX: Regex = Regex::new(r"\/?(r\/[a-zA-Z0-9_]{3}[a-zA-Z0-9_]{0,18}|u\/[a-zA-Z0-9_-]{3}[a-zA-Z0-9_-]{0,17})").unwrap();
}
pub fn preprocess_auto(text: &str) -> String {
let mut cleaned_text = String::new();
let mut begin_cleaned_index = 0;
for span in LINK_FINDER.links(text) {
cleaned_text.push_str(&text[begin_cleaned_index..span.start()]);
begin_cleaned_index = span.end();
}
cleaned_text.push_str(&text[begin_cleaned_index..]);
cleaned_text = REDDIT_REGEX.replace_all(&cleaned_text, "").to_string();
return cleaned_text;
}

View File

@ -0,0 +1,52 @@
mod filter;
mod segment;
pub use filter::*;
pub use segment::*;
#[cfg(test)]
mod tests {
use icu_segmenter::{SentenceSegmenter, WordSegmenter};
use super::*;
#[test]
fn segment_sentences_simple() {
let text = "Hello, world! How are you? I'm fine.";
let segmenter = SentenceSegmenter::new();
let sentences = split_sentences(text, &segmenter);
assert_eq!(&sentences, &["Hello, world!", "How are you?", "I'm fine."]);
}
#[test]
fn segment_words_simple() {
let text = "Hello, world! How are you? I'm fine.";
let segmenter = WordSegmenter::new_auto();
let words = split_words(text, &segmenter);
assert_eq!(&words, &["Hello", "world", "How", "are", "you", "I'm", "fine"]);
}
#[test]
fn preprocess_auto_simple() {
let text = "Hello, world! How are you? I'm fine. https://example.com and more";
let cleaned_text = preprocess_auto(text);
assert_eq!(&cleaned_text, "Hello, world! How are you? I'm fine. and more");
}
#[test]
fn preprocess_reddit_ids() {
let text = "have a look at r/cats, user u/example posed a cute cat in there";
let cleaned_text = preprocess_auto(text);
assert_eq!(&cleaned_text, "have a look at , user posed a cute cat in there");
}
#[test]
fn preprocess_url_markdown() {
let text = "You can find an example [in the documentation](https://example.com) or on GitHub";
let cleaned_text = preprocess_auto(text);
assert_eq!(&cleaned_text, "You can find an example [in the documentation]() or on GitHub");
let segmenter = WordSegmenter::new_auto();
let words = split_words(&cleaned_text, &segmenter);
assert_eq!(&words, &["You", "can", "find", "an", "example", "in", "the", "documentation", "or", "on", "GitHub"]);
}
}

View File

@ -0,0 +1,63 @@
use icu_segmenter::{GraphemeClusterSegmenter, SentenceSegmenter, WordSegmenter};
use itertools::Itertools;
pub struct IcuSegmenterCache {
sentence_segmenter: SentenceSegmenter,
word_segmenter: WordSegmenter,
grapheme_cluster_segmenter: GraphemeClusterSegmenter,
}
impl IcuSegmenterCache {
pub fn new_auto() -> Self {
let sentence_segmenter = SentenceSegmenter::new();
let word_segmenter = WordSegmenter::new_auto();
let grapheme_cluster_segmenter = GraphemeClusterSegmenter::new();
return Self {
sentence_segmenter,
word_segmenter,
grapheme_cluster_segmenter,
};
}
pub fn split_sentences<'t>(&self, text: &'t str) -> Vec<&'t str> {
return split_sentences(text, &self.sentence_segmenter);
}
pub fn split_words<'t>(&self, text: &'t str) -> Vec<&'t str> {
return split_words(text, &self.word_segmenter);
}
pub fn split_grapheme_clusters<'t>(&self, text: &'t str) -> Vec<&'t str> {
return split_grapheme_clusters(text, &self.grapheme_cluster_segmenter);
}
}
pub fn split_sentences<'t>(text: &'t str, segmenter: &SentenceSegmenter) -> Vec<&'t str> {
let sentences: Vec<&str> = segmenter
.segment_str(text)
.tuple_windows()
.map(|(i, j)| text[i..j].trim())
.filter(|sentence| !sentence.is_empty())
.collect();
return sentences;
}
pub fn split_words<'t>(text: &'t str, segmenter: &WordSegmenter) -> Vec<&'t str> {
let words: Vec<&str> = segmenter
.segment_str(text)
.iter_with_word_type()
.tuple_windows()
.filter(|(_, (_, segment_type))| segment_type.is_word_like())
.map(|((i, _), (j, _))| &text[i..j])
.collect();
return words;
}
pub fn split_grapheme_clusters<'t>(text: &'t str, segmenter: &GraphemeClusterSegmenter) -> Vec<&'t str> {
let grapheme_clusters: Vec<&str> = segmenter
.segment_str(text)
.tuple_windows()
.map(|(i, j)| &text[i..j])
.collect();
return grapheme_clusters;
}

View File

@ -0,0 +1,13 @@
[package]
name = "flesttools"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
flest = { path = "../../libnative/flest" }
textutils = { path = "../../libnative/textutils" }
pancurses = { version = "0.17.0", features = ["wide"] }
serde = "1.0.203"
serde_json = "1.0.120"

View File

@ -0,0 +1,148 @@
use flest::NgramModel;
use textutils::IcuSegmenterCache;
use pancurses::Input;
use std::env;
use std::fs;
use std::io::BufRead;
use std::io::BufReader;
const TOKEN_SENTENCE_SEPARATOR: &str = "\\sep";
fn tokenize_text(text: &str) -> Vec<&str> {
let segmenters = IcuSegmenterCache::new_auto();
let sentences = segmenters.split_sentences(text);
let mut tokens: Vec<&str> = Vec::new();
tokens.push(TOKEN_SENTENCE_SEPARATOR);
for sentence in sentences {
let words = segmenters.split_words(sentence);
for word in words {
tokens.push(word);
}
tokens.push(TOKEN_SENTENCE_SEPARATOR);
}
//println!("Tokens: {:?}", tokens);
return tokens;
}
fn train_model(text: &str, model: &mut NgramModel) {
let text = textutils::preprocess_auto(text);
let text = text.trim();
if text.is_empty() {
return;
}
let tokens = tokenize_text(&text);
//println!("Tokens: {:?}", tokens);
let n_values = [2, 3, 4];
for &n in &n_values {
if n > tokens.len() {
continue;
}
for i in 0..tokens.len() - n + 1 {
model.train_dataset(&tokens[i..(i + n)]);
}
}
}
fn train_from_plain_text(path: &str, model: &mut NgramModel) {
let text = fs::read_to_string(path).expect("Failed to read file");
train_model(&text, model);
}
fn train_from_reddit_comments(path: &str, model: &mut NgramModel) {
let file = fs::File::open(path).expect("Failed to open file");
let reader = BufReader::new(file);
let mut line_count = 0;
for line in reader.lines() {
if let Ok(line) = line {
let json: serde_json::Value = serde_json::from_str(&line).expect("Failed to parse JSON");
if let Some(author) = json.get("author").and_then(|it| it.as_str()) {
if author == "AutoModerator" {
continue;
}
}
if let Some(body) = json.get("body").and_then(|it| it.as_str()) {
train_model(body, model);
}
}
line_count += 1;
if line_count > 10000 {
break;
}
}
}
fn main() {
let args: Vec<String> = env::args().collect();
if args.len() != 2 {
eprintln!("Usage: {} <file_path>", args[0]);
return;
}
let path = &args[1];
let mut model = NgramModel::default();
if path.ends_with(".reddit.jsonl") {
train_from_reddit_comments(path, &mut model);
} else {
train_from_plain_text(path, &mut model);
}
let window = pancurses::initscr();
let mut input_text = String::new();
pancurses::noecho();
window.keypad(true);
loop {
let mut words: Vec<&str> = input_text.split_whitespace().collect();
words.insert(0, TOKEN_SENTENCE_SEPARATOR);
if input_text.ends_with(' ') || words.last() == Some(&TOKEN_SENTENCE_SEPARATOR) {
words.push("");
}
let predictions = model.predict(&words);
window.clear();
window.addstr("N-gram model debug frontend\n");
window.addstr(" demo tokenizer only supports single-line sentence in input text!\n\n");
window.addstr(format!("enter text: {}\n", input_text));
window.addstr(format!("detected words: {:?}\n\n", words));
window.addstr("predictions:\n");
for (i, (word, weight)) in predictions.iter().enumerate() {
if i == 0 && *weight > 0.9 {
window.attron(pancurses::A_BOLD);
}
window.addstr(format!(" {}. {} (c={:.2})\n", i + 1, word, weight));
if i == 0 && *weight > 0.9 {
window.attroff(pancurses::A_BOLD);
}
}
if predictions.is_empty() {
window.addstr(" (none)\n");
}
window.mv(3, 12 + input_text.len() as i32);
window.refresh();
match window.getch().unwrap() {
Input::KeyF10 => {
break
}
Input::KeyBackspace => {
input_text.pop();
}
Input::Character('\n') => {
train_model(&input_text, &mut model)
}
Input::Character(ch) => {
input_text.push(ch)
}
_ => { () }
}
}
pancurses::endwin();
}