public class TradNaiveBayesClassifier extends Object implements JointClassifier<CharSequence>, ObjectHandler<Classified<CharSequence>>, Serializable, Compilable
TradNaiveBayesClassifier
implements a traditional
tokenbased approach to naive Bayes text classification. It wraps
a tokenization factory to convert character sequences into
sequences of tokens. This implementation supports several
enhancements to simple naive Bayes: priors, length normalization,
and semisupervised training with EM.
It is the token counts (aka "bag of words") sequence that is actually being classified, not the raw character sequence input. So any character sequences that produce the same bags of tokens are considered equal.
Naive Bayes is trainable online, meaning that it can be given
training instances one at a time, and at any point can be used as a
classifier. Training cases consist of a character sequence and
classification, as dictated by the interface ObjectHandler<Classified<CharSequence>>
.
Given a character sequence, a naive Bayes classifier returns
joint probability estimates of categories and tokens; this is
reflected in its implementing the Classifier<CharSequence,JointClassification>
interface. Note that
this is the joint probability of the token counts, so sums of
probabilities over all input character sequences will exceed 1.0.
Typically, only the conditional probability estimates are used in
practice.
If there is length normalization, the joint probabilities will not sum to 1.0 over all inputs and outputs. The conditional probabilities will always sum to 1.0.
Conditional probabilities are derived by applying Bayes's rule to invert the probability calculation:p(tokens,cat) = p(tokenscat) * p(cat)
The tokens are assumed to be independent (this is the "naive" step):p(cattokens) = p(tokens,cat) / p(tokens) = p(tokenscat) * p(cat) / p(tokens)
Finally, an explicit marginalization allows us to compute the marginal distribution of tokens:p(tokenscat) = p(tokens[0]cat) * ... * p(tokens[tokens.length1]cat) = Π_{i < tokens.length} p(tokens[i]cat)
p(tokens) = Σ_{cat'} p(tokens,cat') = Σ_{cat'} p(tokenscat') * p(cat')
p(cattokens)
in terms of two distributions, the conditional
probability of a token given a category p(tokencat)
, and the
marginal probability of a category p(cat)
(sometimes called
the category's prior probability, though this shouldn't be confused
with the usual Bayesian prior on model parameters).
Traditional naive Bayes uses a maximum a posterior (MAP)
estimate of the multinomial distributions: p(cat)
over the
set of categories, and for each category cat
, the
multinomial distribution p(tokencat)
over the set of tokens.
Traditional naive Bayes employs the Dirichlet conjugate prior for
multinomials, which is straightforward to compute by adding a fixed
"prior count" to each count in the training data. This lends the
traditional name "additive smoothing".
Two sets of counts are sufficient for estimating a traditional
naive Bayes classifier. The first is tokenCount(w,c)
, the
number of times token w
appeared as a token in a training
case for category c
. The second is caseCount(c)
,
which is the number of training cases for category c
.
We assume prior counts α
for the case counts
and β
for the token counts. These values are supplied
in the constructor for this class.
The estimates for category and token probabilities p'
are most easily understood as proportions:
The probability estimatesp'(wc) ∝ tokenCount(w,c) + β p'(c) ∝ caseCount(c) + α
p'
are obtained through the
usual normalization:
p'(wc) = ( tokenCount(w,c) + β ) / Σ_{w} ( tokenCount(w,c) + β ) p'(c) = ( caseCount(c) + α ) / Σ_{c} ( caseCount(c) + α )
Although not traditionally used for naive Bayes, maximum
likelihood estimates arise from setting the prior counts equal to
zero (α = β = 0
). The prior counts drop
out of the equations to yield the maximum likelihood estimates
p^{*}
:
p^{*}(wc) = tokenCount(w,c) / Σ_{w} tokenCount(w,c) p^{*}(c) = caseCount(c) / Σ_{c} caseCount(c)
Unlike traditional naive Bayes implementations, this class allows weighted training, including training directly from a conditional classification. When training using a conditional classification, each category is weighted according to its conditional probability.
Weights may be negative, allowing counts to be decremented (e.g. for Gibbs sampling).
Because the (almost always faulty) independence of tokens
assumptions underlying the naive Bayes classifier, the conditional
probability estimates tend toward either 0.0 or 1.0 as the input
grows longer. In practice, it sometimes help to length normalize
the documents. That is, consider each document to be a given
number of tokens long, lengthNorm
.
Length normalization can be computed directly on the linear scale:
but is more easily understood on the log scale, where we multiply the length norm by the log probability normalized per token:p^{n}(tokenscat) = p(tokenscat)^{(lengthNorm/tokens.length)}
The length normalization parameter is supplied in the constructor, with alog_{2} p^{n}(tokenscat) = lengthNorm * log_{2} p(tokensc) / tokens.length
Double.NaN
value indicating
that no length normalization should be done.
Length normalization will be applied during training, too. Length normalization may be changed using the set method. For instance, this allows training to skip length normalization and classification to use length normalization.
EM is controlled by epoch. Each epoch consists of an expectation (E) step, followed by a maximization (M) step. The expectation step computes expectations which are then fed in as training weights to the maximization step.
The version of EM implemented in this class allows a mixture of supervised and unsupervised data.
The supervised training data is
in the form of a corpus of classifications, implementing
Corpus
Unsupervised data is in the form of a corpus of texts, implementing
Corpus<TextHandler>
.
The method also requires a factory with which to produce a new
classifier in each epoch, namely an implementation of Factory<TradNaiveBayesClassifier>
. And it also takes an initial
classifier, which may be different than the classifiers generated
by the factory.
EM works by iteratively training better and better classifiers using the previous classifier to label unlabeled data to use for training.
set lastClassifier to initialClassifier for (epoch = 0; epoch < maxEpochs; ++epoch) { create classifier using factory train classifier on supervised items for (x in unsupervised items) { compute p(cx) with lastClassifier for (c in category) train classifier on c weighted by p(cx) } evaluate corpus and model probability under classifier set lastClassifier to classifier break if converged } return lastClassifier
Note that in each round, the new classifier is trained on the supervised items.
In general, we have found that EM training works best if the initial classifier does more smoothing than the classifiers returned by the factory.
Annealing, of a sort, may be built in by having the factory return a sequence of classifiers with ever longer length normalizations and/or lower prior counts, both of which attenuate the posterior predictions of a naive Bayes classifier. With a short length normalization, classifications are driven closer to uniform; with longer length normalizations they are more peaky.
It is possible to train a classifier in a completely unsupervised fashion by having the initial classifier assign categories at random. Only the number of categories must be fixed. The algorithm is exactly the same, and the result after convergence or the maximum number of epochs is a classifier.
Now take the trained classifier and run it over the texts in the unsupervised text corpus. This will assign probabilities of the text belonging to each of the categories. This is known as a soft clustering, and the algorithm overall is known as EM clustering. If we assign each item to its most likely category, the result is then a hard clustering.
A naive Bayes classifier may be serialized. The object read back in will behave just as the naive Bayes classifier that was serialized. The tokenizer factory must be serializable in order to serialize the classifier.
A naive Bayes classifier may be compiled. In order to be
compiled, the tokenizer factory must be either serializable or
compilable. The object read back in will implement ConditionalClassifier<CharSequence>
if the compiled classifier is
binary (i.e., has exactly two categories) and JointClassifier<CharSequence>
if the compiled classifier has more
than two categories. The ability to compute joint probabilities in
the binary case is lost due to an optimization in the compiler, so
the resulting class only implements conditional classifier.
A compiled classifier may not be trained.
NaiveBayesClassifier
NaiveBayesClassifier
differs from this version in smoothing the
token estimates with character language model estimates.
TradNaiveBayesClassifier
must be synchronized externally
using read/write synchronization (e.g. with ReadWriteLock
. The write methods
include handle(Classified)
, train(CharSequence,Classification,double)
, trainConditional(CharSequence,ConditionalClassification,double,double)
,
and setLengthNorm(double)
. All other methods are read
methods.
A compiled classifier is completely thread safe.
Constructor and Description 

TradNaiveBayesClassifier(Set<String> categorySet,
TokenizerFactory tokenizerFactory)
Constructs a naive Bayes classifier over the specified
categories, using the specified tokenizer factory.

TradNaiveBayesClassifier(Set<String> categorySet,
TokenizerFactory tokenizerFactory,
double categoryPrior,
double tokenInCategoryPrior,
double lengthNorm)
Constructs a naive Bayes classifier over the specified
categories, using the specified tokenizer factory, priors and
length normalization.

Modifier and Type  Method and Description 

Set<String> 
categorySet()
Returns a set of categories for this classifier.

JointClassification 
classify(CharSequence in)
Return the classification of the specified character sequence.

void 
compileTo(ObjectOutput out)
Compile this classifier to the specified object output.

static Iterator<TradNaiveBayesClassifier> 
emIterator(TradNaiveBayesClassifier initialClassifier,
Factory<TradNaiveBayesClassifier> classifierFactory,
Corpus<ObjectHandler<Classified<CharSequence>>> labeledData,
Corpus<ObjectHandler<CharSequence>> unlabeledData,
double minTokenCount)
Apply the expectation maximization (EM) algorithm to train a traditional
naive Bayes classifier using the specified labeled and unabled data,
initial classifier and factory for creating subsequent factories.

static TradNaiveBayesClassifier 
emTrain(TradNaiveBayesClassifier initialClassifier,
Factory<TradNaiveBayesClassifier> classifierFactory,
Corpus<ObjectHandler<Classified<CharSequence>>> labeledData,
Corpus<ObjectHandler<CharSequence>> unlabeledData,
double minTokenCount,
int maxEpochs,
double minImprovement,
Reporter reporter)
Apply the expectation maximization (EM) algorithm to train a traditional
naive Bayes classifier using the specified labeled and unabled data,
initial classifier and factory for creating subsequent factories,
maximum number of epochs, minimum improvement per epoch, and reporter
to which progress reports are sent.

void 
handle(Classified<CharSequence> classifiedObject)
Trains the classifier with the specified classified character
sequence.

boolean 
isKnownToken(String token)
Returns
true if the token has been seen in
training data. 
Set<String> 
knownTokenSet()
Returns an unmodifiable view of the set of tokens.

double 
lengthNorm()
Returns the length normalization factor for this
classifier.

double 
log2CaseProb(CharSequence input)
Returns the log (base 2) marginal probability of the specified
input.

double 
log2ModelProb()
Returns the log (base 2) of the probability density of this
model in the Dirichlet prior specified by this classifier.

double 
probCat(String cat)
Returns the probability estimate for the specified
category.

double 
probToken(String token,
String cat)
Returns the probability of the specified token
in the specified category.

void 
setLengthNorm(double lengthNorm)
Set the length normalization factor to the specified value.

String 
toString()
Return a string representation of this classifier.

void 
train(CharSequence cSeq,
Classification classification,
double count)
Trains the classifier with the specified case consisting of
a character sequence and conditional classification with the
specified count.

void 
trainConditional(CharSequence cSeq,
ConditionalClassification classification,
double countMultiplier,
double minCount)
Trains this classifier using tokens extracted from the
specified character sequence, using category count multipliers
derived by multiplying the specified count multiplier by the
conditional probablity of a category in the specified
classification.

public TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory)
Double.NaN
).
See the class documentation above for more information.
categorySet
 Categories for classification.tokenizerFactory
 Factory to convert char sequences to
