SKLearner Home | About | Contact | Examples

Scikit-Learn ComplementNB Model

ComplementNB is a variant of the standard Naive Bayes classifier, optimized for imbalanced data by calculating the complement of the term frequencies. It is particularly useful for text classification tasks.

The key hyperparameters of ComplementNB include alpha (smoothing parameter), norm (normalization of weights), and class_prior (prior probabilities of the classes).

This algorithm is appropriate for text classification, especially when dealing with imbalanced datasets.

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import ComplementNB
from sklearn.metrics import accuracy_score

# load the dataset
data = fetch_20newsgroups(subset='all')
X, y = data.data, data.target

# convert text data to TF-IDF features
vectorizer = TfidfVectorizer()
X_tfidf = vectorizer.fit_transform(X)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=1)

# create model
model = ComplementNB()

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
sample = ["The GPU performance in the latest model is outstanding."]
sample_tfidf = vectorizer.transform(sample)
yhat = model.predict(sample_tfidf)
print('Predicted category: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 0.907
Predicted category: 9

The steps are as follows:

  1. First, load a text classification dataset using the fetch_20newsgroups() function. This fetches a dataset suitable for multi-class classification.

  2. Convert the text data to TF-IDF features using TfidfVectorizer(). This prepares the text data for input into the ComplementNB model.

  3. Split the dataset into training and test sets using train_test_split().

  4. Instantiate a ComplementNB model with default hyperparameters.

  5. Fit the model on the training data using the fit() method.

  6. Evaluate the model’s performance by predicting the test set and calculating the accuracy score.

  7. Make a prediction on a new sample text by transforming it to TF-IDF and passing it to the predict() method.

This example demonstrates the process of using the ComplementNB algorithm for text classification, highlighting its application on imbalanced datasets.



See Also