SKLearner Home | About | Contact | Examples

Scikit-Learn BernoulliRBM Model

The Bernoulli Restricted Boltzmann Machine (RBM) is an unsupervised learning algorithm used for feature learning and dimensionality reduction on binary-valued data.

The key hyperparameters of BernoulliRBM include n_components (number of binary hidden units), learning_rate, batch_size, and n_iter (number of iterations).

The algorithm is appropriate for datasets with binary-valued features, such as black and white images or binary occurrence data.

from sklearn.neural_network import BernoulliRBM
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np

# generate binary-valued dataset
X, _ = make_classification(n_samples=100, n_features=10, n_informative=5, n_redundant=0, n_classes=2, random_state=1)
X = (X > 0).astype(int)

# split into train and test sets
X_train, X_test = train_test_split(X, test_size=0.2, random_state=1)

# create model
model = BernoulliRBM(n_components=5, learning_rate=0.1, batch_size=10, n_iter=100, random_state=1)

# fit model
model.fit(X_train)

# transform test data
X_test_transformed = model.transform(X_test)
print(X_test_transformed.shape)

# transform a new sample
row = [[0, 1, 1, 0, 1, 0, 0, 1, 1, 0]]
transformed = model.transform(row)
print('Transformed: %s' % transformed)

Running the example gives an output like:

(20, 5)
Transformed: [[0.3271421  0.29290098 0.31111151 0.34068186 0.31298058]]

The steps are as follows:

  1. First, a synthetic binary-valued dataset is generated using the make_classification() function and then binarized. The dataset is split into training and test sets using train_test_split().

  2. Next, a BernoulliRBM model is instantiated with chosen hyperparameters. The model is then fit on the training data using the fit() method.

  3. The test data is transformed using the fit model’s transform() method.

  4. A single new sample can be transformed by passing it to the transform() method.

This example demonstrates how to set up and use a BernoulliRBM model for unsupervised learning on binary-valued data, showcasing its utility for feature learning and dimensionality reduction in scikit-learn.

The model learns a compressed representation of the binary input data in the hidden layer. This compressed representation can then be used as input features for another model or to reconstruct the original data.



See Also