1 Executive summary

Many modern mobile keyboards offer an option to autocomplete a word or a phrase the user started typing, in order to save keystrokes. These days they usually use a recurrent neural network (RNN) model to make predictions; while this approach provides better accuracy, an RNN is computationally expensive to train and requires a lot of training data.

In this notebook, I create a model with 9.5 million ngrams that

2 Data and exploratory analysis

2.1 Overview and sampling

This project was initially done as the capstone for Johns Hopkins university Data Science specialization. The author thanks SwiftKey for providing the data set.

The data came as a set of .txt files harvested from Twitter, news and blog websites, all in English language.

file.info('data/en_US/en_US.blogs.txt', 'data/en_US/en_US.news.txt', 'data/en_US/en_US.twitter.txt')[1]
##                                   size
## data/en_US/en_US.blogs.txt   210160014
## data/en_US/en_US.news.txt    205811885
## data/en_US/en_US.twitter.txt 167105338

With files ranging from 159 to 200 MB in size, or from 900 thousand to 2.3 million lines each, it makes sense to use a small sample for the exploratory analysis, before training the model on a full data set. The code for createSample() function is listed in the Appendix (as well as other helper functions). Here the calls are not evaluated as I already have the sample files.

source('createSample.R')
# Create ~50,000 line train samples for each of the en_US original files
set.seed(2020)
createSample('./data/en_US/en_US.twitter.txt', './data/en_US/sample.twitter.txt')
createSample('./data/en_US/en_US.blogs.txt', './data/en_US/sample.blogs.txt')
createSample('./data/en_US/en_US.news.txt', './data/en_US/sample.news.txt')
file.info('./data/en_US/sample.blogs.txt', './data/en_US/sample.news.txt', './data/en_US/sample.twitter.txt')[1]
##                                     size
## ./data/en_US/sample.blogs.txt   11234872
## ./data/en_US/sample.news.txt     9616921
## ./data/en_US/sample.twitter.txt  3455372

Now file sizes are between 3 and 10 MB which is more manageable. For further processing the following libraries will be needed:

library(ggplot2) # plotting
library(scales) # labels for plotting
library(wordcloud) # plotting word clouds
library(quanteda) # text mining
library(dplyr) # data processing
library(data.table) # data storage and search optimization

2.2 Exploratory analysis

2.2.1 View the data

Just a glimpse of what the first few lines in the Twitter sample look like:

readLines('./data/en_US/sample.twitter.txt', n=5)
## [1] "SDG&E LED Holiday Light Exchange Brings Energy Savings to Balboa Park, Chula ... - MarketWatch (press release) SDG&E LED Holiday Ligh"
## [2] "Manners. .MF. do you have those?"                                                                                                     
## [3] "Also Welcome new followers"                                                                                                           
## [4] "I need driver anger management"                                                                                                       
## [5] "Ready for valentines day! wonder who my first kiss will be?"

2.2.2 Preprocessing and building a corpus

It would be convenient to work with a series of words, all lowercase, without punctuation marks and without any weird symbols in it. For that ends, I will use quanteda corpus object split into one-sentence documents that is further tokenized into separate words. The rationale is that it is separate words the system will predict, and they will be predicted within a sentence context.

All these tasks are done by two helper functions: files2sentences() and str2tokens(). See Appendix for code.

source('file2sentences.R')
source('str2tokens.R')
train.corpus <- files2sentences(c('./data/en_US/sample.twitter.txt'
                                  , './data/en_US/sample.blogs.txt'
                                  , './data/en_US/sample.news.txt'))
head(train.corpus)
## Corpus consisting of 6 documents.
## text1.1 :
## "SDG&E LED Holiday Light Exchange Brings Energy Savings to Ba..."
## 
## text2.1 :
## "Manners. ."
## 
## text2.2 :
## "MF. do you have those?"
## 
## text3.1 :
## "Also Welcome new followers"
## 
## text4.1 :
## "I need driver anger management"
## 
## text5.1 :
## "Ready for valentines day!"
train.tokens <- str2tokens(train.corpus)
rm(train.corpus)
head(train.tokens)
## Tokens consisting of 6 documents.
## text1.1 :
##  [1] "SDG"      "E"        "LED"      "Holiday"  "Light"    "Exchange"
##  [7] "Brings"   "Energy"   "Savings"  "to"       "Balboa"   "Park"    
## [ ... and 9 more ]
## 
## text2.1 :
## [1] "Manners"
## 
## text2.2 :
## [1] "MF"    "do"    "you"   "have"  "those"
## 
## text3.1 :
## [1] "Also"      "Welcome"   "new"       "followers"
## 
## text4.1 :
## [1] "I"          "need"       "driver"     "anger"      "management"
## 
## text5.1 :
## [1] "Ready"      "for"        "valentines" "day"