tokens.IllegalArgumentException
 If there are fewer than two
categories.public TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory, double categoryPrior, double tokenInCategoryPrior, double lengthNorm)
categorySet
 Categories for classification.tokenizerFactory
 Factory to convert char sequences to
tokens.categoryPrior
 Prior count for categories.tokenInCategoryPrior
 Prior count for tokens per category.lengthNorm
 A positive, finite length norm, or Double.NaN
if no length normalization is to be done.IllegalArgumentException
 If either prior is negative or
not finite, if there are fewer than two categories, or if the
length normalization constant is negative, zero, or infinite.public String toString()
public Set<String> categorySet()
public void setLengthNorm(double lengthNorm)
lengthNorm
 Length normalization or Double.NaN
to turn
off normalization.IllegalArgumentException
 If the length norm is
infinite, zero, or negative.public JointClassification classify(CharSequence in)
classify
in interface BaseClassifier<CharSequence>
classify
in interface ConditionalClassifier<CharSequence>
classify
in interface JointClassifier<CharSequence>
classify
in interface RankedClassifier<CharSequence>
classify
in interface ScoredClassifier<CharSequence>
in
 Character sequence being classified.public double lengthNorm()
public boolean isKnownToken(String token)
true
if the token has been seen in
training data.token
 Token to test.true
