mirror of
https://github.com/florisboard/florisboard.git
synced 2024-09-19 19:42:20 +02:00
Add initial flest implementation
This commit is contained in:
parent
f1c5b1802b
commit
72c4f7d4d8
9
libnative/flest/Cargo.toml
Normal file
9
libnative/flest/Cargo.toml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[package]
|
||||||
|
name = "flest"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
fxhash = "0.2.1"
|
102
libnative/flest/src/dyntrie.rs
Normal file
102
libnative/flest/src/dyntrie.rs
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
use fxhash::FxHashMap;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct DynTrieNode<V> where V: Default {
|
||||||
|
children: FxHashMap<char, Box<DynTrieNode<V>>>,
|
||||||
|
value: Option<V>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V> DynTrieNode<V> where V: Default {
|
||||||
|
fn for_each_recursive<'a, F>(&'a self, current_word: &mut Vec<char>, f: &mut F)
|
||||||
|
where F: FnMut(&[char], &'a V) {
|
||||||
|
if let Some(value) = &self.value {
|
||||||
|
f(¤t_word, value);
|
||||||
|
}
|
||||||
|
for (letter, node) in &self.children {
|
||||||
|
current_word.push(*letter);
|
||||||
|
node.for_each_recursive(current_word, f);
|
||||||
|
current_word.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct DynTrie<V> where V: Default {
|
||||||
|
root: DynTrieNode<V>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V> DynTrie<V>
|
||||||
|
where V: Default {
|
||||||
|
pub fn find(&self, word: &[char]) -> Option<&V> {
|
||||||
|
let mut current_node = &self.root;
|
||||||
|
for letter in word {
|
||||||
|
match current_node.children.get(letter) {
|
||||||
|
Some(node) => current_node = node,
|
||||||
|
None => return None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return current_node.value.as_ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn str_fuzzy_match_whole(str1: &[char], str2: &[char]) -> f64 {
|
||||||
|
let len1 = str1.len();
|
||||||
|
let len2 = str2.len();
|
||||||
|
let max_len = std::cmp::max(len1, len2);
|
||||||
|
let mut score: f64 = 0.0;
|
||||||
|
let mut penalty: f64 = 0.0;
|
||||||
|
for i in 0..max_len {
|
||||||
|
let ch1 = str1.get(i).unwrap_or(&' ');
|
||||||
|
let ch2 = str2.get(i).unwrap_or(&' ');
|
||||||
|
if ch1 == ch2 {
|
||||||
|
score += 1.0;
|
||||||
|
} else if ch1.to_lowercase().eq(ch2.to_lowercase()) {
|
||||||
|
score += 0.5;
|
||||||
|
} else {
|
||||||
|
penalty += if i == 0 { 2.0 } else { 1.0 };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return f64::max(0.0, score - penalty)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: optimization: we do not need to iterate over all
|
||||||
|
// the trie, we can predict if the score will never be >= 0
|
||||||
|
// and skip the whole subtree
|
||||||
|
pub fn find_many(&self, word: &[char]) -> Vec<(Vec<char>, &V)> {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
self.for_each(&mut |current_word, value| {
|
||||||
|
let score = Self::str_fuzzy_match_whole(word, current_word);
|
||||||
|
if score > 0.0 {
|
||||||
|
results.push((current_word.to_owned(), value));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn find_or_insert(&mut self, word: &[char], value: V) -> &mut V {
|
||||||
|
let mut current_node = &mut self.root;
|
||||||
|
for letter in word {
|
||||||
|
current_node = current_node.children.entry(*letter)
|
||||||
|
.or_insert_with(|| Box::new(DynTrieNode::default()));
|
||||||
|
}
|
||||||
|
if current_node.value.is_none() {
|
||||||
|
current_node.value = Some(value);
|
||||||
|
}
|
||||||
|
return current_node.value.as_mut().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn insert(&mut self, word: &[char], value: V) {
|
||||||
|
let mut current_node = &mut self.root;
|
||||||
|
for letter in word {
|
||||||
|
current_node = current_node.children.entry(*letter)
|
||||||
|
.or_insert_with(|| Box::new(DynTrieNode::default()));
|
||||||
|
}
|
||||||
|
current_node.value = Some(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn for_each<'a, F>(&'a self, f: &mut F)
|
||||||
|
where F: FnMut(&[char], &'a V) {
|
||||||
|
let mut current_word: Vec<char> = Vec::new();
|
||||||
|
self.root.for_each_recursive(&mut current_word, f);
|
||||||
|
}
|
||||||
|
}
|
4
libnative/flest/src/lib.rs
Normal file
4
libnative/flest/src/lib.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
mod dyntrie;
|
||||||
|
mod ngrammodel;
|
||||||
|
|
||||||
|
pub use ngrammodel::*;
|
212
libnative/flest/src/ngrammodel.rs
Normal file
212
libnative/flest/src/ngrammodel.rs
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::dyntrie::DynTrie;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct NgramModelNode {
|
||||||
|
children: DynTrie<Box<NgramModelNode>>,
|
||||||
|
time: u64,
|
||||||
|
usage: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NgramModelNode {
|
||||||
|
fn find(&self, ngram: &[&str]) -> Option<&NgramModelNode> {
|
||||||
|
if ngram.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let token: Vec<char> = ngram[0].chars().collect();
|
||||||
|
let child = self.children.find(&token);
|
||||||
|
if child.is_none() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let child = child.unwrap();
|
||||||
|
if ngram.len() == 1 {
|
||||||
|
return Some(child);
|
||||||
|
}
|
||||||
|
return child.find(&ngram[1..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_many(&self, ngram: &[&str]) -> Vec<(Vec<char>, &NgramModelNode)> {
|
||||||
|
if ngram.is_empty() {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
let token: Vec<char> = ngram[0].chars().collect();
|
||||||
|
let ret = self.children.find_many(&token);
|
||||||
|
if ngram.len() == 1 {
|
||||||
|
return ret
|
||||||
|
.into_iter()
|
||||||
|
.map(|node| (node.0, node.1.as_ref()))
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
let mut ret2 = Vec::new();
|
||||||
|
for (_, child) in &ret {
|
||||||
|
ret2.extend(child.find_many(&ngram[1..]));
|
||||||
|
}
|
||||||
|
return ret2;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn train(&mut self, ngram: &[&str], current_time: u64) {
|
||||||
|
if ngram.is_empty() {
|
||||||
|
panic!("ngram must not be empty");
|
||||||
|
}
|
||||||
|
let token: Vec<char> = ngram[0].chars().collect();
|
||||||
|
let child = self.children.find_or_insert(&token, Box::new(NgramModelNode::default()));
|
||||||
|
if ngram.len() == 1 {
|
||||||
|
if current_time != 0 {
|
||||||
|
child.time = current_time;
|
||||||
|
}
|
||||||
|
child.usage += 1;
|
||||||
|
} else {
|
||||||
|
child.train(&ngram[1..], current_time);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn debug_print(&self, _indent: usize) {
|
||||||
|
// println!("{}{}{}", " ".repeat(indent), self.token, if self.time > 0 { "*" } else { "" });
|
||||||
|
// for child in &self.children {
|
||||||
|
// child.debug_print(indent + 1);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct NgramModel {
|
||||||
|
root: NgramModelNode,
|
||||||
|
time: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NgramModel {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn find(&self, ngram: &[&str]) -> Option<&NgramModelNode> {
|
||||||
|
self.root.find(ngram)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_many(&self, ngram: &[&str]) -> Vec<(Vec<char>, &NgramModelNode)> {
|
||||||
|
self.root.find_many(ngram)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train_dataset(&mut self, token_list: &[&str]) {
|
||||||
|
self.root.train(token_list, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train_input(&mut self, token_list: &[&str]) {
|
||||||
|
self.time += 1;
|
||||||
|
self.root.train(token_list, self.time);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn debug_print(&self) {
|
||||||
|
self.root.debug_print(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict(&self, history: &Vec<&str>) -> Vec<(String, f64)> {
|
||||||
|
let mut tmin = u64::MAX;
|
||||||
|
let mut tmax = u64::MIN;
|
||||||
|
let mut umin = u64::MAX;
|
||||||
|
let mut umax = u64::MIN;
|
||||||
|
let nmin = 1;
|
||||||
|
let nmax = 3;
|
||||||
|
let mut candidate_nodes: Vec<(Vec<char>, &NgramModelNode, f64)> = Vec::new();
|
||||||
|
|
||||||
|
let user_input_word = history.last().unwrap_or(&"");
|
||||||
|
|
||||||
|
for n in nmin..=std::cmp::min(history.len(), nmax) {
|
||||||
|
let nweight = 1.0 - (nmax - n) as f64 * 0.1;
|
||||||
|
let ngram = &history[history.len() - n..history.len() - 1];
|
||||||
|
let nodes = self.find_many(ngram);
|
||||||
|
for (_, node) in nodes {
|
||||||
|
node.children.for_each(&mut |curr_word, child| {
|
||||||
|
candidate_nodes.push((curr_word.to_owned(), child, nweight));
|
||||||
|
tmin = tmin.min(child.time);
|
||||||
|
tmax = tmax.max(child.time);
|
||||||
|
umin = umin.min(child.usage);
|
||||||
|
umax = umax.max(child.usage);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate_nodes = candidate_nodes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(word, node, nweight)| {
|
||||||
|
(
|
||||||
|
word,
|
||||||
|
node,
|
||||||
|
nweight
|
||||||
|
* norm_weight(node.time, tmin, tmax)
|
||||||
|
* norm_weight(node.usage, umin, umax),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if !user_input_word.is_empty() {
|
||||||
|
let user_input_word: Vec<char> = user_input_word.chars().collect();
|
||||||
|
let mut filtered_nodes = Vec::new();
|
||||||
|
for (word, node, weight) in candidate_nodes {
|
||||||
|
let score_len = std::cmp::min(
|
||||||
|
(word.len() + user_input_word.len()) / 2,
|
||||||
|
user_input_word.len(),
|
||||||
|
) as f64;
|
||||||
|
let score = str_fuzzy_match_live(&word, &user_input_word);
|
||||||
|
if score > 0.0 {
|
||||||
|
let new_weight = 0.95 * (score / score_len) + 0.05 * weight;
|
||||||
|
filtered_nodes.push((word, node, new_weight));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.root.children.for_each(&mut |word, node| {
|
||||||
|
let score_len = std::cmp::min(
|
||||||
|
(word.len() + user_input_word.len()) / 2,
|
||||||
|
user_input_word.len(),
|
||||||
|
) as f64;
|
||||||
|
let score = str_fuzzy_match_live(&word, &user_input_word);
|
||||||
|
if score > 0.0 {
|
||||||
|
let new_weight = 0.75 * (score / score_len) + 0.25 * 0.0;
|
||||||
|
filtered_nodes.push((word.to_owned(), node, new_weight));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
candidate_nodes = filtered_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate_nodes.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
|
||||||
|
|
||||||
|
let mut predictions: HashMap<String, f64> = HashMap::new();
|
||||||
|
for (word, _, weight) in candidate_nodes {
|
||||||
|
predictions
|
||||||
|
.entry(word.iter().collect())
|
||||||
|
.or_insert(weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut predictions_vec: Vec<(String, f64)> = predictions.into_iter().collect();
|
||||||
|
predictions_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||||
|
|
||||||
|
predictions_vec.into_iter().take(8).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn norm_weight(x: u64, xmin: u64, xmax: u64) -> f64 {
|
||||||
|
if x <= xmin {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
if x >= xmax {
|
||||||
|
return 1.0;
|
||||||
|
}
|
||||||
|
let xnorm = (x - xmin) as f64 / (xmax - xmin) as f64;
|
||||||
|
return 2.0 * xnorm - xnorm.powi(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn str_fuzzy_match_live(word: &[char], current_word: &[char]) -> f64 {
|
||||||
|
//let len1 = word.len();
|
||||||
|
let len2 = current_word.len();
|
||||||
|
let mut score = 0.0;
|
||||||
|
let mut penalty: f64 = 0.0;
|
||||||
|
for i in 0..len2 {
|
||||||
|
let ch1 = word.get(i).unwrap_or(&' ');
|
||||||
|
let ch2 = current_word.get(i).unwrap_or(&' ');
|
||||||
|
if ch1 == ch2 {
|
||||||
|
score += 1.0;
|
||||||
|
} else if ch1.to_lowercase().eq(ch2.to_lowercase()) {
|
||||||
|
score += 0.9;
|
||||||
|
} else {
|
||||||
|
penalty += if i == 0 { 2.0 } else { 1.0 };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return f64::max(0.0, score - 0.125 * penalty.powi(2));
|
||||||
|
}
|
13
libnative/textutils/Cargo.toml
Normal file
13
libnative/textutils/Cargo.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[package]
|
||||||
|
name = "textutils"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
icu_segmenter = "1.5.0"
|
||||||
|
itertools = "0.13.0"
|
||||||
|
lazy_static = "1.5.0"
|
||||||
|
linkify = "0.10.0"
|
||||||
|
regex = "1.10.5"
|
20
libnative/textutils/src/filter.rs
Normal file
20
libnative/textutils/src/filter.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use lazy_static::lazy_static;
|
||||||
|
use linkify::{self, LinkFinder};
|
||||||
|
use regex::Regex;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref LINK_FINDER: LinkFinder = LinkFinder::new();
|
||||||
|
static ref REDDIT_REGEX: Regex = Regex::new(r"\/?(r\/[a-zA-Z0-9_]{3}[a-zA-Z0-9_]{0,18}|u\/[a-zA-Z0-9_-]{3}[a-zA-Z0-9_-]{0,17})").unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess_auto(text: &str) -> String {
|
||||||
|
let mut cleaned_text = String::new();
|
||||||
|
let mut begin_cleaned_index = 0;
|
||||||
|
for span in LINK_FINDER.links(text) {
|
||||||
|
cleaned_text.push_str(&text[begin_cleaned_index..span.start()]);
|
||||||
|
begin_cleaned_index = span.end();
|
||||||
|
}
|
||||||
|
cleaned_text.push_str(&text[begin_cleaned_index..]);
|
||||||
|
cleaned_text = REDDIT_REGEX.replace_all(&cleaned_text, "").to_string();
|
||||||
|
return cleaned_text;
|
||||||
|
}
|
52
libnative/textutils/src/lib.rs
Normal file
52
libnative/textutils/src/lib.rs
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
mod filter;
|
||||||
|
mod segment;
|
||||||
|
|
||||||
|
pub use filter::*;
|
||||||
|
pub use segment::*;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use icu_segmenter::{SentenceSegmenter, WordSegmenter};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn segment_sentences_simple() {
|
||||||
|
let text = "Hello, world! How are you? I'm fine.";
|
||||||
|
let segmenter = SentenceSegmenter::new();
|
||||||
|
let sentences = split_sentences(text, &segmenter);
|
||||||
|
assert_eq!(&sentences, &["Hello, world!", "How are you?", "I'm fine."]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn segment_words_simple() {
|
||||||
|
let text = "Hello, world! How are you? I'm fine.";
|
||||||
|
let segmenter = WordSegmenter::new_auto();
|
||||||
|
let words = split_words(text, &segmenter);
|
||||||
|
assert_eq!(&words, &["Hello", "world", "How", "are", "you", "I'm", "fine"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_auto_simple() {
|
||||||
|
let text = "Hello, world! How are you? I'm fine. https://example.com and more";
|
||||||
|
let cleaned_text = preprocess_auto(text);
|
||||||
|
assert_eq!(&cleaned_text, "Hello, world! How are you? I'm fine. and more");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_reddit_ids() {
|
||||||
|
let text = "have a look at r/cats, user u/example posed a cute cat in there";
|
||||||
|
let cleaned_text = preprocess_auto(text);
|
||||||
|
assert_eq!(&cleaned_text, "have a look at , user posed a cute cat in there");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_url_markdown() {
|
||||||
|
let text = "You can find an example [in the documentation](https://example.com) or on GitHub";
|
||||||
|
let cleaned_text = preprocess_auto(text);
|
||||||
|
assert_eq!(&cleaned_text, "You can find an example [in the documentation]() or on GitHub");
|
||||||
|
let segmenter = WordSegmenter::new_auto();
|
||||||
|
let words = split_words(&cleaned_text, &segmenter);
|
||||||
|
assert_eq!(&words, &["You", "can", "find", "an", "example", "in", "the", "documentation", "or", "on", "GitHub"]);
|
||||||
|
}
|
||||||
|
}
|
63
libnative/textutils/src/segment.rs
Normal file
63
libnative/textutils/src/segment.rs
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
use icu_segmenter::{GraphemeClusterSegmenter, SentenceSegmenter, WordSegmenter};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
pub struct IcuSegmenterCache {
|
||||||
|
sentence_segmenter: SentenceSegmenter,
|
||||||
|
word_segmenter: WordSegmenter,
|
||||||
|
grapheme_cluster_segmenter: GraphemeClusterSegmenter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IcuSegmenterCache {
|
||||||
|
pub fn new_auto() -> Self {
|
||||||
|
let sentence_segmenter = SentenceSegmenter::new();
|
||||||
|
let word_segmenter = WordSegmenter::new_auto();
|
||||||
|
let grapheme_cluster_segmenter = GraphemeClusterSegmenter::new();
|
||||||
|
return Self {
|
||||||
|
sentence_segmenter,
|
||||||
|
word_segmenter,
|
||||||
|
grapheme_cluster_segmenter,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_sentences<'t>(&self, text: &'t str) -> Vec<&'t str> {
|
||||||
|
return split_sentences(text, &self.sentence_segmenter);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_words<'t>(&self, text: &'t str) -> Vec<&'t str> {
|
||||||
|
return split_words(text, &self.word_segmenter);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_grapheme_clusters<'t>(&self, text: &'t str) -> Vec<&'t str> {
|
||||||
|
return split_grapheme_clusters(text, &self.grapheme_cluster_segmenter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_sentences<'t>(text: &'t str, segmenter: &SentenceSegmenter) -> Vec<&'t str> {
|
||||||
|
let sentences: Vec<&str> = segmenter
|
||||||
|
.segment_str(text)
|
||||||
|
.tuple_windows()
|
||||||
|
.map(|(i, j)| text[i..j].trim())
|
||||||
|
.filter(|sentence| !sentence.is_empty())
|
||||||
|
.collect();
|
||||||
|
return sentences;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_words<'t>(text: &'t str, segmenter: &WordSegmenter) -> Vec<&'t str> {
|
||||||
|
let words: Vec<&str> = segmenter
|
||||||
|
.segment_str(text)
|
||||||
|
.iter_with_word_type()
|
||||||
|
.tuple_windows()
|
||||||
|
.filter(|(_, (_, segment_type))| segment_type.is_word_like())
|
||||||
|
.map(|((i, _), (j, _))| &text[i..j])
|
||||||
|
.collect();
|
||||||
|
return words;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_grapheme_clusters<'t>(text: &'t str, segmenter: &GraphemeClusterSegmenter) -> Vec<&'t str> {
|
||||||
|
let grapheme_clusters: Vec<&str> = segmenter
|
||||||
|
.segment_str(text)
|
||||||
|
.tuple_windows()
|
||||||
|
.map(|(i, j)| &text[i..j])
|
||||||
|
.collect();
|
||||||
|
return grapheme_clusters;
|
||||||
|
}
|
13
utils/flesttools/Cargo.toml
Normal file
13
utils/flesttools/Cargo.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[package]
|
||||||
|
name = "flesttools"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
flest = { path = "../../libnative/flest" }
|
||||||
|
textutils = { path = "../../libnative/textutils" }
|
||||||
|
pancurses = { version = "0.17.0", features = ["wide"] }
|
||||||
|
serde = "1.0.203"
|
||||||
|
serde_json = "1.0.120"
|
148
utils/flesttools/src/main.rs
Normal file
148
utils/flesttools/src/main.rs
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
use flest::NgramModel;
|
||||||
|
use textutils::IcuSegmenterCache;
|
||||||
|
use pancurses::Input;
|
||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::io::BufRead;
|
||||||
|
use std::io::BufReader;
|
||||||
|
|
||||||
|
const TOKEN_SENTENCE_SEPARATOR: &str = "\\sep";
|
||||||
|
|
||||||
|
fn tokenize_text(text: &str) -> Vec<&str> {
|
||||||
|
let segmenters = IcuSegmenterCache::new_auto();
|
||||||
|
let sentences = segmenters.split_sentences(text);
|
||||||
|
let mut tokens: Vec<&str> = Vec::new();
|
||||||
|
|
||||||
|
tokens.push(TOKEN_SENTENCE_SEPARATOR);
|
||||||
|
for sentence in sentences {
|
||||||
|
let words = segmenters.split_words(sentence);
|
||||||
|
for word in words {
|
||||||
|
tokens.push(word);
|
||||||
|
}
|
||||||
|
tokens.push(TOKEN_SENTENCE_SEPARATOR);
|
||||||
|
}
|
||||||
|
|
||||||
|
//println!("Tokens: {:?}", tokens);
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn train_model(text: &str, model: &mut NgramModel) {
|
||||||
|
let text = textutils::preprocess_auto(text);
|
||||||
|
let text = text.trim();
|
||||||
|
if text.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let tokens = tokenize_text(&text);
|
||||||
|
//println!("Tokens: {:?}", tokens);
|
||||||
|
let n_values = [2, 3, 4];
|
||||||
|
|
||||||
|
for &n in &n_values {
|
||||||
|
if n > tokens.len() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for i in 0..tokens.len() - n + 1 {
|
||||||
|
model.train_dataset(&tokens[i..(i + n)]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn train_from_plain_text(path: &str, model: &mut NgramModel) {
|
||||||
|
let text = fs::read_to_string(path).expect("Failed to read file");
|
||||||
|
train_model(&text, model);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn train_from_reddit_comments(path: &str, model: &mut NgramModel) {
|
||||||
|
let file = fs::File::open(path).expect("Failed to open file");
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let mut line_count = 0;
|
||||||
|
for line in reader.lines() {
|
||||||
|
if let Ok(line) = line {
|
||||||
|
let json: serde_json::Value = serde_json::from_str(&line).expect("Failed to parse JSON");
|
||||||
|
|
||||||
|
if let Some(author) = json.get("author").and_then(|it| it.as_str()) {
|
||||||
|
if author == "AutoModerator" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(body) = json.get("body").and_then(|it| it.as_str()) {
|
||||||
|
train_model(body, model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
line_count += 1;
|
||||||
|
if line_count > 10000 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let args: Vec<String> = env::args().collect();
|
||||||
|
if args.len() != 2 {
|
||||||
|
eprintln!("Usage: {} <file_path>", args[0]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = &args[1];
|
||||||
|
let mut model = NgramModel::default();
|
||||||
|
|
||||||
|
if path.ends_with(".reddit.jsonl") {
|
||||||
|
train_from_reddit_comments(path, &mut model);
|
||||||
|
} else {
|
||||||
|
train_from_plain_text(path, &mut model);
|
||||||
|
}
|
||||||
|
|
||||||
|
let window = pancurses::initscr();
|
||||||
|
let mut input_text = String::new();
|
||||||
|
|
||||||
|
pancurses::noecho();
|
||||||
|
window.keypad(true);
|
||||||
|
loop {
|
||||||
|
let mut words: Vec<&str> = input_text.split_whitespace().collect();
|
||||||
|
words.insert(0, TOKEN_SENTENCE_SEPARATOR);
|
||||||
|
|
||||||
|
if input_text.ends_with(' ') || words.last() == Some(&TOKEN_SENTENCE_SEPARATOR) {
|
||||||
|
words.push("");
|
||||||
|
}
|
||||||
|
|
||||||
|
let predictions = model.predict(&words);
|
||||||
|
|
||||||
|
window.clear();
|
||||||
|
window.addstr("N-gram model debug frontend\n");
|
||||||
|
window.addstr(" demo tokenizer only supports single-line sentence in input text!\n\n");
|
||||||
|
window.addstr(format!("enter text: {}\n", input_text));
|
||||||
|
window.addstr(format!("detected words: {:?}\n\n", words));
|
||||||
|
window.addstr("predictions:\n");
|
||||||
|
for (i, (word, weight)) in predictions.iter().enumerate() {
|
||||||
|
if i == 0 && *weight > 0.9 {
|
||||||
|
window.attron(pancurses::A_BOLD);
|
||||||
|
}
|
||||||
|
window.addstr(format!(" {}. {} (c={:.2})\n", i + 1, word, weight));
|
||||||
|
if i == 0 && *weight > 0.9 {
|
||||||
|
window.attroff(pancurses::A_BOLD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if predictions.is_empty() {
|
||||||
|
window.addstr(" (none)\n");
|
||||||
|
}
|
||||||
|
window.mv(3, 12 + input_text.len() as i32);
|
||||||
|
window.refresh();
|
||||||
|
|
||||||
|
match window.getch().unwrap() {
|
||||||
|
Input::KeyF10 => {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
Input::KeyBackspace => {
|
||||||
|
input_text.pop();
|
||||||
|
}
|
||||||
|
Input::Character('\n') => {
|
||||||
|
train_model(&input_text, &mut model)
|
||||||
|
}
|
||||||
|
Input::Character(ch) => {
|
||||||
|
input_text.push(ch)
|
||||||
|
}
|
||||||
|
_ => { () }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pancurses::endwin();
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user