2.2.3 Word coverage and dictionary

An ngram is simply a sequence of n words in a text. The simplest case is a unigram which is but a single word. Let’s build a frequency table for separate words and see how many of those are needed to cover 50%, 90%, 95% and 99% of all text.

source('nFreq.R')
words <- nFreq(1, train.tokens)
words$perc <- words$frequency / sum(words$frequency)
words$cumperc <- cumsum(words$perc)

cover <- vector(mode = 'integer')
cover["50"] <- min(which(words$cumperc>=.5))
cover["90"] <- min(which(words$cumperc>=.9))
cover["95"] <- min(which(words$cumperc>=.95))
cover["99"] <- min(which(words$cumperc>=.99))

ggplot(data=words, aes(x=as.integer(rownames(words)), y=cumperc)) +
        geom_point(color='steelblue') + 
        scale_x_continuous(
                labels = comma, 
                trans="log10", 
                breaks=c(
                        10, cover["50"], 1000, cover["90"], cover["95"], cover["99"]
                        ), 
                minor_breaks = NULL
                ) +
        theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)) +
        labs(x='Words count (log scale)', y='Cumulative % covered') +
        geom_hline(yintercept=.5) +
        geom_hline(yintercept=.9) +
        geom_hline(yintercept=.95) +
        geom_hline(yintercept=.99) +
        scale_y_continuous(
                breaks = c(.25, .75, as.numeric(names(cover))/100), 
                labels = label_percent(accuracy = 1L), 
                minor_breaks = NULL
                )

This is a plot of which % of the text can be covered by using a certain number of most frequent words. It follows a known theoretical distribution described by Zipf’s law. This fact is tangential to further analysis so I will not bother with proper statistical tests like a QQ plot to confirm this. The imporant takeaway is that

  • 50% of text can be covered with only 138 words
  • 90% of text are covered by 8021 words
  • 95% need 19523 words
  • 99% need 82999 words

Based on this, it looks reasonable to use top 20,000 words as a dictionary, which will cover a bit over 95% of all text. I initally tried to be more ambitious by using a 50,000 word dictionary that covered 98% of text, but the resulting model took almost 4 GB of RAM before pruning (more on that later), and was not even better in terms of accuracy.

dict20 <- words[1:20000]$feature

2.2.4 Most frequent words

Unsurprisingly, English language consists mostly of function words like “the”, “to”, “and”, “a” etc, with a few personal pronouns and modal verbs sneaking into top-30 chart.

ggplot(data=words[1:30,], aes(x=reorder(feature, -perc), y=perc)) +
        geom_bar(stat='identity', fill='chocolate3') + 
        theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)) +
        xlab('word') +
        scale_y_continuous(name = '% of text', labels = label_percent(accuracy = 1L))

2.2.5 Ngram frequencies visualization

Now that we have a dictionary, it is time to build ngram frequency tables, that is, to understand how often not just single words, but more complex bigrams, trigrams etc are observed in the corpus.

ngrams <- lapply(1:6, nFreq, train.tokens) # Calculate raw ngram frequencies
MB <- round(object.size(ngrams)/2^20) # Size of all the tables in MB
count <- round(sum(sapply(ngrams, nrow))/10^6, 1) # Total number of ngrams in millions
head(ngrams[[3]]) # Top trigrams, as an example
##       feature frequency
## 1  one of the      1500
## 2    a lot of      1268
## 3     to be a       740
## 4  out of the       725
## 5  the end of       692
## 6 going to be       683

As of now, the entire model takes 1951 MB in memory and contains 12.8 million different ngrams. Instead of plotting frequencies, I will show word clouds for uni-, bi- and trigrams on these tabs.

