mirror of
https://github.com/modAL-python/modAL.git
synced 2022-05-17 00:31:33 +03:00
Merge branch 'BoyanH-20-pandas-support' into dev
This commit is contained in:
@@ -100,12 +100,11 @@ import numpy as np
|
||||
X = np.random.choice(np.linspace(0, 20, 10000), size=200, replace=False).reshape(-1, 1)
|
||||
y = np.sin(X) + np.random.normal(scale=0.3, size=X.shape)
|
||||
```
|
||||
For active learning, we shall define a custom query strategy tailored to Gaussian processes. In a nutshell, a *query stategy* in modAL is a function taking (at least) two arguments (an estimator object and a pool of examples), outputting the index of the queried instance and the instance itself. In our case, the arguments are ```regressor``` and ```X```.
|
||||
For active learning, we shall define a custom query strategy tailored to Gaussian processes. In a nutshell, a *query stategy* in modAL is a function taking (at least) two arguments (an estimator object and a pool of examples), outputting the index of the queried instance. In our case, the arguments are ```regressor``` and ```X```.
|
||||
```python
|
||||
def GP_regression_std(regressor, X):
|
||||
_, std = regressor.predict(X, return_std=True)
|
||||
query_idx = np.argmax(std)
|
||||
return query_idx, X[query_idx]
|
||||
return np.argmax(std)
|
||||
```
|
||||
After setting up the query strategy and the data, the active learner can be initialized.
|
||||
```python
|
||||
|
||||
@@ -70,7 +70,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Uncertainty measure and query strategy for Gaussian processes\n",
|
||||
"For active learning, we shall define a custom query strategy tailored to Gaussian processes. More information on how to write your custom query strategies can be found at the page [Extending modAL](https://cosmic-cortex.github.io/modAL/Extending-modAL). In a nutshell, a *query stategy* in modAL is a function taking (at least) two arguments (an estimator object and a pool of examples), outputting the index of the queried instance and the instance itself. In our case, the arguments are ```regressor``` and ```X```."
|
||||
"For active learning, we shall define a custom query strategy tailored to Gaussian processes. More information on how to write your custom query strategies can be found at the page [Extending modAL](https://cosmic-cortex.github.io/modAL/Extending-modAL). In a nutshell, a *query stategy* in modAL is a function taking (at least) two arguments (an estimator object and a pool of examples), outputting the index of the queried instance. In our case, the arguments are ```regressor``` and ```X```."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -81,8 +81,7 @@
|
||||
"source": [
|
||||
"def GP_regression_std(regressor, X):\n",
|
||||
" _, std = regressor.predict(X, return_std=True)\n",
|
||||
" query_idx = np.argmax(std)\n",
|
||||
" return query_idx, X[query_idx]"
|
||||
" return np.argmax(std)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -234,4 +233,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -27,11 +27,8 @@
|
||||
" # measure the utility of each instance in the pool\n",
|
||||
" utility = utility_measure(classifier, X)\n",
|
||||
"\n",
|
||||
" # select the indices of the instances to be queried\n",
|
||||
" query_idx = select_instances(utility)\n",
|
||||
"\n",
|
||||
" # return the indices and the instances\n",
|
||||
" return query_idx, X[query_idx]"
|
||||
" # select and return the indices of the instances to be queried\n",
|
||||
" return select_instances(utility)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -213,8 +210,7 @@
|
||||
"# classifier uncertainty and classifier margin\n",
|
||||
"def custom_query_strategy(classifier, X, n_instances=1):\n",
|
||||
" utility = linear_combination(classifier, X)\n",
|
||||
" query_idx = multi_argmax(utility, n_instances=n_instances)\n",
|
||||
" return query_idx, X[query_idx]\n",
|
||||
" return multi_argmax(utility, n_instances=n_instances)\n",
|
||||
"\n",
|
||||
"custom_query_learner = ActiveLearner(\n",
|
||||
" estimator=GaussianProcessClassifier(1.0 * RBF(1.0)),\n",
|
||||
@@ -299,4 +295,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -118,15 +118,13 @@ the *noisy sine* function:
|
||||
For active learning, we shall define a custom query strategy tailored to
|
||||
Gaussian processes. In a nutshell, a *query stategy* in modAL is a
|
||||
function taking (at least) two arguments (an estimator object and a pool
|
||||
of examples), outputting the index of the queried instance and the
|
||||
instance itself. In our case, the arguments are ``regressor`` and ``X``.
|
||||
of examples), outputting the index of the queried instance. In our case, the arguments are ``regressor`` and ``X``.
|
||||
|
||||
.. code:: python
|
||||
|
||||
def GP_regression_std(regressor, X):
|
||||
_, std = regressor.predict(X, return_std=True)
|
||||
query_idx = np.argmax(std)
|
||||
return query_idx, X[query_idx]
|
||||
return np.argmax(std)
|
||||
|
||||
After setting up the query strategy and the data, the active learner can
|
||||
be initialized.
|
||||
|
||||
@@ -12,8 +12,7 @@ from modAL.models import ActiveLearner
|
||||
# query strategy for regression
|
||||
def GP_regression_std(regressor, X):
|
||||
_, std = regressor.predict(X, return_std=True)
|
||||
query_idx = np.argmax(std)
|
||||
return query_idx, X[query_idx]
|
||||
return np.argmax(std)
|
||||
|
||||
|
||||
# generating the data
|
||||
|
||||
@@ -5,18 +5,16 @@ Template for query strategies
|
||||
|
||||
The first two arguments of a query strategy function is always the estimator and the pool
|
||||
of instances to be queried from. Additional arguments are accepted as keyword arguments.
|
||||
A valid query strategy function always returns a tuple of the indices of the queried
|
||||
instances and the instances themselves.
|
||||
A valid query strategy function always returns indices of the queried
|
||||
instances.
|
||||
|
||||
def custom_query_strategy(classifier, X, a_keyword_argument=42):
|
||||
# measure the utility of each instance in the pool
|
||||
utility = utility_measure(classifier, X)
|
||||
|
||||
# select the indices of the instances to be queried
|
||||
query_idx = select_instances(utility)
|
||||
# select and return the indices of the instances to be queried
|
||||
return select_instances(utility)
|
||||
|
||||
# return the indices and the instances
|
||||
return query_idx, X[query_idx]
|
||||
|
||||
This function can be used in the active learning workflow.
|
||||
|
||||
@@ -97,8 +95,7 @@ with plt.style.context('seaborn-white'):
|
||||
# classifier uncertainty and classifier margin
|
||||
def custom_query_strategy(classifier, X, n_instances=1):
|
||||
utility = linear_combination(classifier, X)
|
||||
query_idx = multi_argmax(utility, n_instances=n_instances)
|
||||
return query_idx, X[query_idx]
|
||||
return multi_argmax(utility, n_instances=n_instances)
|
||||
|
||||
custom_query_learner = ActiveLearner(
|
||||
estimator=GaussianProcessClassifier(1.0 * RBF(1.0)),
|
||||
|
||||
@@ -62,12 +62,10 @@ def max_entropy(learner, X, n_instances=1, T=100):
|
||||
expected_p = np.mean(MC_samples, axis=0)
|
||||
acquisition = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1) # [batch size]
|
||||
idx = (-acquisition).argsort()[:n_instances]
|
||||
query_idx = random_subset[idx]
|
||||
return query_idx, X[query_idx]
|
||||
return random_subset[idx]
|
||||
|
||||
def uniform(learner, X, n_instances=1):
|
||||
query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False)
|
||||
return query_idx, X[query_idx]
|
||||
return np.random.choice(range(len(X)), size=n_instances, replace=False)
|
||||
|
||||
"""
|
||||
Training the ActiveLearner
|
||||
|
||||
@@ -57,8 +57,7 @@ final_prediction = learner.predict_proba(X_full)[:, 1].reshape(im_height, im_wid
|
||||
|
||||
|
||||
def random_sampling(classsifier, X):
|
||||
query_idx = np.random.randint(len(X))
|
||||
return query_idx, X[query_idx]
|
||||
return np.random.randint(len(X))
|
||||
|
||||
|
||||
X_pool = deepcopy(X_full)
|
||||
|
||||
@@ -104,7 +104,7 @@ Query strategies using acquisition functions
|
||||
|
||||
|
||||
def max_PI(optimizer: BaseLearner, X: modALinput, tradeoff: float = 0,
|
||||
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1) -> np.ndarray:
|
||||
"""
|
||||
Maximum PI query strategy. Selects the instance with highest probability of improvement.
|
||||
|
||||
@@ -118,13 +118,11 @@ def max_PI(optimizer: BaseLearner, X: modALinput, tradeoff: float = 0,
|
||||
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
|
||||
"""
|
||||
pi = optimizer_PI(optimizer, X, tradeoff=tradeoff)
|
||||
query_idx = multi_argmax(pi, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return multi_argmax(pi, n_instances=n_instances)
|
||||
|
||||
|
||||
def max_EI(optimizer: BaseLearner, X: modALinput, tradeoff: float = 0,
|
||||
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1) -> np.ndarray:
|
||||
"""
|
||||
Maximum EI query strategy. Selects the instance with highest expected improvement.
|
||||
|
||||
@@ -138,13 +136,11 @@ def max_EI(optimizer: BaseLearner, X: modALinput, tradeoff: float = 0,
|
||||
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
|
||||
"""
|
||||
ei = optimizer_EI(optimizer, X, tradeoff=tradeoff)
|
||||
query_idx = multi_argmax(ei, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return multi_argmax(ei, n_instances=n_instances)
|
||||
|
||||
|
||||
def max_UCB(optimizer: BaseLearner, X: modALinput, beta: float = 1,
|
||||
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1) -> np.ndarray:
|
||||
"""
|
||||
Maximum UCB query strategy. Selects the instance with highest upper confidence bound.
|
||||
|
||||
@@ -158,6 +154,4 @@ def max_UCB(optimizer: BaseLearner, X: modALinput, beta: float = 1,
|
||||
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
|
||||
"""
|
||||
ucb = optimizer_UCB(optimizer, X, beta=beta)
|
||||
query_idx = multi_argmax(ucb, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return multi_argmax(ucb, n_instances=n_instances)
|
||||
|
||||
@@ -114,7 +114,7 @@ def select_instance(
|
||||
unlabeled_indices = [i for i in range(n_pool) if mask[i]]
|
||||
best_instance_index = unlabeled_indices[best_instance_index_in_unlabeled]
|
||||
mask[best_instance_index] = 0
|
||||
return best_instance_index, np.expand_dims(X_pool[best_instance_index], axis=0), mask
|
||||
return best_instance_index, X_pool[[best_instance_index]], mask
|
||||
|
||||
|
||||
def ranked_batch(classifier: Union[BaseLearner, BaseCommittee],
|
||||
@@ -142,11 +142,16 @@ def ranked_batch(classifier: Union[BaseLearner, BaseCommittee],
|
||||
"""
|
||||
# Make a local copy of our classifier's training data.
|
||||
# Define our record container and record the best cold start instance in the case of cold start.
|
||||
|
||||
# transform unlabeled data if needed
|
||||
if classifier.on_transformed:
|
||||
unlabeled = classifier.transform_without_estimating(unlabeled)
|
||||
|
||||
if classifier.X_training is None:
|
||||
best_coldstart_instance_index, labeled = select_cold_start_instance(X=unlabeled, metric=metric, n_jobs=n_jobs)
|
||||
instance_index_ranking = [best_coldstart_instance_index]
|
||||
elif classifier.X_training.shape[0] > 0:
|
||||
labeled = classifier.X_training[:]
|
||||
labeled = classifier.Xt_training[:] if classifier.on_transformed else classifier.X_training[:]
|
||||
instance_index_ranking = []
|
||||
|
||||
# The maximum number of records to sample.
|
||||
@@ -180,7 +185,7 @@ def uncertainty_batch_sampling(classifier: Union[BaseLearner, BaseCommittee],
|
||||
metric: Union[str, Callable] = 'euclidean',
|
||||
n_jobs: Optional[int] = None,
|
||||
**uncertainty_measure_kwargs
|
||||
) -> Tuple[np.ndarray, Union[np.ndarray, sp.csr_matrix]]:
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Batch sampling query strategy. Selects the least sure instances for labelling.
|
||||
|
||||
@@ -206,6 +211,6 @@ def uncertainty_batch_sampling(classifier: Union[BaseLearner, BaseCommittee],
|
||||
Indices of the instances from `X` chosen to be labelled; records from `X` chosen to be labelled.
|
||||
"""
|
||||
uncertainty = classifier_uncertainty(classifier, X, **uncertainty_measure_kwargs)
|
||||
query_indices = ranked_batch(classifier, unlabeled=X, uncertainty_scores=uncertainty,
|
||||
return ranked_batch(classifier, unlabeled=X, uncertainty_scores=uncertainty,
|
||||
n_instances=n_instances, metric=metric, n_jobs=n_jobs)
|
||||
return query_indices, X[query_indices]
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ def KL_max_disagreement(committee: BaseCommittee, X: modALinput, **predict_proba
|
||||
|
||||
def vote_entropy_sampling(committee: BaseCommittee, X: modALinput,
|
||||
n_instances: int = 1, random_tie_break=False,
|
||||
**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
|
||||
**disagreement_measure_kwargs) -> np.ndarray:
|
||||
"""
|
||||
Vote entropy sampling strategy.
|
||||
|
||||
@@ -124,16 +124,14 @@ def vote_entropy_sampling(committee: BaseCommittee, X: modALinput,
|
||||
disagreement = vote_entropy(committee, X, **disagreement_measure_kwargs)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(disagreement, n_instances=n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(disagreement, n_instances=n_instances)
|
||||
return multi_argmax(disagreement, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(disagreement, n_instances=n_instances)
|
||||
|
||||
|
||||
def consensus_entropy_sampling(committee: BaseCommittee, X: modALinput,
|
||||
n_instances: int = 1, random_tie_break=False,
|
||||
**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
|
||||
**disagreement_measure_kwargs) -> np.ndarray:
|
||||
"""
|
||||
Consensus entropy sampling strategy.
|
||||
|
||||
@@ -153,16 +151,14 @@ def consensus_entropy_sampling(committee: BaseCommittee, X: modALinput,
|
||||
disagreement = consensus_entropy(committee, X, **disagreement_measure_kwargs)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(disagreement, n_instances=n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(disagreement, n_instances=n_instances)
|
||||
return multi_argmax(disagreement, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(disagreement, n_instances=n_instances)
|
||||
|
||||
|
||||
def max_disagreement_sampling(committee: BaseCommittee, X: modALinput,
|
||||
n_instances: int = 1, random_tie_break=False,
|
||||
**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
|
||||
**disagreement_measure_kwargs) -> np.ndarray:
|
||||
"""
|
||||
Maximum disagreement sampling strategy.
|
||||
|
||||
@@ -182,16 +178,14 @@ def max_disagreement_sampling(committee: BaseCommittee, X: modALinput,
|
||||
disagreement = KL_max_disagreement(committee, X, **disagreement_measure_kwargs)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(disagreement, n_instances=n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(disagreement, n_instances=n_instances)
|
||||
return multi_argmax(disagreement, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(disagreement, n_instances=n_instances)
|
||||
|
||||
|
||||
def max_std_sampling(regressor: BaseEstimator, X: modALinput,
|
||||
n_instances: int = 1, random_tie_break=False,
|
||||
**predict_kwargs) -> Tuple[np.ndarray, modALinput]:
|
||||
**predict_kwargs) -> np.ndarray:
|
||||
"""
|
||||
Regressor standard deviation sampling strategy.
|
||||
|
||||
@@ -211,8 +205,6 @@ def max_std_sampling(regressor: BaseEstimator, X: modALinput,
|
||||
std = std.reshape(X.shape[0], )
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(std, n_instances=n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(std, n_instances=n_instances)
|
||||
return multi_argmax(std, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(std, n_instances=n_instances)
|
||||
|
||||
@@ -10,14 +10,14 @@ from sklearn.base import clone
|
||||
from sklearn.exceptions import NotFittedError
|
||||
|
||||
from modAL.models import ActiveLearner
|
||||
from modAL.utils.data import modALinput, data_vstack
|
||||
from modAL.utils.data import modALinput, data_vstack, enumerate_data, drop_rows, data_shape, add_row
|
||||
from modAL.utils.selection import multi_argmax, shuffled_argmax
|
||||
from modAL.uncertainty import _proba_uncertainty, _proba_entropy
|
||||
|
||||
|
||||
def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str = 'binary',
|
||||
p_subsample: np.float = 1.0, n_instances: int = 1,
|
||||
random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
|
||||
random_tie_break: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Expected error reduction query strategy.
|
||||
|
||||
@@ -38,31 +38,30 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
|
||||
|
||||
|
||||
Returns:
|
||||
The indices of the instances from X chosen to be labelled;
|
||||
the instances from X chosen to be labelled.
|
||||
The indices of the instances from X chosen to be labelled.
|
||||
"""
|
||||
|
||||
assert 0.0 <= p_subsample <= 1.0, 'p_subsample subsampling keep ratio must be between 0.0 and 1.0'
|
||||
assert loss in ['binary', 'log'], 'loss must be \'binary\' or \'log\''
|
||||
|
||||
expected_error = np.zeros(shape=(len(X), ))
|
||||
expected_error = np.zeros(shape=(data_shape(X)[0],))
|
||||
possible_labels = np.unique(learner.y_training)
|
||||
|
||||
try:
|
||||
X_proba = learner.predict_proba(X)
|
||||
except NotFittedError:
|
||||
# TODO: implement a proper cold-start
|
||||
return 0, X[0]
|
||||
return np.array([0])
|
||||
|
||||
cloned_estimator = clone(learner.estimator)
|
||||
|
||||
for x_idx, x in enumerate(X):
|
||||
for x_idx, x in enumerate_data(X):
|
||||
# subsample the data if needed
|
||||
if np.random.rand() <= p_subsample:
|
||||
X_reduced = np.delete(X, x_idx, axis=0)
|
||||
X_reduced = drop_rows(X, x_idx)
|
||||
# estimate the expected error
|
||||
for y_idx, y in enumerate(possible_labels):
|
||||
X_new = data_vstack((learner.X_training, np.expand_dims(x, axis=0)))
|
||||
X_new = add_row(learner.X_training, x)
|
||||
y_new = data_vstack((learner.y_training, np.array(y).reshape(1,)))
|
||||
|
||||
cloned_estimator.fit(X_new, y_new)
|
||||
@@ -78,8 +77,6 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
|
||||
expected_error[x_idx] = np.inf
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(-expected_error, n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(-expected_error, n_instances)
|
||||
return multi_argmax(-expected_error, n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(-expected_error, n_instances)
|
||||
|
||||
@@ -5,14 +5,18 @@ Base classes for active learning algorithms
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Union, Callable, Optional, Tuple, List, Iterator, Any
|
||||
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.ensemble._base import _BaseHeterogeneousEnsemble
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.utils import check_X_y
|
||||
|
||||
from modAL.utils.data import data_vstack, modALinput
|
||||
import scipy.sparse as sp
|
||||
|
||||
from modAL.utils.data import data_vstack, data_hstack, modALinput, retrieve_rows
|
||||
|
||||
if sys.version_info >= (3, 4):
|
||||
ABC = abc.ABC
|
||||
@@ -34,6 +38,8 @@ class BaseLearner(ABC, BaseEstimator):
|
||||
When False, accepts np.nan and np.inf values.
|
||||
bootstrap_init: If initial training data is available, bootstrapping can be done during the first training.
|
||||
Useful when building Committee models with bagging.
|
||||
on_transformed: Whether to transform samples with the pipeline defined by the estimator
|
||||
when applying the query strategy.
|
||||
**fit_kwargs: keyword arguments.
|
||||
|
||||
Attributes:
|
||||
@@ -49,6 +55,7 @@ class BaseLearner(ABC, BaseEstimator):
|
||||
X_training: Optional[modALinput] = None,
|
||||
y_training: Optional[modALinput] = None,
|
||||
bootstrap_init: bool = False,
|
||||
on_transformed: bool = False,
|
||||
force_all_finite: bool = True,
|
||||
**fit_kwargs
|
||||
) -> None:
|
||||
@@ -56,11 +63,14 @@ class BaseLearner(ABC, BaseEstimator):
|
||||
|
||||
self.estimator = estimator
|
||||
self.query_strategy = query_strategy
|
||||
self.on_transformed = on_transformed
|
||||
|
||||
self.X_training = X_training
|
||||
self.Xt_training = None
|
||||
self.y_training = y_training
|
||||
if X_training is not None:
|
||||
self._fit_to_known(bootstrap=bootstrap_init, **fit_kwargs)
|
||||
self.Xt_training = self.transform_without_estimating(self.X_training) if self.on_transformed else None
|
||||
|
||||
assert isinstance(force_all_finite, bool), 'force_all_finite must be a bool'
|
||||
self.force_all_finite = force_all_finite
|
||||
@@ -82,15 +92,61 @@ class BaseLearner(ABC, BaseEstimator):
|
||||
|
||||
if self.X_training is None:
|
||||
self.X_training = X
|
||||
self.Xt_training = self.transform_without_estimating(self.X_training) if self.on_transformed else None
|
||||
self.y_training = y
|
||||
else:
|
||||
try:
|
||||
self.X_training = data_vstack((self.X_training, X))
|
||||
self.Xt_training = data_vstack((
|
||||
self.Xt_training,
|
||||
self.transform_without_estimating(X)
|
||||
)) if self.on_transformed else None
|
||||
self.y_training = data_vstack((self.y_training, y))
|
||||
except ValueError:
|
||||
raise ValueError('the dimensions of the new training data and label must'
|
||||
'agree with the training data and labels provided so far')
|
||||
|
||||
def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]:
|
||||
"""
|
||||
Transforms the data as supplied to the estimator.
|
||||
|
||||
* In case the estimator is an skearn pipeline, it applies all pipeline components but the last one.
|
||||
* In case the estimator is an ensemble, it concatenates the transformations for each classfier
|
||||
(pipeline) in the ensemble.
|
||||
* Otherwise returns the non-transformed dataset X
|
||||
Args:
|
||||
X: dataset to be transformed
|
||||
|
||||
Returns:
|
||||
Transformed data set
|
||||
"""
|
||||
Xt = []
|
||||
pipes = [self.estimator]
|
||||
|
||||
if isinstance(self.estimator, _BaseHeterogeneousEnsemble):
|
||||
pipes = self.estimator.estimators_
|
||||
|
||||
################################
|
||||
# transform data with pipelines used by estimator
|
||||
for pipe in pipes:
|
||||
if isinstance(pipe, Pipeline):
|
||||
# NOTE: The used pipeline class might be an extension to sklearn's!
|
||||
# Create a new instance of the used pipeline class with all
|
||||
# components but the final estimator, which is replaced by an empty (passthrough) component.
|
||||
# This prevents any special handling of the final transformation pipe, which is usually
|
||||
# expected to be an estimator.
|
||||
transformation_pipe = pipe.__class__(steps=[*pipe.steps[:-1], ('passthrough', 'passthrough')])
|
||||
Xt.append(transformation_pipe.transform(X))
|
||||
|
||||
# in case no transformation pipelines are used by the estimator,
|
||||
# return the original, non-transfored data
|
||||
if not Xt:
|
||||
return X
|
||||
|
||||
################################
|
||||
# concatenate all transformations and return
|
||||
return data_hstack(Xt)
|
||||
|
||||
def _fit_to_known(self, bootstrap: bool = False, **fit_kwargs) -> 'BaseLearner':
|
||||
"""
|
||||
Fits self.estimator to the training data and labels provided to it so far.
|
||||
@@ -157,6 +213,7 @@ class BaseLearner(ABC, BaseEstimator):
|
||||
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None,
|
||||
force_all_finite=self.force_all_finite)
|
||||
self.X_training, self.y_training = X, y
|
||||
self.Xt_training = self.transform_without_estimating(self.X_training) if self.on_transformed else None
|
||||
return self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
|
||||
|
||||
def predict(self, X: modALinput, **predict_kwargs) -> Any:
|
||||
@@ -185,11 +242,12 @@ class BaseLearner(ABC, BaseEstimator):
|
||||
"""
|
||||
return self.estimator.predict_proba(X, **predict_proba_kwargs)
|
||||
|
||||
def query(self, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
|
||||
def query(self, X_pool, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
|
||||
"""
|
||||
Finds the n_instances most informative point in the data provided by calling the query_strategy function.
|
||||
|
||||
Args:
|
||||
X_pool: Pool of unlabeled instances to retrieve most informative instances from
|
||||
*query_args: The arguments for the query strategy. For instance, in the case of
|
||||
:func:`~modAL.uncertainty.uncertainty_sampling`, it is the pool of samples from which the query strategy
|
||||
should choose instances to request labels.
|
||||
@@ -200,8 +258,15 @@ class BaseLearner(ABC, BaseEstimator):
|
||||
labelled and the instances themselves. Can be different in other cases, for instance only the instance to be
|
||||
labelled upon query synthesis.
|
||||
"""
|
||||
query_result = self.query_strategy(self, *query_args, **query_kwargs)
|
||||
return query_result
|
||||
query_result = self.query_strategy(self, X_pool, *query_args, **query_kwargs)
|
||||
|
||||
if isinstance(query_result, tuple):
|
||||
warnings.warn("Query strategies should no longer return the selected instances, "
|
||||
"this is now handled by the query method. "
|
||||
"Please return only the indices of the selected instances.", DeprecationWarning)
|
||||
return query_result
|
||||
|
||||
return query_result, retrieve_rows(X_pool, query_result)
|
||||
|
||||
def score(self, X: modALinput, y: modALinput, **score_kwargs) -> Any:
|
||||
"""
|
||||
@@ -229,12 +294,17 @@ class BaseCommittee(ABC, BaseEstimator):
|
||||
Args:
|
||||
learner_list: List of ActiveLearner objects to form committee.
|
||||
query_strategy: Function to query labels.
|
||||
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
|
||||
when applying the query strategy.
|
||||
"""
|
||||
def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable) -> None:
|
||||
def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable, on_transformed: bool = False) -> None:
|
||||
assert type(learner_list) == list, 'learners must be supplied in a list'
|
||||
|
||||
self.learner_list = learner_list
|
||||
self.query_strategy = query_strategy
|
||||
self.on_transformed = on_transformed
|
||||
# TODO: update training data when using fit() and teach() methods
|
||||
self.X_training = None
|
||||
|
||||
def __iter__(self) -> Iterator[BaseLearner]:
|
||||
for learner in self.learner_list:
|
||||
@@ -301,11 +371,23 @@ class BaseCommittee(ABC, BaseEstimator):
|
||||
|
||||
return self
|
||||
|
||||
def query(self, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
|
||||
def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]:
|
||||
"""
|
||||
Transforms the data as supplied to each learner's estimator and concatenates transformations.
|
||||
Args:
|
||||
X: dataset to be transformed
|
||||
|
||||
Returns:
|
||||
Transformed data set
|
||||
"""
|
||||
return data_hstack([learner.transform_without_estimating(X) for learner in self.learner_list])
|
||||
|
||||
def query(self, X_pool, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
|
||||
"""
|
||||
Finds the n_instances most informative point in the data provided by calling the query_strategy function.
|
||||
|
||||
Args:
|
||||
X_pool: Pool of unlabeled instances to retrieve most informative instances from
|
||||
*query_args: The arguments for the query strategy. For instance, in the case of
|
||||
:func:`~modAL.disagreement.max_disagreement_sampling`, it is the pool of samples from which the query.
|
||||
strategy should choose instances to request labels.
|
||||
@@ -316,8 +398,15 @@ class BaseCommittee(ABC, BaseEstimator):
|
||||
be labelled and the instances themselves. Can be different in other cases, for instance only the instance to
|
||||
be labelled upon query synthesis.
|
||||
"""
|
||||
query_result = self.query_strategy(self, *query_args, **query_kwargs)
|
||||
return query_result
|
||||
query_result = self.query_strategy(self, X_pool, *query_args, **query_kwargs)
|
||||
|
||||
if isinstance(query_result, tuple):
|
||||
warnings.warn("Query strategies should no longer return the selected instances, "
|
||||
"this is now handled by the query method. "
|
||||
"Please return only the indices of the selected instances", DeprecationWarning)
|
||||
return query_result
|
||||
|
||||
return query_result, retrieve_rows(X_pool, query_result)
|
||||
|
||||
def rebag(self, **fit_kwargs) -> None:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from sklearn.metrics import accuracy_score
|
||||
|
||||
from modAL.models.base import BaseLearner, BaseCommittee
|
||||
from modAL.utils.validation import check_class_labels, check_class_proba
|
||||
from modAL.utils.data import modALinput
|
||||
from modAL.utils.data import modALinput, retrieve_rows
|
||||
from modAL.uncertainty import uncertainty_sampling
|
||||
from modAL.disagreement import vote_entropy_sampling, max_std_sampling
|
||||
from modAL.acquisition import max_EI
|
||||
@@ -30,6 +30,8 @@ class ActiveLearner(BaseLearner):
|
||||
y_training: Initial training labels corresponding to initial training samples.
|
||||
bootstrap_init: If initial training data is available, bootstrapping can be done during the first training.
|
||||
Useful when building Committee models with bagging.
|
||||
on_transformed: Whether to transform samples with the pipeline defined by the estimator
|
||||
when applying the query strategy.
|
||||
**fit_kwargs: keyword arguments.
|
||||
|
||||
Attributes:
|
||||
@@ -73,10 +75,11 @@ class ActiveLearner(BaseLearner):
|
||||
X_training: Optional[modALinput] = None,
|
||||
y_training: Optional[modALinput] = None,
|
||||
bootstrap_init: bool = False,
|
||||
on_transformed: bool = False,
|
||||
**fit_kwargs
|
||||
) -> None:
|
||||
super().__init__(estimator, query_strategy,
|
||||
X_training, y_training, bootstrap_init, **fit_kwargs)
|
||||
X_training, y_training, bootstrap_init, on_transformed, **fit_kwargs)
|
||||
|
||||
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
|
||||
"""
|
||||
@@ -177,13 +180,14 @@ class BayesianOptimizer(BaseLearner):
|
||||
X_training: Optional[modALinput] = None,
|
||||
y_training: Optional[modALinput] = None,
|
||||
bootstrap_init: bool = False,
|
||||
on_transformed: bool = False,
|
||||
**fit_kwargs) -> None:
|
||||
super(BayesianOptimizer, self).__init__(estimator, query_strategy,
|
||||
X_training, y_training, bootstrap_init, **fit_kwargs)
|
||||
X_training, y_training, bootstrap_init, on_transformed, **fit_kwargs)
|
||||
# setting the maximum value
|
||||
if self.y_training is not None:
|
||||
max_idx = np.argmax(self.y_training)
|
||||
self.X_max = self.X_training[max_idx]
|
||||
self.X_max = retrieve_rows(self.X_training, max_idx)
|
||||
self.y_max = self.y_training[max_idx]
|
||||
else:
|
||||
self.X_max = None
|
||||
@@ -194,7 +198,7 @@ class BayesianOptimizer(BaseLearner):
|
||||
y_max = y[max_idx]
|
||||
if y_max > self.y_max:
|
||||
self.y_max = y_max
|
||||
self.X_max = X[max_idx]
|
||||
self.X_max = retrieve_rows(X, max_idx)
|
||||
|
||||
def get_max(self) -> Tuple:
|
||||
"""
|
||||
@@ -244,6 +248,8 @@ class Committee(BaseCommittee):
|
||||
learner_list: A list of ActiveLearners forming the Committee.
|
||||
query_strategy: Query strategy function. Committee supports disagreement-based query strategies from
|
||||
:mod:`modAL.disagreement`, but uncertainty-based ones from :mod:`modAL.uncertainty` are also supported.
|
||||
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
|
||||
when applying the query strategy.
|
||||
|
||||
Attributes:
|
||||
classes_: Class labels known by the Committee.
|
||||
@@ -284,8 +290,9 @@ class Committee(BaseCommittee):
|
||||
... y=iris['target'][query_idx].reshape(1, )
|
||||
... )
|
||||
"""
|
||||
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = vote_entropy_sampling) -> None:
|
||||
super().__init__(learner_list, query_strategy)
|
||||
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = vote_entropy_sampling,
|
||||
on_transformed: bool = False) -> None:
|
||||
super().__init__(learner_list, query_strategy, on_transformed)
|
||||
self._set_classes()
|
||||
|
||||
def _set_classes(self):
|
||||
@@ -452,6 +459,8 @@ class CommitteeRegressor(BaseCommittee):
|
||||
Args:
|
||||
learner_list: A list of ActiveLearners forming the CommitteeRegressor.
|
||||
query_strategy: Query strategy function.
|
||||
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
|
||||
when applying the query strategy.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -481,8 +490,7 @@ class CommitteeRegressor(BaseCommittee):
|
||||
>>> # query strategy for regression
|
||||
>>> def ensemble_regression_std(regressor, X):
|
||||
... _, std = regressor.predict(X, return_std=True)
|
||||
... query_idx = np.argmax(std)
|
||||
... return query_idx, X[query_idx]
|
||||
... return np.argmax(std)
|
||||
>>>
|
||||
>>> # initializing the CommitteeRegressor
|
||||
>>> committee = CommitteeRegressor(
|
||||
@@ -496,8 +504,9 @@ class CommitteeRegressor(BaseCommittee):
|
||||
... query_idx, query_instance = committee.query(X.reshape(-1, 1))
|
||||
... committee.teach(X[query_idx].reshape(-1, 1), y[query_idx].reshape(-1, 1))
|
||||
"""
|
||||
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = max_std_sampling) -> None:
|
||||
super().__init__(learner_list, query_strategy)
|
||||
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = max_std_sampling,
|
||||
on_transformed: bool = False) -> None:
|
||||
super().__init__(learner_list, query_strategy, on_transformed)
|
||||
|
||||
def predict(self, X: modALinput, return_std: bool = False, **predict_kwargs) -> Any:
|
||||
"""
|
||||
|
||||
@@ -43,7 +43,7 @@ def _SVM_loss(multiclass_classifier: ActiveLearner,
|
||||
|
||||
|
||||
def SVM_binary_minimum(classifier: ActiveLearner, X_pool: modALinput,
|
||||
random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
|
||||
random_tie_break: bool = False) -> np.ndarray:
|
||||
"""
|
||||
SVM binary minimum multilabel active learning strategy. For details see the paper
|
||||
Klaus Brinker, On Active Learning in Multi-label Classification
|
||||
@@ -67,15 +67,13 @@ def SVM_binary_minimum(classifier: ActiveLearner, X_pool: modALinput,
|
||||
min_abs_dist = np.min(np.abs(decision_function), axis=1)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = np.argmin(min_abs_dist)
|
||||
else:
|
||||
query_idx = shuffled_argmax(min_abs_dist)
|
||||
return np.argmin(min_abs_dist)
|
||||
|
||||
return query_idx, X_pool[query_idx]
|
||||
return shuffled_argmax(min_abs_dist)
|
||||
|
||||
|
||||
def max_loss(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> np.ndarray:
|
||||
|
||||
"""
|
||||
Max Loss query strategy for SVM multilabel classification.
|
||||
@@ -103,15 +101,13 @@ def max_loss(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
loss = _SVM_loss(classifier, X_pool, most_certain_classes=most_certain_classes)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(loss, n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(loss, n_instances)
|
||||
return multi_argmax(loss, n_instances)
|
||||
|
||||
return query_idx, X_pool[query_idx]
|
||||
return shuffled_argmax(loss, n_instances)
|
||||
|
||||
|
||||
def mean_max_loss(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Mean Max Loss query strategy for SVM multilabel classification.
|
||||
|
||||
@@ -136,15 +132,13 @@ def mean_max_loss(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
loss = _SVM_loss(classifier, X_pool)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(loss, n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(loss, n_instances)
|
||||
return multi_argmax(loss, n_instances)
|
||||
|
||||
return query_idx, X_pool[query_idx]
|
||||
return shuffled_argmax(loss, n_instances)
|
||||
|
||||
|
||||
def min_confidence(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> np.ndarray:
|
||||
"""
|
||||
MinConfidence query strategy for multilabel classification.
|
||||
|
||||
@@ -167,15 +161,13 @@ def min_confidence(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
classwise_min = np.min(classwise_confidence, axis=1)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(-classwise_min, n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(-classwise_min, n_instances)
|
||||
return multi_argmax(-classwise_min, n_instances)
|
||||
|
||||
return query_idx, X_pool[query_idx]
|
||||
return shuffled_argmax(-classwise_min, n_instances)
|
||||
|
||||
|
||||
def avg_confidence(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> np.ndarray:
|
||||
"""
|
||||
AvgConfidence query strategy for multilabel classification.
|
||||
|
||||
@@ -198,15 +190,13 @@ def avg_confidence(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
classwise_mean = np.mean(classwise_confidence, axis=1)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(classwise_mean, n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(classwise_mean, n_instances)
|
||||
return multi_argmax(classwise_mean, n_instances)
|
||||
|
||||
return query_idx, X_pool[query_idx]
|
||||
return shuffled_argmax(classwise_mean, n_instances)
|
||||
|
||||
|
||||
def max_score(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = 1) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1, random_tie_break: bool = 1) -> np.ndarray:
|
||||
"""
|
||||
MaxScore query strategy for multilabel classification.
|
||||
|
||||
@@ -231,15 +221,13 @@ def max_score(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
classwise_max = np.max(classwise_scores, axis=1)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(classwise_max, n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(classwise_max, n_instances)
|
||||
return multi_argmax(classwise_max, n_instances)
|
||||
|
||||
return query_idx, X_pool[query_idx]
|
||||
return shuffled_argmax(classwise_max, n_instances)
|
||||
|
||||
|
||||
def avg_score(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
|
||||
n_instances: int = 1, random_tie_break: bool = False) -> np.ndarray:
|
||||
"""
|
||||
AvgScore query strategy for multilabel classification.
|
||||
|
||||
@@ -264,8 +252,6 @@ def avg_score(classifier: OneVsRestClassifier, X_pool: modALinput,
|
||||
classwise_mean = np.mean(classwise_scores, axis=1)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(classwise_mean, n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(classwise_mean, n_instances)
|
||||
return multi_argmax(classwise_mean, n_instances)
|
||||
|
||||
return query_idx, X_pool[query_idx]
|
||||
return shuffled_argmax(classwise_mean, n_instances)
|
||||
|
||||
@@ -132,7 +132,7 @@ def classifier_entropy(classifier: BaseEstimator, X: modALinput, **predict_proba
|
||||
|
||||
def uncertainty_sampling(classifier: BaseEstimator, X: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False,
|
||||
**uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
|
||||
**uncertainty_measure_kwargs) -> np.ndarray:
|
||||
"""
|
||||
Uncertainty sampling query strategy. Selects the least sure instances for labelling.
|
||||
|
||||
@@ -152,16 +152,14 @@ def uncertainty_sampling(classifier: BaseEstimator, X: modALinput,
|
||||
uncertainty = classifier_uncertainty(classifier, X, **uncertainty_measure_kwargs)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(uncertainty, n_instances=n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(uncertainty, n_instances=n_instances)
|
||||
return multi_argmax(uncertainty, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(uncertainty, n_instances=n_instances)
|
||||
|
||||
|
||||
def margin_sampling(classifier: BaseEstimator, X: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False,
|
||||
**uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
|
||||
**uncertainty_measure_kwargs) -> np.ndarray:
|
||||
"""
|
||||
Margin sampling query strategy. Selects the instances where the difference between
|
||||
the first most likely and second most likely classes are the smallest.
|
||||
@@ -180,16 +178,14 @@ def margin_sampling(classifier: BaseEstimator, X: modALinput,
|
||||
margin = classifier_margin(classifier, X, **uncertainty_measure_kwargs)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(-margin, n_instances=n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(-margin, n_instances=n_instances)
|
||||
return multi_argmax(-margin, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(-margin, n_instances=n_instances)
|
||||
|
||||
|
||||
def entropy_sampling(classifier: BaseEstimator, X: modALinput,
|
||||
n_instances: int = 1, random_tie_break: bool = False,
|
||||
**uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
|
||||
**uncertainty_measure_kwargs) -> np.ndarray:
|
||||
"""
|
||||
Entropy sampling query strategy. Selects the instances where the class probabilities
|
||||
have the largest entropy.
|
||||
@@ -210,8 +206,6 @@ def entropy_sampling(classifier: BaseEstimator, X: modALinput,
|
||||
entropy = classifier_entropy(classifier, X, **uncertainty_measure_kwargs)
|
||||
|
||||
if not random_tie_break:
|
||||
query_idx = multi_argmax(entropy, n_instances=n_instances)
|
||||
else:
|
||||
query_idx = shuffled_argmax(entropy, n_instances=n_instances)
|
||||
return multi_argmax(entropy, n_instances=n_instances)
|
||||
|
||||
return query_idx, X[query_idx]
|
||||
return shuffled_argmax(entropy, n_instances=n_instances)
|
||||
|
||||
@@ -78,7 +78,6 @@ def make_query_strategy(utility_measure: Callable, selector: Callable) -> Callab
|
||||
"""
|
||||
def query_strategy(classifier: BaseEstimator, X: modALinput) -> Tuple:
|
||||
utility = utility_measure(classifier, X)
|
||||
query_idx = selector(utility)
|
||||
return query_idx, X[query_idx]
|
||||
return selector(utility)
|
||||
|
||||
return query_strategy
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from typing import Union, Container
|
||||
from itertools import chain
|
||||
from typing import Union, List, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy.sparse as sp
|
||||
|
||||
|
||||
modALinput = Union[list, np.ndarray, sp.csr_matrix]
|
||||
modALinput = Union[sp.csr_matrix, pd.DataFrame, np.ndarray, list]
|
||||
|
||||
|
||||
def data_vstack(blocks: Container) -> modALinput:
|
||||
def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
|
||||
"""
|
||||
Stack vertically both sparse and dense arrays.
|
||||
Stack vertically sparse/dense arrays and pandas data frames.
|
||||
|
||||
Args:
|
||||
blocks: Sequence of modALinput objects.
|
||||
@@ -18,14 +18,137 @@ def data_vstack(blocks: Container) -> modALinput:
|
||||
Returns:
|
||||
New sequence of vertically stacked elements.
|
||||
"""
|
||||
if isinstance(blocks[0], np.ndarray):
|
||||
if any([sp.issparse(b) for b in blocks]):
|
||||
return sp.vstack(blocks)
|
||||
elif isinstance(blocks[0], pd.DataFrame):
|
||||
return blocks[0].append(blocks[1:])
|
||||
elif isinstance(blocks[0], np.ndarray):
|
||||
return np.concatenate(blocks)
|
||||
elif isinstance(blocks[0], list):
|
||||
return list(chain(blocks))
|
||||
elif sp.issparse(blocks[0]):
|
||||
return sp.vstack(blocks)
|
||||
else:
|
||||
return np.concatenate(blocks).tolist()
|
||||
|
||||
raise TypeError('%s datatype is not supported' % type(blocks[0]))
|
||||
|
||||
|
||||
def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
|
||||
"""
|
||||
Stack horizontally sparse/dense arrays and pandas data frames.
|
||||
|
||||
Args:
|
||||
blocks: Sequence of modALinput objects.
|
||||
|
||||
Returns:
|
||||
New sequence of horizontally stacked elements.
|
||||
"""
|
||||
if any([sp.issparse(b) for b in blocks]):
|
||||
return sp.hstack(blocks)
|
||||
elif isinstance(blocks[0], pd.DataFrame):
|
||||
pd.concat(blocks, axis=1)
|
||||
elif isinstance(blocks[0], np.ndarray):
|
||||
return np.hstack(blocks)
|
||||
elif isinstance(blocks[0], list):
|
||||
return np.hstack(blocks).tolist()
|
||||
|
||||
TypeError('%s datatype is not supported' % type(blocks[0]))
|
||||
|
||||
|
||||
def add_row(X:modALinput, row: modALinput):
|
||||
"""
|
||||
Returns X' =
|
||||
|
||||
[X
|
||||
|
||||
row]
|
||||
"""
|
||||
if isinstance(X, np.ndarray):
|
||||
return np.vstack((X, row))
|
||||
elif isinstance(X, list):
|
||||
return np.vstack((X, row)).tolist()
|
||||
|
||||
# data_vstack readily supports stacking of matrix as first argument
|
||||
# and row as second for the other data types
|
||||
return data_vstack([X, row])
|
||||
|
||||
|
||||
def retrieve_rows(X: modALinput,
|
||||
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
|
||||
"""
|
||||
Returns the rows I from the data set X
|
||||
|
||||
For a single index, the result is as follows:
|
||||
* 1xM matrix in case of scipy sparse NxM matrix X
|
||||
* pandas series in case of a pandas data frame
|
||||
* row in case of list or numpy format
|
||||
"""
|
||||
if sp.issparse(X):
|
||||
# Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
|
||||
# sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
|
||||
# and sp.dia_matrix don't support indexing and need to be converted to a sparse format
|
||||
# that does support indexing. It seems conversion to CSR is currently most efficient.
|
||||
|
||||
try:
|
||||
return np.concatenate(blocks)
|
||||
return X[I]
|
||||
except:
|
||||
raise TypeError('%s datatype is not supported' % type(blocks[0]))
|
||||
sp_format = X.getformat()
|
||||
return X.tocsr()[I].asformat(sp_format)
|
||||
elif isinstance(X, pd.DataFrame):
|
||||
return X.iloc[I]
|
||||
elif isinstance(X, np.ndarray):
|
||||
return X[I]
|
||||
elif isinstance(X, list):
|
||||
return np.array(X)[I].tolist()
|
||||
|
||||
raise TypeError('%s datatype is not supported' % type(X))
|
||||
|
||||
|
||||
def drop_rows(X: modALinput,
|
||||
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
|
||||
"""
|
||||
Returns X without the row(s) at index/indices I
|
||||
"""
|
||||
if sp.issparse(X):
|
||||
mask = np.ones(X.shape[0], dtype=bool)
|
||||
mask[I] = False
|
||||
return retrieve_rows(X, mask)
|
||||
elif isinstance(X, pd.DataFrame):
|
||||
return X.drop(I, axis=0)
|
||||
elif isinstance(X, np.ndarray):
|
||||
return np.delete(X, I, axis=0)
|
||||
elif isinstance(X, list):
|
||||
return np.delete(X, I, axis=0).tolist()
|
||||
|
||||
raise TypeError('%s datatype is not supported' % type(X))
|
||||
|
||||
|
||||
def enumerate_data(X: modALinput):
|
||||
"""
|
||||
for i, x in enumerate_data(X):
|
||||
|
||||
Depending on the data type of X, returns:
|
||||
|
||||
* A 1xM matrix in case of scipy sparse NxM matrix X
|
||||
* pandas series in case of a pandas data frame X
|
||||
* row in case of list or numpy format
|
||||
"""
|
||||
if sp.issparse(X):
|
||||
return enumerate(X.tocsr())
|
||||
elif isinstance(X, pd.DataFrame):
|
||||
return X.iterrows()
|
||||
elif isinstance(X, np.ndarray) or isinstance(X, list):
|
||||
# numpy arrays and lists can readily be enumerated
|
||||
return enumerate(X)
|
||||
|
||||
raise TypeError('%s datatype is not supported' % type(X))
|
||||
|
||||
|
||||
def data_shape(X: modALinput):
|
||||
"""
|
||||
Returns the shape of the data set X
|
||||
"""
|
||||
if sp.issparse(X) or isinstance(X, pd.DataFrame) or isinstance(X, np.ndarray):
|
||||
# scipy.sparse, pandas and numpy all support .shape
|
||||
return X.shape
|
||||
elif isinstance(X, list):
|
||||
return np.array(X).shape
|
||||
|
||||
raise TypeError('%s datatype is not supported' % type(X))
|
||||
|
||||
@@ -3,3 +3,4 @@ scipy
|
||||
scikit-learn
|
||||
ipykernel
|
||||
nbsphinx
|
||||
pandas
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,5 +10,5 @@ setup(
|
||||
url='https://modAL-python.github.io/',
|
||||
packages=['modAL', 'modAL.models', 'modAL.utils'],
|
||||
classifiers=['Development Status :: 4 - Beta'],
|
||||
install_requires=['numpy>=1.13', 'scikit-learn>=0.18', 'scipy>=0.18'],
|
||||
install_requires=['numpy>=1.13', 'scikit-learn>=0.18', 'scipy>=0.18', 'pandas>=1.1.0'],
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import mock
|
||||
import modAL.models.base
|
||||
@@ -26,6 +27,8 @@ from sklearn.exceptions import NotFittedError
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.svm import SVC
|
||||
from sklearn.multiclass import OneVsRestClassifier
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import FunctionTransformer
|
||||
from scipy.stats import entropy, norm
|
||||
from scipy.special import ndtr
|
||||
from scipy import sparse as sp
|
||||
@@ -140,8 +143,7 @@ class TestUtils(unittest.TestCase):
|
||||
query_1 = query_strategy(learner, X)
|
||||
query_2 = modAL.uncertainty.uncertainty_sampling(learner, X)
|
||||
|
||||
np.testing.assert_equal(query_1[0], query_2[0])
|
||||
np.testing.assert_almost_equal(query_1[1], query_2[1])
|
||||
np.testing.assert_equal(query_1, query_2)
|
||||
|
||||
def test_data_vstack(self):
|
||||
for n_samples, n_features in product(range(1, 10), range(1, 10)):
|
||||
@@ -455,21 +457,24 @@ class TestDisagreements(unittest.TestCase):
|
||||
class TestEER(unittest.TestCase):
|
||||
def test_eer(self):
|
||||
for n_pool, n_features, n_classes in product(range(5, 10), range(1, 5), range(2, 5)):
|
||||
X_training, y_training = np.random.rand(10, n_features), np.random.randint(0, n_classes, size=10)
|
||||
X_pool, y_pool = np.random.rand(n_pool, n_features), np.random.randint(0, n_classes+1, size=n_pool)
|
||||
X_training_, y_training = np.random.rand(10, n_features).tolist(), np.random.randint(0, n_classes, size=10)
|
||||
X_pool_, y_pool = np.random.rand(n_pool, n_features).tolist(), np.random.randint(0, n_classes+1, size=n_pool)
|
||||
|
||||
learner = modAL.models.ActiveLearner(RandomForestClassifier(n_estimators=2),
|
||||
X_training=X_training, y_training=y_training)
|
||||
for data_type in (sp.csr_matrix, pd.DataFrame, np.array, list):
|
||||
X_training, X_pool = data_type(X_training_), data_type(X_pool_)
|
||||
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool)
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, random_tie_break=True)
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1)
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, loss='binary')
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1, loss='log')
|
||||
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
|
||||
learner, X_pool, p_subsample=1.5)
|
||||
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
|
||||
learner, X_pool, loss=42)
|
||||
learner = modAL.models.ActiveLearner(RandomForestClassifier(n_estimators=2),
|
||||
X_training=X_training, y_training=y_training)
|
||||
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool)
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, random_tie_break=True)
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1)
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, loss='binary')
|
||||
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1, loss='log')
|
||||
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
|
||||
learner, X_pool, p_subsample=1.5)
|
||||
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
|
||||
learner, X_pool, loss=42)
|
||||
|
||||
|
||||
class TestUncertainties(unittest.TestCase):
|
||||
@@ -560,10 +565,10 @@ class TestUncertainties(unittest.TestCase):
|
||||
predict_proba = np.random.rand(n_samples, n_classes)
|
||||
predict_proba[true_query_idx] = max_proba
|
||||
classifier = mock.MockEstimator(predict_proba_return=predict_proba)
|
||||
query_idx, query_instance = modAL.uncertainty.uncertainty_sampling(
|
||||
query_idx = modAL.uncertainty.uncertainty_sampling(
|
||||
classifier, np.random.rand(n_samples, n_classes)
|
||||
)
|
||||
shuffled_query_idx, shuffled_query_instance = modAL.uncertainty.uncertainty_sampling(
|
||||
shuffled_query_idx = modAL.uncertainty.uncertainty_sampling(
|
||||
classifier, np.random.rand(n_samples, n_classes),
|
||||
random_tie_break=True
|
||||
)
|
||||
@@ -577,10 +582,10 @@ class TestUncertainties(unittest.TestCase):
|
||||
predict_proba[:, 0] = 1.0
|
||||
predict_proba[true_query_idx, 0] = 0.0
|
||||
classifier = mock.MockEstimator(predict_proba_return=predict_proba)
|
||||
query_idx, query_instance = modAL.uncertainty.margin_sampling(
|
||||
query_idx = modAL.uncertainty.margin_sampling(
|
||||
classifier, np.random.rand(n_samples, n_classes)
|
||||
)
|
||||
shuffled_query_idx, shuffled_query_instance = modAL.uncertainty.margin_sampling(
|
||||
shuffled_query_idx = modAL.uncertainty.margin_sampling(
|
||||
classifier, np.random.rand(n_samples, n_classes),
|
||||
random_tie_break=True
|
||||
)
|
||||
@@ -595,10 +600,10 @@ class TestUncertainties(unittest.TestCase):
|
||||
predict_proba[:, 0] = 1.0
|
||||
predict_proba[true_query_idx] = max_proba
|
||||
classifier = mock.MockEstimator(predict_proba_return=predict_proba)
|
||||
query_idx, query_instance = modAL.uncertainty.entropy_sampling(
|
||||
query_idx = modAL.uncertainty.entropy_sampling(
|
||||
classifier, np.random.rand(n_samples, n_classes)
|
||||
)
|
||||
shuffled_query_idx, shuffled_query_instance = modAL.uncertainty.entropy_sampling(
|
||||
shuffled_query_idx = modAL.uncertainty.entropy_sampling(
|
||||
classifier, np.random.rand(n_samples, n_classes),
|
||||
random_tie_break=True
|
||||
)
|
||||
@@ -698,7 +703,7 @@ class TestActiveLearner(unittest.TestCase):
|
||||
for n_features in range(1, 10):
|
||||
X = np.random.rand(n_samples, n_features)
|
||||
query_idx = np.random.randint(0, n_samples)
|
||||
mock_query = mock.MockFunction(return_val=(query_idx, X[query_idx]))
|
||||
mock_query = mock.MockFunction(return_val=query_idx)
|
||||
learner = modAL.models.learners.ActiveLearner(
|
||||
estimator=None,
|
||||
query_strategy=mock_query
|
||||
@@ -789,6 +794,68 @@ class TestActiveLearner(unittest.TestCase):
|
||||
query_idx, query_inst = learner.query(X_pool)
|
||||
learner.teach(X_pool[query_idx], y_pool[query_idx])
|
||||
|
||||
def test_on_transformed(self):
|
||||
n_samples = 10
|
||||
n_features = 5
|
||||
query_strategies = [
|
||||
modAL.batch.uncertainty_batch_sampling
|
||||
# add further strategies which work with instance representations
|
||||
# no further ones as of 25.09.2020
|
||||
]
|
||||
X_pool = np.random.rand(n_samples, n_features)
|
||||
|
||||
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
|
||||
X_pool = pd.DataFrame(X_pool)
|
||||
|
||||
y_pool = np.random.randint(0, 2, size=(n_samples,))
|
||||
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
|
||||
|
||||
for query_strategy in query_strategies:
|
||||
learner = modAL.models.learners.ActiveLearner(
|
||||
estimator=make_pipeline(
|
||||
FunctionTransformer(func=pd.DataFrame.to_numpy),
|
||||
RandomForestClassifier(n_estimators=10)
|
||||
),
|
||||
query_strategy=query_strategy,
|
||||
X_training=X_pool.iloc[train_idx],
|
||||
y_training=y_pool[train_idx],
|
||||
on_transformed=True
|
||||
)
|
||||
query_idx, query_inst = learner.query(X_pool)
|
||||
learner.teach(X_pool.iloc[query_idx], y_pool[query_idx])
|
||||
|
||||
def test_old_query_strategy_interface(self):
|
||||
n_samples = 10
|
||||
n_features = 5
|
||||
X_pool = np.random.rand(n_samples, n_features)
|
||||
y_pool = np.random.randint(0, 2, size=(n_samples,))
|
||||
|
||||
# defining a custom query strategy also returning the selected instance
|
||||
# make sure even if a query strategy works in some funny way
|
||||
# (e.g. instance not matching instance index),
|
||||
# the old interface remains unchanged
|
||||
query_idx_ = np.random.choice(n_samples, 2)
|
||||
query_instance_ = X_pool[(query_idx_ + 1) % len(X_pool)]
|
||||
|
||||
def custom_query_strategy(classifier, X):
|
||||
return query_idx_, query_instance_
|
||||
|
||||
|
||||
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
|
||||
custom_query_learner = modAL.models.learners.ActiveLearner(
|
||||
estimator=RandomForestClassifier(n_estimators=10),
|
||||
query_strategy=custom_query_strategy,
|
||||
X_training=X_pool[train_idx], y_training=y_pool[train_idx]
|
||||
)
|
||||
|
||||
query_idx, query_instance = custom_query_learner.query(X_pool)
|
||||
custom_query_learner.teach(
|
||||
X=X_pool[query_idx],
|
||||
y=y_pool[query_idx]
|
||||
)
|
||||
np.testing.assert_equal(query_idx, query_idx_)
|
||||
np.testing.assert_equal(query_instance, query_instance_)
|
||||
|
||||
|
||||
class TestBayesianOptimizer(unittest.TestCase):
|
||||
def test_set_max(self):
|
||||
@@ -898,6 +965,39 @@ class TestBayesianOptimizer(unittest.TestCase):
|
||||
)
|
||||
learner.teach(X, y, bootstrap=bootstrap, only_new=only_new)
|
||||
|
||||
def test_on_transformed(self):
|
||||
n_samples = 10
|
||||
n_features = 5
|
||||
query_strategies = [
|
||||
# TODO remove, added just to make sure on_transformed doesn't break anything
|
||||
# but it has no influence on this strategy, nothing special tested here
|
||||
mock.MockFunction(return_val=[np.random.randint(0, n_samples)])
|
||||
|
||||
# add further strategies which work with instance representations
|
||||
# no further ones as of 25.09.2020
|
||||
]
|
||||
X_pool = np.random.rand(n_samples, n_features)
|
||||
|
||||
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
|
||||
X_pool = pd.DataFrame(X_pool)
|
||||
|
||||
y_pool = np.random.rand(n_samples)
|
||||
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
|
||||
|
||||
for query_strategy in query_strategies:
|
||||
learner = modAL.models.learners.BayesianOptimizer(
|
||||
estimator=make_pipeline(
|
||||
FunctionTransformer(func=pd.DataFrame.to_numpy),
|
||||
GaussianProcessRegressor()
|
||||
),
|
||||
query_strategy=query_strategy,
|
||||
X_training=X_pool.iloc[train_idx],
|
||||
y_training=y_pool[train_idx],
|
||||
on_transformed=True
|
||||
)
|
||||
query_idx, query_inst = learner.query(X_pool)
|
||||
learner.teach(X_pool.iloc[query_idx], y_pool[query_idx])
|
||||
|
||||
|
||||
class TestCommittee(unittest.TestCase):
|
||||
|
||||
@@ -1008,6 +1108,42 @@ class TestCommittee(unittest.TestCase):
|
||||
|
||||
committee.teach(X, y, bootstrap=bootstrap, only_new=only_new)
|
||||
|
||||
def test_on_transformed(self):
|
||||
n_samples = 10
|
||||
n_features = 5
|
||||
query_strategies = [
|
||||
modAL.batch.uncertainty_batch_sampling
|
||||
# add further strategies which work with instance representations
|
||||
# no further ones as of 25.09.2020
|
||||
]
|
||||
X_pool = np.random.rand(n_samples, n_features)
|
||||
|
||||
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
|
||||
X_pool = pd.DataFrame(X_pool)
|
||||
|
||||
y_pool = np.random.randint(0, 2, size=(n_samples,))
|
||||
train_idx = np.random.choice(range(n_samples), size=5, replace=False)
|
||||
|
||||
learner_list = [modAL.models.learners.ActiveLearner(
|
||||
estimator=make_pipeline(
|
||||
FunctionTransformer(func=pd.DataFrame.to_numpy),
|
||||
RandomForestClassifier(n_estimators=10)
|
||||
),
|
||||
# committee learners can contain different amounts of
|
||||
# different instances
|
||||
X_training=X_pool.iloc[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
|
||||
y_training=y_pool[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
|
||||
) for i in range(3)]
|
||||
|
||||
for query_strategy in query_strategies:
|
||||
committee = modAL.models.learners.Committee(
|
||||
learner_list=learner_list,
|
||||
query_strategy=query_strategy,
|
||||
on_transformed=True
|
||||
)
|
||||
query_idx, query_inst = committee.query(X_pool)
|
||||
committee.teach(X_pool.iloc[query_idx], y_pool[query_idx])
|
||||
|
||||
|
||||
class TestCommitteeRegressor(unittest.TestCase):
|
||||
|
||||
@@ -1041,6 +1177,45 @@ class TestCommitteeRegressor(unittest.TestCase):
|
||||
vote_output
|
||||
)
|
||||
|
||||
def test_on_transformed(self):
|
||||
n_samples = 10
|
||||
n_features = 5
|
||||
query_strategies = [
|
||||
# TODO remove, added just to make sure on_transformed doesn't break anything
|
||||
# but it has no influence on this strategy, nothing special tested here
|
||||
mock.MockFunction(return_val=[np.random.randint(0, n_samples)])
|
||||
|
||||
# add further strategies which work with instance representations
|
||||
# no further ones as of 25.09.2020
|
||||
]
|
||||
X_pool = np.random.rand(n_samples, n_features)
|
||||
|
||||
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
|
||||
X_pool = pd.DataFrame(X_pool)
|
||||
|
||||
y_pool = np.random.rand(n_samples)
|
||||
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
|
||||
|
||||
learner_list = [modAL.models.learners.ActiveLearner(
|
||||
estimator=make_pipeline(
|
||||
FunctionTransformer(func=pd.DataFrame.to_numpy),
|
||||
GaussianProcessRegressor()
|
||||
),
|
||||
# committee learners can contain different amounts of
|
||||
# different instances
|
||||
X_training=X_pool.iloc[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
|
||||
y_training=y_pool[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
|
||||
) for i in range(3)]
|
||||
|
||||
for query_strategy in query_strategies:
|
||||
committee = modAL.models.learners.CommitteeRegressor(
|
||||
learner_list=learner_list,
|
||||
query_strategy=query_strategy,
|
||||
on_transformed=True
|
||||
)
|
||||
query_idx, query_inst = committee.query(X_pool)
|
||||
committee.teach(X_pool.iloc[query_idx], y_pool[query_idx])
|
||||
|
||||
|
||||
class TestMultilabel(unittest.TestCase):
|
||||
def test_SVM_loss(self):
|
||||
@@ -1107,4 +1282,3 @@ class TestExamples(unittest.TestCase):
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
0
|
||||
@@ -42,8 +42,7 @@ product = make_product(
|
||||
# classifier uncertainty and classifier margin
|
||||
def custom_query_strategy(classifier, X, n_instances=1):
|
||||
utility = linear_combination(classifier, X)
|
||||
query_idx = multi_argmax(utility, n_instances=n_instances)
|
||||
return query_idx, X[query_idx]
|
||||
return multi_argmax(utility, n_instances=n_instances)
|
||||
|
||||
custom_query_learner = ActiveLearner(
|
||||
estimator=GaussianProcessClassifier(1.0 * RBF(1.0)),
|
||||
|
||||
Reference in New Issue
Block a user