adaptivesplit.base package

Submodules

adaptivesplit.base.learning_curve module

class adaptivesplit.base.learning_curve.LearningCurve(df=None, ns=None, data=None, scoring=None, curve_names=None, curve_type=None, description=None)[source]

Bases: object

Class for storing and handling learning curves. Wraps a pandas dataframe with rows for learning curves and columns for sample sizes It can handle multiple learning curves (e.g. for different bootstrap samples).

Args:
df (array-like, pandas.DataFrame):

Initialize by a DataFrame, already being in the correct format. Defaults to None.

ns (int):

Alias for number of samples (sample sizes). Defaults to None.

data (np.ndarray):

Array of shape (ns, n_curves). Contains score value per sample size. Defaults to None.

scoring (str, callable, list, tuple or dict):

Scikit-learn-like score to evaluate the performance of the cross-validated model on the test set. If scoring represents a single score, one can use:

  • a single string (see The scoring parameter: defining model evaluation rules);

  • a callable (see Defining your scoring strategy from metric functions) that returns a single value.

If scoring represents multiple scores, one can use:

  • a list or tuple of unique strings;

  • a callable returning a dictionary where the keys are the metric names and the values are the metric scores;

  • a dictionary with metric names as keys and callables a values.

If None, the estimator’s score method is used. Defaults to None.

curve_names (str, optional):

Curve names metadata. Defaults to None.

curve_type (str, optional):

Curve type metadata. Defaults to None.

description (str, optional):

Curve description. Defaults to None.

dump()[source]

Dump learning curve data.

stat(mid='mean', ci='95%')[source]

Return simple descriptive learning curve stats (e.g for plotting purposes). Stat names should be given according to pandas.describe(). Most common choices: ‘mean’, ‘50%’, ‘std’, ‘95%’.

Args:
mid (str):

Stat to calculate mid value.

ci (str):

Confidence level.

Extra keywords:

‘stderr’ gives mid +- stderr/2

adaptivesplit.base.power module

class adaptivesplit.base.power.PowerEstimatorBootstrap(power_stat_fun, stratify=None, total_sample_size=None, alpha=0.001, bootstrap_samples=100, n_jobs=None, verbose=True, message='Estimating Power with bootstrap')[source]

Bases: _PowerEstimatorBase

Calculate power using bootstrap.

Args:
power_stat_fun (callable):

Statistical function used to calculate power.

stratify (int):

For classification tasks. If not None, use stratified sampling to account for class labels imbalance. Defaults to None.

total_sample_size (int):

The total number of samples in the data given as input. Defaults to None.

alpha (float):

Statistical threshold used to calculate power. Defaults to 0.001.

bootstrap_samples (int):

Number of samples selected during bootstrapping. Defaults to 100.

n_jobs (int):

Number of jobs to run in parallel. Defaults to None. Power calculations are parallelized over the cross-validation splits. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.

verbose (bool):

Prints progress. Defaults to True.

message (str, optional):

Message shown when power estimation starts. Defaults to ‘Estimating Power with bootstrap’.

adaptivesplit.base.resampling module

class adaptivesplit.base.resampling.PermTest(stat_fun, num_samples=1000, n_jobs=-1, compare=<built-in function ge>, verbose=True, message='Permutation test')[source]

Bases: _ResampleBase

Implements a permutation test.

Args:
stat_fun (callable):

Statistical function used to evaluate statistical significance.

num_samples (int):

Number of samples generated during permutation. Defaults to 1000.

n_jobs (int, optional):

Number of jobs to run in parallel. Defaults to -1. Power calculations are parallelized over the cross-validation splits. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.

compare (callable):

User defined comparison function. Defaults to operator.ge.

verbose (bool, optional):

Prints progress. Defaults to True.

message (str, optional):

Message shown when permutation test starts. Defaults to “Permutation test”.

class adaptivesplit.base.resampling.Resample(stat_fun, sample_size, stratify=None, num_samples=1000, replacement=True, first_unshuffled=False, n_jobs=- 1, verbose=True, message='Resampling')[source]