2.2.5.1 Unigram cloud

wordcloud(ngrams[[1]]$feature, ngrams[[1]]$frequency, max.words=200, random.order=FALSE, rot.per=0.35,            colors=brewer.pal(8, "Dark2"))

2.2.5.2 Bigram cloud

wordcloud(ngrams[[2]]$feature, ngrams[[2]]$frequency, max.words=120, random.order=FALSE, rot.per=0.35,            colors=brewer.pal(8, "Dark2"))

2.2.5.3 Trigram cloud

wordcloud(ngrams[[3]]$feature, ngrams[[3]]$frequency, max.words=50, random.order=FALSE, rot.per=0.35,            colors=brewer.pal(8, "Dark2"))

3 Toy model

In this section I will build a toy model, using only the sampled data (about 2% of the original data), evaluate it and make some predictions. Next section will be dedicated to building a full production model with the entire data set, which will require batch processing and a few hours on my laptop.

3.1 Possible approaches

3.1.1 Building a probability distribution

There are many ways of predicting the next word based on a known sequence of preceding words. If you google ngram language models, you will find many articles aimed at constructing a probabilistic model, that is, one that will produce a probability distribution of what the n-th word may be, given a preceding n-1 words:

P(Wn | W1W2..Wn-1)

  • If ngram W1W2..Wn is observed in the data, we estimate the required probability using maximum likelihood.
  • If it is not observed, statisticians have come up with many clever ways to still infere its probability using some sort of smoothing, either by discounting some probability mass from observed to unobserved ngrams, or by interpolating unobserved probabilities from observed ones.

I initially tried one of such fancy methonds called Katz backoff.

However, in order to actually predict the next word, one has to calculate predictions for all the words in the dictionary and choose the highest one. It can be optimized in several ways, e.g. keeping track of not-yet-calculated probability mass and stopping when it is lower than the highest already calculated probability. But even with optimization, this technique needed minutes to produce a prediction. It also required storing all the observed frequencies as integers, which was consuming a lot of memory.

My conclusion was that while probabilistic models are unavoidable when say building a text generation algorith (that will have to sample from a probability distribution), they are quite an overkill for predicting just a single next word.

3.1.2 Stupid backoff

I then resorted to a much simpler algorith called stupid backoff. When it was first suggested, it was called “stupid” because nobody believed it would actually work, but it turned out that with enough data it approaches Kneser-Ney smoothing performance.

My implementation is further simplified and works like this. Suppose we have ngram frequency tables from 1 to n and an input string split into tokens (separate words). Then

  1. Take n-1 last words from the input and try to find a n-gram that starts with them. If found, return the most frequent last word.
  2. If nothing found, back off to a smaller value of n: use only n-2 last words from input and look them up in the (n-1)-gram table.
  3. Continue doing so until something is found.
  4. If nothing is found, predict the most frequent unigram (that is, the most frequent word in the language).

Now you can see that we do not even need to store the entire frequency tables! For each combination of (n-1) history words, we’re only interested in the top prediction. However, just to give user some options, we can store say top-3 predictions.

3.2 Building the model

3.2.1 Removing out of dictionary words

Ngrams we currently have are raw, that is, contain out-of-dictionary (OOD) words. The following code, using removeOOD() helper function,

  1. Splits a characater ngram into separate words
  2. Replaces OOD words with <UNK> token
  3. Recalculates frequencies by collapsing equivalent ngrams (e.g. “I love XXX” and “I love YYY” both became “I love <UNK>” if XXX and YYY are out of dictionary -> their frequencies should be counted together)
  4. Transforms data.frame to data.table
  5. Sets all ngrams words as keys (but not the frequency column) to accelerate binary search
source('removeOOD.R')

# Steps 1-4: split ngrams into words, replace OOD words, recalculate frequencies, transform to data.table
ngrams <- lapply(ngrams, removeOOD, dict20) 

# Step 5: set keys (all columns except the last one which is frequency)
ngrams <- lapply(ngrams, function(x) 
        setkeyv(x, colnames(x)[1:length(colnames(x))-1])
       )
