mirror of
https://github.com/modAL-python/modAL.git
synced 2022-05-17 00:31:33 +03:00
68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
"""
|
|
In this example the use of ActiveLearner is demonstrated on the iris dataset in a pool-based sampling setting.
|
|
For more information on the iris dataset, see https://en.wikipedia.org/wiki/Iris_flower_data_set
|
|
For its scikit-learn interface, see http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html
|
|
"""
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.decomposition import PCA
|
|
from sklearn.datasets import load_iris
|
|
from sklearn.neighbors import KNeighborsClassifier
|
|
from modAL.models import ActiveLearner
|
|
|
|
# loading the iris dataset
|
|
iris = load_iris()
|
|
# visualizing the classes
|
|
with plt.style.context('seaborn-white'):
|
|
pca = PCA(n_components=2).fit_transform(iris['data'])
|
|
plt.figure(figsize=(7, 7))
|
|
plt.scatter(x=pca[:, 0], y=pca[:, 1], c=iris['target'], cmap='viridis', s=50)
|
|
plt.title('The iris dataset')
|
|
plt.show()
|
|
|
|
# initial training data
|
|
train_idx = [0, 50, 100]
|
|
X_train = iris['data'][train_idx]
|
|
y_train = iris['target'][train_idx]
|
|
|
|
# generating the pool
|
|
X_pool = np.delete(iris['data'], train_idx, axis=0)
|
|
y_pool = np.delete(iris['target'], train_idx)
|
|
|
|
# initializing the active learner
|
|
learner = ActiveLearner(
|
|
estimator=KNeighborsClassifier(n_neighbors=3),
|
|
X_training=X_train, y_training=y_train
|
|
)
|
|
|
|
# visualizing initial prediction
|
|
with plt.style.context('seaborn-white'):
|
|
plt.figure(figsize=(7, 7))
|
|
prediction = learner.predict(iris['data'])
|
|
plt.scatter(x=pca[:, 0], y=pca[:, 1], c=prediction, cmap='viridis', s=50)
|
|
plt.title('Initial accuracy: %f' % learner.score(iris['data'], iris['target']))
|
|
plt.show()
|
|
|
|
print('Accuracy before active learning: %f' % learner.score(iris['data'], iris['target']))
|
|
|
|
# pool-based sampling
|
|
n_queries = 20
|
|
for idx in range(n_queries):
|
|
query_idx, query_instance = learner.query(X_pool)
|
|
learner.teach(
|
|
X=X_pool[query_idx].reshape(1, -1),
|
|
y=y_pool[query_idx].reshape(1, )
|
|
)
|
|
# remove queried instance from pool
|
|
X_pool = np.delete(X_pool, query_idx, axis=0)
|
|
y_pool = np.delete(y_pool, query_idx)
|
|
print('Accuracy after query no. %d: %f' % (idx+1, learner.score(iris['data'], iris['target'])))
|
|
|
|
# plotting final prediction
|
|
with plt.style.context('seaborn-white'):
|
|
plt.figure(figsize=(7, 7))
|
|
prediction = learner.predict(iris['data'])
|
|
plt.scatter(x=pca[:, 0], y=pca[:, 1], c=prediction, cmap='viridis', s=50)
|
|
plt.title('Classification accuracy after %i queries: %f' % (n_queries, learner.score(iris['data'], iris['target'])))
|
|
plt.show() |