if the token has been seen in
training data.public Set<String> knownTokenSet()
public double probToken(String token, String cat)
IllegalArgumentException
 If the category is not known
or the token is not known.public void compileTo(ObjectOutput out) throws IOException
compileTo
in interface Compilable
out
 Object output to which this classifier is compiled.IOException
 If there is an underlying I/O error
during the write.public double probCat(String cat)
cat
 Category whose probability is returned.IllegalArgumentException
 If the category is not known.public void handle(Classified<CharSequence> classifiedObject)
trainConditional(CharSequence,ConditionalClassification,double,double)
.handle
in interface ObjectHandler<Classified<CharSequence>>
classifiedObject
 Classified character sequence.public void trainConditional(CharSequence cSeq, ConditionalClassification classification, double countMultiplier, double minCount)
cSeq
 Character sequence being trained.classification
 Conditional classification to train.countMultiplier
 Count multiplier of training instance.minCount
 Minimum count for which a category is trained for this character
sequence.IllegalArgumentException
 If the countMultiplier is not finite and
nonnegative, or if the min count is below zero or not a number.public void train(CharSequence cSeq, Classification classification, double count)
If the count value is negative, counts are subtracted rather than added. If any of the counts fall below zero, an illegal argument exception will be thrown and the classifier will be reverted to the counts in place before the method was called. Cleanup after errors requires the tokenizer factory to return the same tokenizer given the same string, but no check is made that it does.
cSeq
 Character sequence on which to train.classification
 Classification to train with character