colnames(ngrams[[1]])[1] <- 'X1' # Fix column name for unigrams
MB <- round(object.size(ngrams)/2^20) # Size of all the tables in MB
count <- round(sum(sapply(ngrams, nrow))/10^6,1) # Total number of ngrams in millions
ngrams[[3]] # Top trigrams again
##             X1      X2           X3 frequency
##       1:    -d mission accomplished         1
##       2:  -the    boys         were         1
##       3:  -the classic         diva         1
##       4:  -the islamic    worldview         1
##       5:  -the   jokes         seth         1
##      ---                                     
## 2343628: <UNK>   <UNK>          zen         1
## 2343629: <UNK>   <UNK>         zero         1
## 2343630: <UNK>   <UNK>          zip         1
## 2343631: <UNK>   <UNK>         zone         1
## 2343632: <UNK>   <UNK>         zuma         2

You can see that ngrams are now sorted not by frequency descending, but alphabetically by their component words. That is because of setting the keys, it will help to find needed ngrams faster when making predictions based on full data set. See data.table vignette for details.

An update on the model size: now it takes 970 MB in memory and contains 11.7 million ngrams.

That’s it, this is basically my toy model. Before evaluating the model, and when building a production one on a full data set, I will keep only top-3 for each (n-1) history words combination, but for now I keep all the possible predictions just to show how it works.

3.2.2 Making predictions

combined_predict() function implements the algorithm described in the Stupid backoff section. The code, as always, is listed in the Appendix. Let’s use it to make some predictions and see where they come from (that is, what were the original ngrams with their frequencies that contributed to the prediction).

source("combined_predict.R")
p <- list()
p[[1]] <- combined_predict("Don't", ngrams, dict20, verbose=TRUE)
p[[2]] <- combined_predict("Bill loves", ngrams, dict20, verbose=TRUE)
p[[3]] <- combined_predict("You shall not", ngrams, dict20, verbose=TRUE)
p[[4]] <- combined_predict("Never have I ever", ngrams, dict20, verbose=TRUE)

With verbose=TRUE, combined_predict() returns a list of two elements:

  1. Character vector of top-3 predictions
  2. A list of observed ngrams with frequencies that start with the same n-1 words. Only 3 of them will have contributed to the top-3 prediction, but it is interesting to see what other options were.

3.2.2.1 Don’t

p[[1]]
## [[1]]
## [1] "know" "have" "want"
## 
## [[2]]
## [[2]][[1]]
##         X1      X2 frequency
##   1: don't    know       470
##   2: don't    have       390
##   3: don't    want       291
##   4: don't   think       284
##   5: don't    like       140
##  ---                        
## 480: don't witness         1
## 481: don't   would         1
## 482: don't   wreck         1
## 483: don't   young         1
## 484: don't    your         1
## 
## [[2]][[2]]
##             X1 frequency
##     1:     the    209797
##     2:      to    115046
##     3:     and    108883
##     4:       a    101389
##     5:      of     90537
##    ---                  
## 19996:    zinc         9
## 19997:    zine         9
## 19998:  zipped         9
## 19999: zooming         9
## 20000:   zumba         9

This is easy, there is a lot of bigrams starting with “don’t” so we just use top 3.

3.2.2.2 Bill loves

p[[2]]
## [[1]]
## [1] "to"  "you" "me" 
## 
## [[2]]
## [[2]][[1]]
##         X1       X2 frequency
##   1: loves       to        42
##   2: loves      you        30
##   3: loves       me        20
##   4: loves      the        14
##   5: loves       it        10
##  ---                         
## 112: loves watching         1
## 113: loves     will         1
## 114: loves      won         1
## 115: loves  working         1
## 116: loves  writing         1
## 
## [[2]][[2]]
##             X1 frequency
##     1:     the    209797
##     2:      to    115046
##     3:     and    108883
##     4:       a    101389
##     5:      of     90537
##    ---                  
## 19996:    zinc         9
## 19997:    zine         9
## 19998:  zipped         9
## 19999: zooming         9
## 20000:   zumba         9

Apparently there is no trigrams starting with exactly “Bill loves”, so the system backs off to bigrams and select top 3 from the many starting with “loves”.

3.2.2.3 You shall not

