adaptivesplit.base package
Submodules
adaptivesplit.base.learning_curve module
- class adaptivesplit.base.learning_curve.LearningCurve(df=None, ns=None, data=None, scoring=None, curve_names=None, curve_type=None, description=None)[source]
Bases:
object
Class for storing and handling learning curves. Wraps a pandas dataframe with rows for learning curves and columns for sample sizes It can handle multiple learning curves (e.g. for different bootstrap samples).
- Args:
- df (array-like, pandas.DataFrame):
Initialize by a DataFrame, already being in the correct format. Defaults to None.
- ns (int):
Alias for number of samples (sample sizes). Defaults to None.
- data (np.ndarray):
Array of shape (ns, n_curves). Contains score value per sample size. Defaults to None.
- scoring (str, callable, list, tuple or dict):
Scikit-learn-like score to evaluate the performance of the cross-validated model on the test set. If scoring represents a single score, one can use:
a single string (see The scoring parameter: defining model evaluation rules);
a callable (see Defining your scoring strategy from metric functions) that returns a single value.
If scoring represents multiple scores, one can use:
a list or tuple of unique strings;
a callable returning a dictionary where the keys are the metric names and the values are the metric scores;
a dictionary with metric names as keys and callables a values.
If None, the estimator’s score method is used. Defaults to None.
- curve_names (str, optional):
Curve names metadata. Defaults to None.
- curve_type (str, optional):
Curve type metadata. Defaults to None.
- description (str, optional):
Curve description. Defaults to None.
- stat(mid='mean', ci='95%')[source]
Return simple descriptive learning curve stats (e.g for plotting purposes). Stat names should be given according to pandas.describe(). Most common choices: ‘mean’, ‘50%’, ‘std’, ‘95%’.
- Args:
- mid (str):
Stat to calculate mid value.
- ci (str):
Confidence level.
- Extra keywords:
‘stderr’ gives mid +- stderr/2
adaptivesplit.base.power module
- class adaptivesplit.base.power.PowerEstimatorBootstrap(power_stat_fun, stratify=None, total_sample_size=None, alpha=0.001, bootstrap_samples=100, n_jobs=None, verbose=True, message='Estimating Power with bootstrap')[source]
Bases:
_PowerEstimatorBase
Calculate power using bootstrap.
- Args:
- power_stat_fun (callable):
Statistical function used to calculate power.
- stratify (int):
For classification tasks. If not None, use stratified sampling to account for class labels imbalance. Defaults to None.
- total_sample_size (int):
The total number of samples in the data given as input. Defaults to None.
- alpha (float):
Statistical threshold used to calculate power. Defaults to 0.001.
- bootstrap_samples (int):
Number of samples selected during bootstrapping. Defaults to 100.
- n_jobs (int):
Number of jobs to run in parallel. Defaults to None. Power calculations are parallelized over the cross-validation splits. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.
- verbose (bool):
Prints progress. Defaults to True.
- message (str, optional):
Message shown when power estimation starts. Defaults to ‘Estimating Power with bootstrap’.
adaptivesplit.base.resampling module
- class adaptivesplit.base.resampling.PermTest(stat_fun, num_samples=1000, n_jobs=-1, compare=<built-in function ge>, verbose=True, message='Permutation test')[source]
Bases:
_ResampleBase
Implements a permutation test.
- Args:
- stat_fun (callable):
Statistical function used to evaluate statistical significance.
- num_samples (int):
Number of samples generated during permutation. Defaults to 1000.
- n_jobs (int, optional):
Number of jobs to run in parallel. Defaults to -1. Power calculations are parallelized over the cross-validation splits. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.
- compare (callable):
User defined comparison function. Defaults to operator.ge.
- verbose (bool, optional):
Prints progress. Defaults to True.
- message (str, optional):
Message shown when permutation test starts. Defaults to “Permutation test”.
- class adaptivesplit.base.resampling.Resample(stat_fun, sample_size, stratify=None, num_samples=1000, replacement=True, first_unshuffled=False, n_jobs=- 1, verbose=True, message='Resampling')[source]
Bases:
_ResampleBase
Implements re-sampling.
- Args:
- stat_fun (callable):
User defined statistical function.
- sample_size (int):
Current sample size.
- stratify (int):
For classification tasks. If not None, use stratified sampling to account for class labels imbalance. Defaults to None.
- num_samples (int, optional):
Number of samples. Defaults to 1000.
- replacement (bool, optional):
Whether or not to sample with replacement. Defaults to True.
- first_unshuffled (bool, optional):
Whether or not to shuffle the samples first. Defaults to False.
- n_jobs (int, optional):
Number of jobs to run in parallel. Defaults to -1. Power calculations are parallelized over the cross-validation splits. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.
- verbose (bool, optional):
Print progress. Defaults to True.
- message (str, optional):
Message shown when resampling starts. Defaults to “Resampling”.
adaptivesplit.base.split module
- adaptivesplit.base.split.estimate_sample_size(y_obs, y_pred, target_power, power_estimator, max_iter=100, rel_pwr_threshold=0.001, learning_rate=0.05)[source]
Compute current sample size for power calculations.
- Args:
- y_obs (np.ndarray):
The observed target samples.
- y_pred (np.ndarray):
The predicted target samples.
- target_power (float):
Target power.
- power_estimator (adaptivesplit.power.PowerEstimatorBootstrap):
Estimator to calculate power.
- max_iter (int):
Max number of iterations for sample estimation. Defaults to 100.
- rel_pwr_threshold (float):
Power threshold’s tolerance. Defaults to 0.001.
- learning_rate (float):
The learning rate value. Defaults to 0.05.
- Returns:
- sample_size (int):
The estimated sample size.
- adaptivesplit.base.split.plot(learning_curve, learning_curve_predicted=None, power_curve=None, power_curve_lower=None, power_curve_upper=None, power_curve_predicted=None, training_curve=None, dummy_curve=None, stop=None, reason=None, ci='95%', grid=True, subplot_kw=None, gridspec_kw=None, **kwargs)[source]
Plot the results.
- Args:
- learning_curve (adaptivesplit.base.learning_curve.LearningCurve object):
The learning curve calculated during validation.
- learning_curve_predicted (adaptivesplit.base.learning_curve.LearningCurve object):
The predicted learning curve. Defaults to None.
- power_curve (adaptivesplit.base.learning_curve.LearningCurve object):
The calculated power curve. Defaults to None.
- power_curve_lower ():
Power curve confidence intervals lower bound. Defaults to None.
- power_curve_upper ():
Power curve confidence intervals upper bound. Defaults to None.
- power_curve_predicted (adaptivesplit.base.learning_curve.LearningCurve object):
The predicted power curve. Defaults to None.
- training_curve (adaptivesplit.base.learning_curve.LearningCurve object):
Learning curve calculated during training. Defaults to None.
- dummy_curve (adaptivesplit.base.learning_curve.LearningCurve object):
Baseline curve calculated using a dummy estimator. Defaults to None.
- stop (int):
Sample size where the stopping point lies. Defaults to None.
- reason (str or list of str):
Reason or list of reasons describing how the stopping rule found the stopping point. Defaults to None.
- ci (str):
Intervals confidence for the learning curve. Defaults to ‘95%’.
- grid (bool):
Whether or not to configure the grid lines. Defaults to True.
- subplot_kw (dict):
Dict with keywords passed to matplotlib used to create each subplot. Defaults to None.
- gridspec_kw (dict):
Dict with keywords passed to matplotlib used to create the grid the subplots are placed on. Defaults to None.
- Returns:
- Figure (matplotlib.pyplot.figure):
Plot containing the learning and power curves with the estimated stopping point.
adaptivesplit.base.utils module
- adaptivesplit.base.utils.tqdm_joblib(tqdm_object)[source]
Context manager to patch joblib to report into tqdm progress bar given as argument Based on: https://stackoverflow.com/questions/37804279/how-can-we-use-tqdm-in-a-parallel-execution-with-joblib