sequence.count
 How many instances the classification will count
as for training purposes.IllegalArgumentException
 If the count is negative and
increments cause accumulated counts to fall below zero.public double log2CaseProb(CharSequence input)
p(x) = Σ_{c in cats} p(c,x)
Note that this value is normalized by the number of tokens
in the input, so that
Σ_{length(x) = n} p(x) = 1.0
input
 Input character sequence.public double log2ModelProb()
The result is the sum of the log density of the multinomial over categories and the log density of the percategory multinomials over tokens.
For a definition of the probability function for each
category's multinomial and the overall category multinomial,
see Statistics.dirichletLog2Prob(double,double[])
.
public static Iterator<TradNaiveBayesClassifier> emIterator(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount) throws IOException
This method lets the client take control over assessing convergence, so there are no convergencerelated arguments.
initialClassifier
 Initial classifier to bootstrap.classifierFactory
 Factory for creating subsequent classifiers.labeledData
 Labeled data for supervised trianing.unlabeledData
 Unlabeled data for unsupervised training.minTokenCount
 Min count for a word to not be pruned.IOException
public static TradNaiveBayesClassifier emTrain(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount, int maxEpochs, double minImprovement, Reporter reporter) throws IOException
initialClassifier
 Initial classifier to bootstrap.classifierFactory
 Factory for creating subsequent classifiers.labeledData
 Labeled data for supervised trianing.unlabeledData
 Unlabeled data for unsupervised training.minTokenCount
 Min count for a word to not be pruned.maxEpochs
 Maximum number of epochs to run training.minImprovement
 Minimum relative improvement per epoch.reporter
 Reporter to which intermediate results are reported,
or null
for no reporting.IOException