p[[3]]
## [[1]]
## [1] "go"     "be"     "resist"
## 
## [[2]]
## [[2]][[1]]
##     X1    X2  X3 X4 frequency
## 1: you shall not go         1
## 
## [[2]][[2]]
##       X1  X2     X3 frequency
## 1: shall not     be         7
## 2: shall not resist         2
## 3: shall not    die         1
## 4: shall not     go         1
## 5: shall not return         1
## 6: shall not suffer         1

A more interesting case: there is a quadgram “you shall not go” but only one. As we want top 3 predictions, the model also looks into trigrams and finds six starting with “shall not”, two of which contribute to prediction.

3.2.2.4 Never have I ever

p[[4]]
## [[1]]
## [1] "felt"      "mentioned" "seen"     
## 
## [[2]]
## [[2]][[1]]
##       X1   X2 X3   X4   X5 frequency
## 1: never have  i ever felt         1
## 
## [[2]][[2]]
##      X1 X2   X3        X4 frequency
## 1: have  i ever mentioned         2
## 2: have  i ever      felt         1
## 3: have  i ever      seen         1

Similar with the previous one, there is exactly one quintagram that uses the entire provided input - “never have i ever felt”, so the model need to looks for quadgrams starting with “have i ever” to fill in all the needed top 3 predictions.

3.2.3 Grooming the model

Now that we’ve seen some original counts, I can remove unneeded information:

  1. Keep only top-3 prediction for each n-1 word combination

  2. Replace frequencies with prediction ranks (factor with levels 1, 2, 3) as we are no longer intersted in the fact that prediction A was observed 1500 times and prediction B 666 times, only that A was more frequent than B.

  3. Pruning: it is often advisable to remove ngrams with counts below a certain minimum. I will try two options and compare them:

     - no pruning
     - remove all singletons (ngrams observed only once)
source('keep3.R')
ngrams_pruned <- lapply(ngrams, function(x) x[frequency>1])
ngrams_pruned <- lapply(ngrams_pruned, keep3)
ngrams <- lapply(ngrams, keep3)

Just a peek on what the new ngram tables look like, using bigrams as an example:

ngrams[[2]]
##              X1       X2 rank
##     1:       -d  mission    1
##     2:     -the   united    1
##     3:     -the  thought    2
##     4:     -the segments    3
##     5: #believe       in    1
##    ---                       
## 59810:  zumwalt    north    2
## 59811:  zumwalt     east    3
## 59812:    <UNK>      and    1
## 59813:    <UNK>      the    2
## 59814:    <UNK>       of    3
ngrams_pruned[[2]]
##              X1    X2 rank
##     1: #believe    in    1
##     2: #brewers  need    1
##     3: #brewers  game    2
##     4:    #caps  have    1
##     5:    #caps  fans    2
##    ---                    
## 41556:  zumwalt north    2
## 41557:  zumwalt  east    3
## 41558:    <UNK>   and    1
## 41559:    <UNK>   the    2
## 41560:    <UNK>    of    3

Not pruned bigrams have now have 59814 rows, that is, almost exactly 3 predictions for each starter word. Pruned ones have only 41560 rows.

The usual update on model size:

MB <- round(object.size(ngrams)/2^20)
MB_pruned <- round(object.size(ngrams_pruned)/2^20)
count <- round(sum(sapply(ngrams, nrow))/10^6,1)
count_pruned <- round(sum(sapply(ngrams_pruned, nrow))/10^6,1)
  • Unpruned model takes 229 MB in memory and contains 9.3 million ngrams
  • Pruned model takes 32 MB in memory and contains 0.3 million ngrams.

3.2.4 Evaluating model

How good is my toy model? If I had a probability distribution, I could calculate perplexity but for my model it will not work. So I am left with good old accuracy: in what % of the cases does the model predict the next word correctly?

Luckily previous generations of data scientists left a convenient benchmark framework that will run my prediction function for a set of test data and return the score. You can see both benchmark code and test data in its github repo, the password for the archive is “capstone4” (without quotes).

The benchmarking may take up to half an hour so the calls are commented, but I saved the results and you can see them below.