Bases: _ResampleBase

Implements re-sampling.

Args:
stat_fun (callable):

User defined statistical function.

sample_size (int):

Current sample size.

stratify (int):

For classification tasks. If not None, use stratified sampling to account for class labels imbalance. Defaults to None.

num_samples (int, optional):

Number of samples. Defaults to 1000.

replacement (bool, optional):

Whether or not to sample with replacement. Defaults to True.

first_unshuffled (bool, optional):

Whether or not to shuffle the samples first. Defaults to False.

n_jobs (int, optional):

Number of jobs to run in parallel. Defaults to -1. Power calculations are parallelized over the cross-validation splits. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.

verbose (bool, optional):

Print progress. Defaults to True.

message (str, optional):

Message shown when resampling starts. Defaults to “Resampling”.

adaptivesplit.base.split module

adaptivesplit.base.split.estimate_sample_size(y_obs, y_pred, target_power, power_estimator, max_iter=100, rel_pwr_threshold=0.001, learning_rate=0.05)[source]

Compute current sample size for power calculations.

Args:
y_obs (np.ndarray):

The observed target samples.

y_pred (np.ndarray):

The predicted target samples.

target_power (float):

Target power.

power_estimator (adaptivesplit.power.PowerEstimatorBootstrap):

Estimator to calculate power.

max_iter (int):

Max number of iterations for sample estimation. Defaults to 100.

rel_pwr_threshold (float):

Power threshold’s tolerance. Defaults to 0.001.

learning_rate (float):

The learning rate value. Defaults to 0.05.

Returns:
sample_size (int):

The estimated sample size.

adaptivesplit.base.split.plot(learning_curve, learning_curve_predicted=None, power_curve=None, power_curve_lower=None, power_curve_upper=None, power_curve_predicted=None, training_curve=None, dummy_curve=None, stop=None, reason=None, ci='95%', grid=True, subplot_kw=None, gridspec_kw=None, **kwargs)[source]

Plot the results.

Args:
learning_curve (adaptivesplit.base.learning_curve.LearningCurve object):

The learning curve calculated during validation.

learning_curve_predicted (adaptivesplit.base.learning_curve.LearningCurve object):

The predicted learning curve. Defaults to None.

power_curve (adaptivesplit.base.learning_curve.LearningCurve object):

The calculated power curve. Defaults to None.

power_curve_lower ():

Power curve confidence intervals lower bound. Defaults to None.

power_curve_upper ():

Power curve confidence intervals upper bound. Defaults to None.

power_curve_predicted (adaptivesplit.base.learning_curve.LearningCurve object):

The predicted power curve. Defaults to None.

training_curve (adaptivesplit.base.learning_curve.LearningCurve object):

Learning curve calculated during training. Defaults to None.

dummy_curve (adaptivesplit.base.learning_curve.LearningCurve object):

Baseline curve calculated using a dummy estimator. Defaults to None.

stop (int):

Sample size where the stopping point lies. Defaults to None.

reason (str or list of str):

Reason or list of reasons describing how the stopping rule found the stopping point. Defaults to None.

ci (str):

Intervals confidence for the learning curve. Defaults to ‘95%’.

grid (bool):

Whether or not to configure the grid lines. Defaults to True.

subplot_kw (dict):

Dict with keywords passed to matplotlib used to create each subplot. Defaults to None.

gridspec_kw (dict):

Dict with keywords passed to matplotlib used to create the grid the subplots are placed on. Defaults to None.

Returns:
Figure (matplotlib.pyplot.figure):

Plot containing the learning and power curves with the estimated stopping point.

adaptivesplit.base.utils module

adaptivesplit.base.utils.tqdm_joblib(tqdm_object)[source]

Context manager to patch joblib to report into tqdm progress bar given as argument Based on: https://stackoverflow.com/questions/37804279/how-can-we-use-tqdm-in-a-parallel-execution-with-joblib