# benchmark(combined_predict, model=ngrams_pruned, dict=dict20, sent.list = list('tweets' = tweets, 'blogs' = blogs), ext.output = FALSE)
# benchmark(combined_predict, model=ngrams, dict=dict20, sent.list = list('tweets' = tweets, 'blogs' = blogs), ext.output = FALSE)

bench_results <- read.csv("benchmark_comparison.csv")
bench_results[1:4, c(1,4,5)]
##                 Algorithm Top.3.accuracy Avg.runtime..msec.
## 1                Baseline          8.11%               0.07
## 2      Unpruned toy model         18.11%              30.94
## 3        Pruned toy model         19.71%              30.88
## 4 Best results from forum         21.75%              25.29
  1. Baseline model always returns three most common words: the, on, a
  2. Pruned model definitely shows some improvement over unpruned one
  3. Best results from forum are due to Khaled Al-Shamaa

A sidenote on how the benchmark works: it actually provides three metrics:

  1. Top-3 score that gives 3 points for correct 1st prediction, 2 for 2nd and 1 for 3rd
  2. Top-1 accuracy: in what % of the cases 1st prediction was the correct one
  3. Top-3 accuracy: in what % of the cases one of the top-3 predictions was the correct one

The author calls last two metrics “precision” but it is in fact accuracy, judging by how it is calculated in the code. As it is always useful to have just one real number metric to compare different algorithms, I chose Top-3 accuracy. The rationale is that the goal is to save user’s keystrokes, and as long as we can display the desired next word within 3 options to tap, we have a happy user.

4 Final model

Now that I have a working toy model and the understanding that it does not lag too much behind its peers, it is time to build a fully functional one using the entire data set. The process will be very similar one for the toy model, except that I will have to process data in batches because the entire data set up to hexagrams will not fit into memory.

4.1 Building

Batch processing of the entire data set took a few hours on my laptop with 16 GB RAM and it was an interactive process of trial and error. Here I will just load my saved model, but the basic idea was to

  1. Split data into batches

     a. For 1-4-grams the 3 sources (Twitter, news and blogs) can each be considered a single batch
     b. For 5-6-grams, even that is too much, so I took a half of each source
  2. For each batch, calculate all ngrams separately and save to disk

  3. Iteratively load, rbind and recalculate frequencies between several batches for each value of n separately

  4. Keep only top-3 predictions

rm(ngrams, ngrams_pruned, dict20)
load('model20_with_dict')
MB <- round(object.size(model20)/(2^20))
count <- round(sum(sapply(model20, nrow))/10^6, 1)

Final model takes 219 MB in memory and contains 9.5 million different ngrams.

4.2 Making predictions

Just for fun, let’s make the same predictions we did on toy model, using the full model, and see if there is any difference.

p[[1]] <- combined_predict("Don't", model20, dict20, verbose=TRUE)
p[[2]] <- combined_predict("Bill loves", model20, dict20, verbose=TRUE)
p[[3]] <- combined_predict("You shall not", model20, dict20, verbose=TRUE)
p[[4]] <- combined_predict("Never have I ever", model20, dict20, verbose=TRUE)

4.2.1 Don’t

p[[1]]
## [[1]]
## [1] "know" "have" "want"
## 
## [[2]]
## [[2]][[1]]
##       X1   X2 rank
## 1: don't know    1
## 2: don't have    2
## 3: don't want    3
## 
## [[2]][[2]]
##     X1 rank
## 1: the    1
## 2:  to    2
## 3: and    3

Same as toy model.

4.2.2 Bill loves

p[[2]]
## [[1]]
## [1] "to"  "you" "me" 
## 
## [[2]]
## [[2]][[1]]
##       X1  X2 rank
## 1: loves  to    1
## 2: loves you    2
## 3: loves  me    3
## 
## [[2]][[2]]
##     X1 rank
## 1: the    1
## 2:  to    2
## 3: and    3

Same as toy model.

4.2.3 You shall not

p[[3]]
## [[1]]
## [1] "pass"   "be"     "escape"
## 
## [[2]]
## [[2]][[1]]
##     X1    X2  X3     X4 rank
## 1: you shall not   pass    1
## 2: you shall not     be    2
## 3: you shall not escape    3

Ah finally! The famous “you shall not pass” is here and on the top row!

4.2.4 Never have I ever

p[[4]]
## [[1]]
## [1] "mentioned" "told"      "been"     
## 
## [[2]]
## [[2]][[1]]
##      X1 X2   X3        X4 rank
## 1: have  i ever mentioned    1
## 2: have  i ever      told    2
## 3: have  i ever      been    3

Now all the predictions come from quadgrams, while toy model also took into the account a quintagram “never have i ever felt”. However, it was observed only once, so it must have been pruned out.

4.3 Evaluating the model

Now I can show the full comparison of different models. Before building this final one, I tried making one with 50k words dictionary and no pruning (that also used 4 GB of RAM), I list it here too for reference.

bench_results[,c(1,4,5)]
##                     Algorithm Top.3.accuracy Avg.runtime..msec.
## 1                    Baseline          8.11%               0.07
## 2          Unpruned toy model         18.11%              30.94
## 3            Pruned toy model         19.71%              30.88
## 4     Best results from forum         21.75%              25.29
## 5 Unpruned 50k dict full data         21.31%              24.65
## 6   Pruned 20k dict full data         22.80%              22.23

Results are clear:

  1. Pruning is good
  2. Having a bigger dictionary does not necessarily help
  3. My model beats the competition by over 1 percentage point of top-3 accuracy!

For better picture, it could be useful to try pruning the 50k model but it looks good enough as it is.

5 Productionalizing (building an app)

Creating a fully functional mobile keyboard is far out of scope of this project, but what I did build is a simple proof-of-concept web app with UI where you can type any text and it will provide 3 options for next word along with the original ngrams this predictions come from. It is available here:

https://dmitrytoda.shinyapps.io/SwiftPredict/

Try typing any text and see what it suggests! If you open it from mobile, you can compare my predictions with the ones your favourite mobile keyboard provides.

6 Next steps

How can this model be improved? One obvious way is to collect more data. However, a more modern approach would be to use a recurrent neural network (RNN), or more specifically, Coupled Input-Forget Gates (CIFG) variation. In this article, the authors compare ngram baseline model with server-trained CIFG and with federated CIFG (one trained on a user’s device using their own data). They list following top-3 accuracies (that they refer to as “recall” which is the same in this particular case):

  1. Ngram model - 22.1%
  2. Server-trained CIFG - 27.1%
  3. Federated CIFG - 27.0%

So while my model is on par or even a bit better than their baseline ngram (22.8% vs 21.1%), it is lagging behind RNN big time. The logical improvement path would be to switch to using RNN as well.

7 Appendix (functions’ code)

createSample <- function(infile, outfile, lines=50000, block=100){
        # Creates a smaller sample file based on a large input file
        # by reading random blocks of lines from the input
        # Arguments:
        # infile, outfile - names of input and output files
        # lines - number of lines in the resulting file
        # block - size of a block to be read from infile
        
        in.con <- file(infile, 'r')
        file.length <- length(readLines(in.con))
        close(in.con)
        
        in.con <- file(infile, 'r')
        out.con <- file(outfile, 'w')
        
        in.blocks <- as.integer(file.length/block)
        out.blocks <- as.integer(lines/block)
        keep.blocks <- sort(sample(1:in.blocks, out.blocks))
        
        i_prev <- 1
        print(keep.blocks)
        for(i in keep.blocks) {
                tmp <- scan(in.con, what='character', nmax=block, skip = block*(i-i_prev), sep='\n')
                write(tmp, out.con, append=TRUE, sep='\n')
                i_prev <- i
        }
        
        close(in.con)
        close(out.con)
}

files2sentences <- function(filenames) {
        # input: vector of file names
        # output: quanteda corpus reshaped to sentences
        
        require(quanteda)
        texts <- sapply(filenames, readLines, skipNul = TRUE)
        texts <- sapply(texts, enc2utf8)
        texts <- unname(unlist(texts))
        my_corpus <- corpus(texts)
        corpus_reshape(my_corpus, to="sentences")
        
}

str2tokens <- function(string, res='tokens', dict) {
        # converts a string into tokens
        # res=tokens returns tokens for further processing
        # res=vector returns a character vector to make predictions upon
        
        require(quanteda)
        stopifnot(res %in% c('tokens', 'vector'))
        my_tokens <- tokens(
                string,
                what = "word",
                remove_punct = TRUE,
                remove_symbols = TRUE,
                remove_numbers = TRUE,
                remove_url = TRUE,
                remove_separators = TRUE,
                split_hyphens = FALSE,
                include_docvars = FALSE,
                padding = FALSE
        )
        
        # keep only tokens that contain at least one letter
        my_tokens <- tokens_select(
                my_tokens, 
                "[a-z]+", 
                valuetype = "regex", 
                selection = "keep",  
                case_insensitive = TRUE)
        
        # remove tokens that contain weird characters,
        # i.e. anything but letters, digits and #'.-’ signs
        my_tokens <- tokens_select(
                my_tokens, 
                "[^\\#a-z0-9\\'\\.\\-]+", 
                valuetype = "regex", 
                selection = "remove",  
                case_insensitive = TRUE)
        
        if(res=='tokens') return(my_tokens)
        
        res <- tolower(my_tokens[[1]])
        res[! res %in% dict] <- '<UNK>'
        res
}

nFreq <- function(n, my_tokens) {
        # function to create n-gram frequency table
        
        my_ngrams <- tokens_ngrams(my_tokens, n, concatenator = " ")
        my_ngrams <- dfm(my_ngrams)
        my_ngrams <- textstat_frequency(my_ngrams)
        my_ngrams[,1:2]
}

combined_predict <- function(starter, model, dict, verbose=FALSE) {
        # starter - any string
        # model - list of ngram data.tables, with either frequency or rank as last column
        # dict - character vector with dictionary
        # verbose - whether to include path into output
        
        require(data.table)
        stopifnot(sapply(model, is.data.table))
        max_n <- length(model)
        starter <- str2tokens(starter, "vector", dict)
        
        # for empty starter, predict the unigrams
        if(length(starter)==0) {
                if(verbose) {
                        return(list(as.character(model[[1]]$X1), model[[1]]))
                } else return(as.character(model[[1]]$X1))
        }
        
        # take only max_n-1 last words for prediction
        if (length(starter)>max_n-1) starter <- starter[(length(starter)-max_n+2):length(starter)]
        
        res <- vector(mode='character')
        path <- list()
        
        while(length(unique(res))<3) {
                n <- length(starter)+1
                cond <- my_cond(starter)
                curr <- model[[n]][eval(cond)]
                if(!is.na(curr[[n+1]])) { # if something was found at current n
                        if(colnames(curr)[n+1]=='rank') {
                                # rank models are already sorted
                                res <- c(res, as.character(curr[[n]])) 
                                path[[n]] <- curr
                        } else if(colnames(curr)[n+1]=='frequency') {
                                # frequency models need to be sorted by desc frequency
                                res <- c(res, as.character(curr[order(-frequency)][[n]])) 
                                path[[n]] <- curr[order(-frequency)]
                        } else stop("Predict error: last column is neither rank nor frequency")
                        
                }
                starter <- starter[2:(n-1)]
        }
        
        if(verbose) 
                list(unique(res)[1:3], rev(path[!sapply(path, is.null)]))
        else unique(res)[1:3]
}

keep3 <- function(ngrams) {
        # keeps top-3 prediction for each starter, replacing frequencies
        # with a 1-2-3 factor rank
        
        # set keys to everything except Xn, which equals
        # X1, X2..Xn-1, frequency
        n <- ncol(ngrams)-1
        setkeyv(ngrams, cols = colnames(ngrams)[-n])
        
        # keep top 3 predictions (which are at the bottom) 
        # for each starter = X1..Xi-1 combination
        ngrams <- ngrams[
                ,
                tail(.SD,3), 
                by=eval(if(n>1) colnames(ngrams)[1:(n-1)] else NULL)
                ]
        
        # add rank column, largest count = rank 1
        ngrams[
                ,
                rank:=.N:1,
                by=eval(if(n>1) colnames(ngrams)[1:(n-1)] else NULL)
                ]
        ngrams$rank <- as.factor(ngrams$rank)
        ngrams$frequency <- NULL           
        
        # set new keys: starter + rank
        setkeyv(ngrams, colnames(ngrams)[-n])
        ngrams
}