scmkl
Single-cell analysis using Multiple Kernel Learning, scMKL, is a binary classification algorithm utilizing prior information to group features to enhance classification and aid understanding of distinguishing features in multi-omic data sets.
Installation
Conda install
Conda is the recommended method to install scMKL:
conda create -n scMKL python=3.12
conda activate scMKL
conda install -c conda-forge ivango17::scmkl
Pip install
First, create a virtual environment with python>=3.11.1,<3.13
.
Then, install scMKL with:
# activate your new env with python>=3.11.1 and <3.13
pip install scmkl
If wheels do not build correctly, ensure gcc
and g++
are installed and up to date. They can be installed with sudo apt install gcc
and sudo apt install g++
.
Requirements
scMKL takes advantage of AnnData objects and can be implemented with just four pieces of data:
1) scRNA and/or scATAC matrices (can be scipy.sparse
matrix)
2) An array of cell labels
3) An array of feature names (eg. gene symbols for RNA or peaks for ATAC)
4) A grouping dictionary where {'group_1' : [feature_5, feature_16], 'group_2' : [feature_1, feature_4, feature_9]}
For implementing scMKL and learning how to get meaningful feature groupings, see our examples for your use case in examples.
Links
Repo: https://github.com/ohsu-cedar-comp-hub/scMKL
PyPI: https://pypi.org/project/scmkl/
Anaconda: https://anaconda.org/ivango17/scmkl
API: https://ohsu-cedar-comp-hub.github.io/scMKL/
Publication
If you use scMKL in your research, please cite using:
Kupp, S., VanGordon, I., Gönen, M., Esener, S., Eksi, S., Ak, C. Interpretable and integrative analysis of single-cell multiomics with scMKL. Commun Biol 8, 1160 (2025). https://doi.org/10.1038/s42003-025-08533-7
Our Shiny for Python application for viewing data produced from this work can be found here: scMKL_analysis
Issues
Please report bugs here.
Examples
Here are helpful examples for running scMKL, and getting the data required to run scMKL.
File | Description |
---|---|
getting_gene_groupings.ipynb | Different ways to get gene sets in a usable format with scMKL |
getting_region_groupings.ipynb | How to group genomic regions using a gene set library and a GTF file |
RNA_analysis.ipynb | Running scMKL using only single-cell RNA data |
ATAC_analysis.ipynb | Running scMKL using only single-cell ATAC data |
multimodal_analysis.ipynb | Running scMKL using single-cell multi-omic data (RNA + ATAC) |
scMKL Documentation
1""" 2.. include:: ../README.md 3.. include:: ../example/README.md 4 5---------------------------- 6 7## **scMKL Documentation** 8""" 9 10 11__all__ = ['calculate_z', 12 'calculate_d', 13 'create_adata', 14 'data_processing', 15 'dataframes', 16 'estimate_sigma', 17 'extract_results', 18 'find_candidates', 19 'format_adata', 20 'format_group_names', 21 'get_gene_groupings', 22 'get_metrics', 23 'get_region_groupings', 24 'get_selection', 25 'get_summary', 26 'get_weights', 27 'groups_per_alpha', 28 'group_umap', 29 'multimodal_processing', 30 'one_v_rest', 31 'optimize_alpha', 32 'optimize_sparsity', 33 'parse_metrics', 34 'parse_weights', 35 'plotting', 36 'plot_metric', 37 'plot_conf_mat', 38 'projections', 39 'read_files', 40 'read_gtf', 41 'run', 42 'sort_groups', 43 'test', 44 'tfidf_normalize', 45 'train_model', 46 'weights_barplot', 47 'weights_dotplot', 48 'weights_heatmap' 49 ] 50 51from scmkl._checks import * 52from scmkl.calculate_z import * 53from scmkl.create_adata import * 54from scmkl.data_processing import * 55from scmkl.dataframes import * 56from scmkl.estimate_sigma import * 57from scmkl.get_gene_groupings import * 58from scmkl.get_region_groupings import * 59from scmkl.multimodal_processing import * 60from scmkl.one_v_rest import * 61from scmkl.optimize_alpha import * 62from scmkl.optimize_sparsity import * 63from scmkl.plotting import * 64from scmkl.projections import * 65from scmkl.run import * 66from scmkl.test import * 67from scmkl.tfidf_normalize import * 68from scmkl.train_model import * 69from scmkl.projections import *
78def calculate_z(adata, n_features=5000, batches=10, 79 batch_size=100) -> ad.AnnData: 80 """ 81 Function to calculate Z matrices for all groups in both training 82 and testing data. 83 84 Parameters 85 ---------- 86 adata : ad.AnnData 87 created by `scmkl.create_adata()` with `adata.uns.keys()`: 88 `'train_indices'`, and `'test_indices'`. 89 90 n_features : int 91 Number of random feature to use when calculating Z; used for 92 scalability. 93 94 batches : int 95 The number of batches to use for the distance calculation. 96 This will average the result of `batches` distance calculations 97 of `batch_size` randomly sampled cells. More batches will converge 98 to population distance values at the cost of scalability. 99 100 batch_size : int 101 The number of cells to include per batch for distance 102 calculations. Higher batch size will converge to population 103 distance values at the cost of scalability. 104 If `batches*batch_size > num_training_cells`, 105 `batch_size` will be reduced to 106 `int(num_training_cells / batches)`. 107 108 Returns 109 ------- 110 adata : ad.AnnData 111 `adata` with Z matrices accessible with `adata.uns['Z_train']` 112 and `adata.uns['Z_test']`. 113 114 Examples 115 -------- 116 >>> adata = scmkl.estimate_sigma(adata) 117 >>> adata = scmkl.calculate_z(adata) 118 >>> adata.uns.keys() 119 dict_keys(['Z_train', 'Z_test', 'sigmas', 'train_indices', 120 'test_indices']) 121 """ 122 # Number of groupings taking from group_dict 123 n_pathway = len(adata.uns['group_dict'].keys()) 124 D = adata.uns['D'] 125 126 sq_i_d = np.sqrt(1/D) 127 128 # Capturing training and testing sizes 129 train_len = len(adata.uns['train_indices']) 130 test_len = len(adata.uns['test_indices']) 131 132 if batch_size * batches > len(adata.uns['train_indices']): 133 old_batch_size = batch_size 134 batch_size = int(len(adata.uns['train_indices'])/batches) 135 print("Specified batch size required too many cells for " 136 "independent batches. Reduced batch size from " 137 f"{old_batch_size} to {batch_size}") 138 139 if 'sigma' not in adata.uns.keys(): 140 n_samples = np.min((2000, adata.uns['train_indices'].shape[0])) 141 sample_range = np.arange(n_samples) 142 batch_idx = get_batches(sample_range, adata.uns['seed_obj'], 143 batches=batches, batch_size=batch_size) 144 sigma_indices = sample_cells(adata.uns['train_indices'], n_samples, adata.uns['seed_obj']) 145 146 # Create Arrays to store concatenated group Zs 147 # Each group of features will have a corresponding entry in each array 148 n_cols = 2*adata.uns['D']*n_pathway 149 Z_train = np.zeros((train_len, n_cols)) 150 Z_test = np.zeros((test_len, n_cols)) 151 152 153 # Setting kernel function 154 match adata.uns['kernel_type'].lower(): 155 case 'gaussian': 156 proj_func = gaussian_trans 157 case 'laplacian': 158 proj_func = laplacian_trans 159 case 'cauchy': 160 proj_func = cauchy_trans 161 162 163 # Loop over each of the groups and creating Z for each 164 sigma_list = list() 165 for m, group_features in enumerate(adata.uns['group_dict'].values()): 166 167 n_group_features = len(group_features) 168 169 X_train, X_test = get_group_mat(adata, n_features, group_features, 170 n_group_features, process_test=True) 171 172 if adata.uns['tfidf']: 173 X_train, X_test = tfidf_train_test(X_train, X_test) 174 175 # Data filtering, and transformation according to given data_type 176 # Will remove low variance (< 1e5) features regardless of data_type 177 # If scale_data will log scale and z-score the data 178 X_train, X_test = process_data(X_train=X_train, X_test=X_test, 179 scale_data=adata.uns['scale_data'], 180 return_dense=True) 181 182 # Getting sigma 183 if 'sigma' in adata.uns.keys(): 184 sigma = adata.uns['sigma'][m] 185 else: 186 sigma = est_group_sigma(adata, X_train, n_group_features, 187 n_features, batch_idx=batch_idx) 188 sigma_list.append(sigma) 189 190 assert sigma > 0, "Sigma must be more than 0" 191 train_projection, test_projection = calc_groupz(X_train, X_test, 192 adata, D, sigma, 193 proj_func) 194 195 # Store group Z in whole-Z object 196 # Preserves order to be able to extract meaningful groups 197 cos_idx, sin_idx = get_z_indices(m, D) 198 199 Z_train[0:, cos_idx] = np.cos(train_projection) 200 Z_train[0:, sin_idx] = np.sin(train_projection) 201 202 Z_test[0:, cos_idx] = np.cos(test_projection) 203 Z_test[0:, sin_idx] = np.sin(test_projection) 204 205 adata.uns['Z_train'] = Z_train*sq_i_d 206 adata.uns['Z_test'] = Z_test*sq_i_d 207 208 if 'sigma' not in adata.uns.keys(): 209 adata.uns['sigma'] = np.array(sigma_list) 210 211 return adata
Function to calculate Z matrices for all groups in both training and testing data.
Parameters
- adata (ad.AnnData):
created by
scmkl.create_adata
withadata.uns.keys()
:'train_indices'
, and'test_indices'
. - n_features (int): Number of random feature to use when calculating Z; used for scalability.
- batches (int):
The number of batches to use for the distance calculation.
This will average the result of
batches
distance calculations ofbatch_size
randomly sampled cells. More batches will converge to population distance values at the cost of scalability. - batch_size (int):
The number of cells to include per batch for distance
calculations. Higher batch size will converge to population
distance values at the cost of scalability.
If
batches*batch_size > num_training_cells
,batch_size
will be reduced toint(num_training_cells / batches)
.
Returns
- adata (ad.AnnData):
adata
with Z matrices accessible withadata.uns['Z_train']
andadata.uns['Z_test']
.
Examples
>>> adata = scmkl.estimate_sigma(adata)
>>> adata = scmkl.calculate_z(adata)
>>> adata.uns.keys()
dict_keys(['Z_train', 'Z_test', 'sigmas', 'train_indices',
'test_indices'])
198def calculate_d(num_samples : int): 199 """ 200 This function calculates the optimal number of dimensions for 201 performance. See https://doi.org/10.48550/arXiv.1806.09178 for more 202 information. 203 204 Parameters 205 ---------- 206 num_samples : int 207 The number of samples in the data set including both training 208 and testing sets. 209 210 Returns 211 ------- 212 d : int 213 The optimal number of dimensions to run scMKL with the given 214 data set. 215 216 Examples 217 -------- 218 >>> raw_counts = scipy.sparse.load_npz('MCF7_counts.npz') 219 >>> 220 >>> num_cells = raw_counts.shape[0] 221 >>> d = scmkl.calculate_d(num_cells) 222 >>> d 223 161 224 """ 225 d = int(np.sqrt(num_samples)*np.log(np.log(num_samples))) 226 227 return int(np.max([d, 100]))
This function calculates the optimal number of dimensions for performance. See https://doi.org/10.48550/arXiv.1806.09178 for more information.
Parameters
- num_samples (int): The number of samples in the data set including both training and testing sets.
Returns
- d (int): The optimal number of dimensions to run scMKL with the given data set.
Examples
>>> raw_counts = scipy.sparse.load_npz('MCF7_counts.npz')
>>>
>>> num_cells = raw_counts.shape[0]
>>> d = scmkl.calculate_d(num_cells)
>>> d
161
264def create_adata(X: scipy.sparse._csc.csc_matrix | np.ndarray | pd.DataFrame, 265 feature_names: np.ndarray, cell_labels: np.ndarray, 266 group_dict: dict, obs_names: None | np.ndarray=None, 267 scale_data: bool=True, split_data: np.ndarray | None=None, 268 D: int | None=None, remove_features: bool=True, 269 train_ratio: float=0.8, distance_metric: str='euclidean', 270 kernel_type: str='Gaussian', random_state: int=1, 271 allow_multiclass: bool = False, 272 class_threshold: str | int = 'median', 273 reduction: str | None = None, tfidf: bool = False): 274 """ 275 Function to create an AnnData object to carry all relevant 276 information going forward. 277 278 Parameters 279 ---------- 280 X : scipy.sparse.csc_matrix | np.ndarray | pd.DataFrame 281 A data matrix of cells by features (sparse array 282 recommended for large datasets). 283 284 feature_names : np.ndarray 285 Array of feature names corresponding with the features 286 in `X`. 287 288 cell_labels : np.ndarray 289 A numpy array of cell phenotypes corresponding with 290 the cells in `X`. 291 292 group_dict : dict 293 Dictionary containing feature grouping information (i.e. 294 `{geneset1: np.array([gene_1, gene_2, ..., gene_n]), geneset2: 295 np.array([...]), ...}`. 296 297 obs_names : None | np.ndarray 298 The cell names corresponding to `X` to be assigned to output 299 object `.obs_names` attribute. 300 301 scale_data : bool 302 If `True`, data matrix is log transformed and standard 303 scaled. 304 305 split_data : None | np.ndarray 306 If `None`, data will be split stratified by cell labels. 307 Else, is an array of precalculated train/test split 308 corresponding to samples. Can include labels for entire 309 dataset to benchmark performance or for only training 310 data to classify unknown cell types (i.e. `np.array(['train', 311 'test', ..., 'train'])`. 312 313 D : int 314 Number of Random Fourier Features used to calculate Z. 315 Should be a positive integer. Higher values of D will 316 increase classification accuracy at the cost of computation 317 time. If set to `None`, will be calculated given number of 318 samples. 319 320 remove_features : bool 321 If `True`, will remove features from `X` and `feature_names` 322 not in `group_dict` and remove features from groupings not in 323 `feature_names`. 324 325 train_ratio : float 326 Ratio of number of training samples to entire data set. Note: 327 if a threshold is applied, the ratio training samples may 328 decrease depending on class balance and `class_threshold` 329 parameter if `allow_multiclass = True`. 330 331 distance_metric : str 332 The pairwise distance metric used to estimate sigma. Must 333 be one of the options used in `scipy.spatial.distance.cdist`. 334 335 kernel_type : str 336 The approximated kernel function used to calculate Zs. 337 Must be one of `'Gaussian'`, `'Laplacian'`, or `'Cauchy'`. 338 339 random_state : int 340 Integer random_state used to set the seed for 341 reproducibilty. 342 343 allow_multiclass : bool 344 If `False`, will ensure that cell labels are binary. 345 346 class_threshold : str | int 347 Number of samples allowed in the training data for each cell 348 class in the training data. If `'median'`, the median number 349 of cells per cell class will be the threshold for number of 350 samples per class. 351 352 reduction: str | None 353 Choose which dimension reduction technique to perform on 354 features within a group. 'svd' will run 355 `sklearn.decomposition.TruncatedSVD`, 'linear' will multiply 356 by an array of 1s down to 50 dimensions. 357 358 tfidf: bool 359 Whether to calculate TFIDF transformation on peaks within 360 groupings. 361 362 Returns 363 ------- 364 adata : ad.AnnData 365 AnnData with the following attributes and keys: 366 367 `adata.X` (array_like): 368 Data matrix. 369 370 `adata.var_names` (array_like): 371 Feature names corresponding to `adata.X`. 372 373 `adata.obs['labels']` (array_like): 374 cell classes/phenotypes from `cell_labels`. 375 376 `adata.uns['train_indices']` (array_like): 377 Indices for training data. 378 379 `adata.uns['test_indices']` (array_like) 380 Indices for testing data. 381 382 `adata.uns['group_dict']` (dict): 383 Grouping information. 384 385 `adata.uns['seed_obj']` (np.random._generator.Generator): 386 Seed object with seed equal to 100 * `random_state`. 387 388 `adata.uns['D']` (int): 389 Number of dimensions to scMKL with. 390 391 `adata.uns['scale_data']` (bool): 392 Whether or not data is log and z-score transformed. 393 394 `adata.uns['distance_metric']` (str): 395 Distance metric as given. 396 397 `adata.uns['kernel_type']` (str): 398 Kernel function as given. 399 400 `adata.uns['svd']` (bool): 401 Whether to calculate SVD reduction. 402 403 `adata.uns['tfidf']` (bool): 404 Whether to calculate TF-IDF per grouping. 405 406 Examples 407 -------- 408 >>> data_mat = scipy.sparse.load_npz('MCF7_RNA_matrix.npz') 409 >>> gene_names = np.load('MCF7_gene_names.pkl', allow_pickle = True) 410 >>> group_dict = np.load('hallmark_genesets.pkl', 411 >>> allow_pickle = True) 412 >>> 413 >>> adata = scmkl.create_adata(X = data_mat, 414 ... feature_names = gene_names, 415 ... group_dict = group_dict) 416 >>> adata 417 AnnData object with n_obs Ă— n_vars = 1000 Ă— 4341 418 obs: 'labels' 419 uns: 'group_dict', 'seed_obj', 'scale_data', 'D', 'kernel_type', 420 'distance_metric', 'train_indices', 'test_indices' 421 """ 422 423 assert X.shape[1] == len(feature_names), ("Different number of features " 424 "in X than feature names") 425 426 if not allow_multiclass: 427 assert len(np.unique(cell_labels)) == 2, ("cell_labels must contain " 428 "2 classes") 429 if D is not None: 430 assert isinstance(D, int) and D > 0, 'D must be a positive integer' 431 432 kernel_options = ['gaussian', 'laplacian', 'cauchy'] 433 assert kernel_type.lower() in kernel_options, ("Given kernel type not " 434 "implemented. Gaussian, " 435 "Laplacian, and Cauchy " 436 "are the acceptable " 437 "types.") 438 439 # Create adata object and add column names 440 adata = ad.AnnData(X) 441 adata.var_names = feature_names 442 443 if isinstance(obs_names, (np.ndarray)): 444 adata.obs_names = obs_names 445 446 filtered_feature_names, group_dict = _filter_features(feature_names, 447 group_dict) 448 449 if remove_features: 450 warnings.filterwarnings('ignore', category = ad.ImplicitModificationWarning) 451 adata = adata[:, filtered_feature_names] 452 453 gc.collect() 454 455 # Add metadata to adata object 456 adata.uns['group_dict'] = group_dict 457 adata.uns['seed_obj'] = np.random.default_rng(100*random_state) 458 adata.uns['scale_data'] = scale_data 459 adata.uns['D'] = D if D is not None else calculate_d(adata.shape[0]) 460 adata.uns['kernel_type'] = kernel_type 461 adata.uns['distance_metric'] = distance_metric 462 adata.uns['reduction'] = reduction if isinstance(reduction, str) else 'None' 463 adata.uns['tfidf'] = tfidf 464 465 if (split_data is None): 466 assert X.shape[0] == len(cell_labels), ("Different number of cells " 467 "than labels") 468 adata.obs['labels'] = cell_labels 469 470 if (allow_multiclass == False): 471 split = _binary_split(cell_labels, 472 seed_obj = adata.uns['seed_obj'], 473 train_ratio = train_ratio) 474 train_indices, test_indices = split 475 476 elif (allow_multiclass == True): 477 split = _multi_class_split(cell_labels, 478 seed_obj = adata.uns['seed_obj'], 479 class_threshold = class_threshold, 480 train_ratio = train_ratio) 481 train_indices, test_indices = split 482 483 adata.uns['labeled_test'] = True 484 485 else: 486 x_eq_labs = X.shape[0] == len(cell_labels) 487 train_eq_labs = X.shape[0] == len(cell_labels) 488 assert x_eq_labs or train_eq_labs, ("Must give labels for all cells " 489 "or only for training cells") 490 491 train_indices = np.where(split_data == 'train')[0] 492 test_indices = np.where(split_data == 'test')[0] 493 494 if len(cell_labels) == len(train_indices): 495 496 padded_cell_labels = np.zeros((X.shape[0])).astype('object') 497 padded_cell_labels[train_indices] = cell_labels 498 padded_cell_labels[test_indices] = 'padded_test_label' 499 500 adata.obs['labels'] = padded_cell_labels 501 adata.uns['labeled_test'] = False 502 503 elif len(cell_labels) == len(split_data): 504 adata.obs['labels'] = cell_labels 505 adata.uns['labeled_test'] = True 506 507 # Ensuring all train samples are first in adata object followed by test 508 sort_idx, train_indices, test_indices = sort_samples(train_indices, 509 test_indices) 510 511 adata = adata[sort_idx] 512 513 if not isinstance(obs_names, (np.ndarray)): 514 adata.obs = adata.obs.reset_index(drop=True) 515 adata.obs.index = adata.obs.index.astype('O') 516 517 adata.uns['train_indices'] = train_indices 518 adata.uns['test_indices'] = test_indices 519 520 if not scale_data: 521 print("WARNING: Data will not be log transformed and scaled. " 522 "To change this behavior, set scale_data to True") 523 524 return adata
Function to create an AnnData object to carry all relevant information going forward.
Parameters
- X (scipy.sparse.csc_matrix | np.ndarray | pd.DataFrame): A data matrix of cells by features (sparse array recommended for large datasets).
- feature_names (np.ndarray):
Array of feature names corresponding with the features
in
X
. - cell_labels (np.ndarray):
A numpy array of cell phenotypes corresponding with
the cells in
X
. - group_dict (dict):
Dictionary containing feature grouping information (i.e.
{geneset1: np.array([gene_1, gene_2, ..., gene_n]), geneset2: np.array([...]), ...}
. - obs_names (None | np.ndarray):
The cell names corresponding to
X
to be assigned to output object.obs_names
attribute. - scale_data (bool):
If
True
, data matrix is log transformed and standard scaled. - split_data (None | np.ndarray):
If
None
, data will be split stratified by cell labels. Else, is an array of precalculated train/test split corresponding to samples. Can include labels for entire dataset to benchmark performance or for only training data to classify unknown cell types (i.e.np.array(['train', 'test', ..., 'train'])
. - D (int):
Number of Random Fourier Features used to calculate Z.
Should be a positive integer. Higher values of D will
increase classification accuracy at the cost of computation
time. If set to
None
, will be calculated given number of samples. - remove_features (bool):
If
True
, will remove features fromX
andfeature_names
not ingroup_dict
and remove features from groupings not infeature_names
. - train_ratio (float):
Ratio of number of training samples to entire data set. Note:
if a threshold is applied, the ratio training samples may
decrease depending on class balance and
class_threshold
parameter ifallow_multiclass = True
. - distance_metric (str):
The pairwise distance metric used to estimate sigma. Must
be one of the options used in
scipy.spatial.distance.cdist
. - kernel_type (str):
The approximated kernel function used to calculate Zs.
Must be one of
'Gaussian'
,'Laplacian'
, or'Cauchy'
. - random_state (int): Integer random_state used to set the seed for reproducibilty.
- allow_multiclass (bool):
If
False
, will ensure that cell labels are binary. - class_threshold (str | int):
Number of samples allowed in the training data for each cell
class in the training data. If
'median'
, the median number of cells per cell class will be the threshold for number of samples per class. - reduction (str | None):
Choose which dimension reduction technique to perform on
features within a group. 'svd' will run
sklearn.decomposition.TruncatedSVD
, 'linear' will multiply by an array of 1s down to 50 dimensions. - tfidf (bool): Whether to calculate TFIDF transformation on peaks within groupings.
Returns
adata (ad.AnnData): AnnData with the following attributes and keys:
adata.X
(array_like): Data matrix.adata.var_names
(array_like): Feature names corresponding toadata.X
.adata.obs['labels']
(array_like): cell classes/phenotypes fromcell_labels
.adata.uns['train_indices']
(array_like): Indices for training data.adata.uns['test_indices']
(array_like) Indices for testing data.adata.uns['group_dict']
(dict): Grouping information.adata.uns['seed_obj']
(np.random._generator.Generator): Seed object with seed equal to 100 *random_state
.adata.uns['D']
(int): Number of dimensions to scMKL with.adata.uns['scale_data']
(bool): Whether or not data is log and z-score transformed.adata.uns['distance_metric']
(str): Distance metric as given.adata.uns['kernel_type']
(str): Kernel function as given.adata.uns['svd']
(bool): Whether to calculate SVD reduction.adata.uns['tfidf']
(bool): Whether to calculate TF-IDF per grouping.
Examples
>>> data_mat = scipy.sparse.load_npz('MCF7_RNA_matrix.npz')
>>> gene_names = np.load('MCF7_gene_names.pkl', allow_pickle = True)
>>> group_dict = np.load('hallmark_genesets.pkl',
>>> allow_pickle = True)
>>>
>>> adata = scmkl.create_adata(X = data_mat,
... feature_names = gene_names,
... group_dict = group_dict)
>>> adata
AnnData object with n_obs Ă— n_vars = 1000 Ă— 4341
obs: 'labels'
uns: 'group_dict', 'seed_obj', 'scale_data', 'D', 'kernel_type',
'distance_metric', 'train_indices', 'test_indices'
164def estimate_sigma(adata: ad.AnnData, 165 n_features: int = 5000, 166 batches: int = 10, 167 batch_size: int = 100) -> ad.AnnData: 168 """ 169 Calculate kernel widths to inform distribution for projection of 170 Fourier Features. Calculates one sigma per group of features. 171 172 Parameters 173 ---------- 174 adata : ad.AnnData 175 Created by `create_adata`. 176 177 n_features : int 178 Number of random features to include when estimating sigma. 179 Will be scaled for the whole pathway set according to a 180 heuristic. Used for scalability. 181 182 batches : int 183 The number of batches to use for the distance calculation. 184 This will average the result of `batches` distance calculations 185 of `batch_size` randomly sampled cells. More batches will converge 186 to population distance values at the cost of scalability. 187 188 batch_size : int 189 The number of cells to include per batch for distance 190 calculations. Higher batch size will converge to population 191 distance values at the cost of scalability. 192 If `batches` * `batch_size` > # training cells, 193 `batch_size` will be reduced to `int(num training cells / 194 batches)`. 195 196 Returns 197 ------- 198 adata : ad.AnnData 199 Key added `adata.uns['sigma']` with grouping kernel widths. 200 201 Examples 202 -------- 203 >>> adata = scmkl.estimate_sigma(adata) 204 >>> adata.uns['sigma'] 205 array([10.4640895 , 10.82011454, 6.16769438, 9.86156855, ...]) 206 """ 207 assert batch_size <= len(adata.uns['train_indices']), ("Batch size must be " 208 "smaller than the " 209 "training set.") 210 211 if batch_size * batches > len(adata.uns['train_indices']): 212 old_batch_size = batch_size 213 batch_size = int(len(adata.uns['train_indices'])/batches) 214 print("Specified batch size required too many cells for " 215 "independent batches. Reduced batch size from " 216 f"{old_batch_size} to {batch_size}") 217 218 if batch_size > 2000: 219 print("Warning: Batch sizes over 2000 may " 220 "result in long run-time.") 221 222 # Getting subsample indices 223 sample_size = np.min((2000, adata.uns['train_indices'].shape[0])) 224 indices = sample_cells(adata.uns['train_indices'], sample_size=sample_size, 225 seed_obj=adata.uns['seed_obj']) 226 227 # Getting batch indices 228 sample_range = np.arange(sample_size) 229 batch_idx = get_batches(sample_range, adata.uns['seed_obj'], 230 batches, batch_size) 231 232 # Loop over every group in group_dict 233 sigma_array = np.zeros((len(adata.uns['group_dict']))) 234 for m, group_features in enumerate(adata.uns['group_dict'].values()): 235 236 n_group_features = len(group_features) 237 238 # Filtering to only features in grouping using filtered view of adata 239 X_train = get_group_mat(adata[indices], n_features=n_features, 240 group_features=group_features, 241 n_group_features=n_group_features) 242 243 if adata.uns['tfidf']: 244 X_train = tfidf(X_train, mode='normalize') 245 246 # Data filtering, and transformation according to given data_type 247 # Will remove low variance (< 1e5) features regardless of data_type 248 # If scale_data will log scale and z-score the data 249 X_train = process_data(X_train=X_train, 250 scale_data=adata.uns['scale_data'], 251 return_dense=True) 252 253 # Estimating sigma 254 sigma = est_group_sigma(adata, X_train, n_group_features, 255 n_features, batch_idx=batch_idx) 256 257 sigma_array[m] = sigma 258 259 adata.uns['sigma'] = sigma_array 260 261 return adata
Calculate kernel widths to inform distribution for projection of Fourier Features. Calculates one sigma per group of features.
Parameters
- adata (ad.AnnData):
Created by
create_adata
. - n_features (int): Number of random features to include when estimating sigma. Will be scaled for the whole pathway set according to a heuristic. Used for scalability.
- batches (int):
The number of batches to use for the distance calculation.
This will average the result of
batches
distance calculations ofbatch_size
randomly sampled cells. More batches will converge to population distance values at the cost of scalability. - batch_size (int):
The number of cells to include per batch for distance
calculations. Higher batch size will converge to population
distance values at the cost of scalability.
If
batches
*batch_size
> # training cells,batch_size
will be reduced toint(num training cells / batches)
.
Returns
- adata (ad.AnnData):
Key added
adata.uns['sigma']
with grouping kernel widths.
Examples
>>> adata = scmkl.estimate_sigma(adata)
>>> adata.uns['sigma']
array([10.4640895 , 10.82011454, 6.16769438, 9.86156855, ...])
264def extract_results(results: dict, metric: str): 265 """ 266 267 """ 268 summary = {'Alpha' : list(), 269 metric : list(), 270 'Number of Selected Groups' : list(), 271 'Top Group' : list()} 272 273 alpha_list = list(results['Metrics'].keys()) 274 275 # Creating summary DataFrame for each model 276 for alpha in alpha_list: 277 cur_alpha_rows = results['Norms'][alpha] 278 top_weight_rows = np.max(results['Norms'][alpha]) 279 top_group_index = np.where(cur_alpha_rows == top_weight_rows) 280 num_selected = len(results['Selected_groups'][alpha]) 281 top_group_name = np.array(results['Group_names'])[top_group_index] 282 283 if 0 == num_selected: 284 top_group_name = ["No groups selected"] 285 286 summary['Alpha'].append(alpha) 287 summary[metric].append(results['Metrics'][alpha][metric]) 288 summary['Number of Selected Groups'].append(num_selected) 289 summary['Top Group'].append(*top_group_name) 290 291 return pd.DataFrame(summary)
203def find_candidates(organism: str='human', key_terms: str | list='', blacklist: str | list | bool=False): 204 """ 205 Given `organism` and `key_terms`, will search for gene 206 groupings that could fit the datasets/classification task. 207 `blacklist` terms undesired in group names. 208 209 Parameters 210 ---------- 211 organism : str 212 The species the gene grouping is for. Options are 213 `{'Human', 'Mouse', 'Yeast', 'Fly', 'Fish', 'Worm'}` 214 215 key_terms : str | list 216 The types of cells or other specifiers the gene set is for 217 (example: 'CD4 T'). 218 219 blacklist : str | list | bool 220 Term(s) undesired in group names. Ignored unless provided. 221 222 Returns 223 ------- 224 libraries : list 225 A list of gene set library names that could serve for the 226 dataset/classification task. 227 228 Examples 229 -------- 230 >>> scmkl.find_candidates('human', key_terms=' b ') 231 Library No. Gene Sets 232 0 Azimuth_2023 1241 233 1 Azimuth_Cell_Types_2021 341 234 2 Cancer_Cell_Line_Encyclopedia 967 235 3 CellMarker_2024 1134 236 No. Key Type Matching 237 9 238 9 239 0 240 21 241 """ 242 check_organism(organism) 243 244 if organism.lower() in global_lib_orgs: 245 glo = global_lib_orgs.copy() 246 glo.remove(organism) 247 other_org = glo[0] 248 libs = human_genesets 249 libs = [name for name in libs if not other_org in name.lower()] 250 else: 251 libs = gp.get_library_name(organism) 252 other_org = '' 253 254 libs = {name : gp.get_library(name, organism) 255 for name in libs} 256 257 libs_df, _ = check_libs(libs, key_terms, blacklist, other_org) 258 259 return libs_df
Given organism
and key_terms
, will search for gene
groupings that could fit the datasets/classification task.
blacklist
terms undesired in group names.
Parameters
- organism (str):
The species the gene grouping is for. Options are
{'Human', 'Mouse', 'Yeast', 'Fly', 'Fish', 'Worm'}
- key_terms (str | list): The types of cells or other specifiers the gene set is for (example: 'CD4 T').
- blacklist (str | list | bool): Term(s) undesired in group names. Ignored unless provided.
Returns
- libraries (list): A list of gene set library names that could serve for the dataset/classification task.
Examples
>>> scmkl.find_candidates('human', key_terms=' b ')
Library No. Gene Sets
0 Azimuth_2023 1241
1 Azimuth_Cell_Types_2021 341
2 Cancer_Cell_Line_Encyclopedia 967
3 CellMarker_2024 1134
No. Key Type Matching
9
9
0
21
527def format_adata(adata: ad.AnnData | str, cell_labels: np.ndarray | str, 528 group_dict: dict | str, use_raw: bool=False, 529 scale_data: bool=True, split_data: np.ndarray | None=None, 530 D: int | None=None, remove_features: bool=True, 531 train_ratio: float=0.8, distance_metric: str='euclidean', 532 kernel_type: str='Gaussian', random_state: int=1, 533 allow_multiclass: bool = False, 534 class_threshold: str | int = 'median', 535 reduction: str | None = None, tfidf: bool = False): 536 """ 537 Function to format an `ad.AnnData` object to carry all relevant 538 information going forward. `adata.obs_names` will be retained. 539 540 **NOTE: Information not needed for running `scmkl` will be 541 removed.** 542 543 Parameters 544 ---------- 545 adata : ad.AnnData 546 Object with data for `scmkl` to be applied to. Only requirment 547 is that `.var_names` is correct and data matrix is in `adata.X` 548 or `adata.raw.X`. A h5ad file can be provided as a `str` and it 549 will be read in. 550 551 cell_labels : np.ndarray | str 552 If type `str`, the labels for `scmkl` to learn are captured 553 from `adata.obs['cell_labels']`. Else, a `np.ndarray` of cell 554 phenotypes corresponding with the cells in `adata.X`. 555 556 group_dict : dict | str 557 Dictionary containing feature grouping information (i.e. 558 `{geneset1: np.array([gene_1, gene_2, ..., gene_n]), geneset2: 559 np.array([...]), ...}`. A pickle file can be provided as a `str` 560 and it will be read in. 561 562 obs_names : None | np.ndarray 563 The cell names corresponding to `X` to be assigned to output 564 object `.obs_names` attribute. 565 566 use_raw : bool 567 If `False`, will use `adata.X` to create new `adata`. Else, 568 will use `adata.raw.X`. 569 570 scale_data : bool 571 If `True`, data matrix is log transformed and standard 572 scaled. 573 574 split_data : None | np.ndarray 575 If `None`, data will be split stratified by cell labels. 576 Else, is an array of precalculated train/test split 577 corresponding to samples. Can include labels for entire 578 dataset to benchmark performance or for only training 579 data to classify unknown cell types (i.e. `np.array(['train', 580 'test', ..., 'train'])`. 581 582 D : int 583 Number of Random Fourier Features used to calculate Z. 584 Should be a positive integer. Higher values of D will 585 increase classification accuracy at the cost of computation 586 time. If set to `None`, will be calculated given number of 587 samples. 588 589 remove_features : bool 590 If `True`, will remove features from `X` and `feature_names` 591 not in `group_dict` and remove features from groupings not in 592 `feature_names`. 593 594 train_ratio : float 595 Ratio of number of training samples to entire data set. Note: 596 if a threshold is applied, the ratio training samples may 597 decrease depending on class balance and `class_threshold` 598 parameter if `allow_multiclass = True`. 599 600 distance_metric : str 601 The pairwise distance metric used to estimate sigma. Must 602 be one of the options used in `scipy.spatial.distance.cdist`. 603 604 kernel_type : str 605 The approximated kernel function used to calculate Zs. 606 Must be one of `'Gaussian'`, `'Laplacian'`, or `'Cauchy'`. 607 608 random_state : int 609 Integer random_state used to set the seed for 610 reproducibilty. 611 612 allow_multiclass : bool 613 If `False`, will ensure that cell labels are binary. 614 615 class_threshold : str | int 616 Number of samples allowed in the training data for each cell 617 class in the training data. If `'median'`, the median number 618 of cells per cell class will be the threshold for number of 619 samples per class. 620 621 reduction: str | None 622 Choose which dimension reduction technique to perform on 623 features within a group. 'svd' will run 624 `sklearn.decomposition.TruncatedSVD`, 'linear' will multiply 625 by an array of 1s down to 50 dimensions. 626 627 tfidf: bool 628 Whether to calculate TFIDF transformation on peaks within 629 groupings. 630 631 Returns 632 ------- 633 adata : ad.AnnData 634 AnnData with the following attributes and keys: 635 636 `adata.X` (array_like): 637 Data matrix. 638 639 `adata.var_names` (array_like): 640 Feature names corresponding to `adata.X`. 641 642 `adata.obs['labels']` (array_like): 643 cell classes/phenotypes from `cell_labels`. 644 645 `adata.uns['train_indices']` (array_like): 646 Indices for training data. 647 648 `adata.uns['test_indices']` (array_like) 649 Indices for testing data. 650 651 `adata.uns['group_dict']` (dict): 652 Grouping information. 653 654 `adata.uns['seed_obj']` (np.random._generator.Generator): 655 Seed object with seed equal to 100 * `random_state`. 656 657 `adata.uns['D']` (int): 658 Number of dimensions to scMKL with. 659 660 `adata.uns['scale_data']` (bool): 661 Whether or not data is log and z-score transformed. 662 663 `adata.uns['distance_metric']` (str): 664 Distance metric as given. 665 666 `adata.uns['kernel_type']` (str): 667 Kernel function as given. 668 669 `adata.uns['svd']` (bool): 670 Whether to calculate SVD reduction. 671 672 `adata.uns['tfidf']` (bool): 673 Whether to calculate TF-IDF per grouping. 674 675 Examples 676 -------- 677 >>> adata = ad.read_h5ad('MCF7_rna.h5ad') 678 >>> group_dict = np.load('hallmark_genesets.pkl', 679 >>> allow_pickle = True) 680 >>> 681 >>> 682 >>> # The labels in adata.obs we want to learn are 'celltypes' 683 >>> adata = scmkl.format_adata(adata, 'celltypes', 684 ... group_dict) 685 >>> adata 686 AnnData object with n_obs Ă— n_vars = 1000 Ă— 4341 687 obs: 'labels' 688 uns: 'group_dict', 'seed_obj', 'scale_data', 'D', 'kernel_type', 689 'distance_metric', 'train_indices', 'test_indices' 690 """ 691 if str == type(adata): 692 adata = ad.read_h5ad(adata) 693 694 if str == type(group_dict): 695 group_dict = np.load(group_dict, allow_pickle=True) 696 697 if str == type(cell_labels): 698 err_msg = f"{cell_labels} is not in `adata.obs`" 699 assert cell_labels in adata.obs.keys(), err_msg 700 cell_labels = adata.obs[cell_labels].to_numpy() 701 702 if use_raw: 703 assert adata.raw, "`adata.raw` is empty, set `use_raw` to `False`" 704 X = adata.raw.X 705 else: 706 X = adata.X 707 708 adata = create_adata(X, adata.var_names.to_numpy().copy(), cell_labels, 709 group_dict, adata.obs_names.to_numpy().copy(), 710 scale_data, split_data, D, remove_features, 711 train_ratio, distance_metric, kernel_type, 712 random_state, allow_multiclass, class_threshold, 713 reduction, tfidf) 714 715 return adata
Function to format an ad.AnnData
object to carry all relevant
information going forward. adata.obs_names
will be retained.
NOTE: Information not needed for running scmkl
will be
removed.
Parameters
- adata (ad.AnnData):
Object with data for
scmkl
to be applied to. Only requirment is that.var_names
is correct and data matrix is inadata.X
oradata.raw.X
. A h5ad file can be provided as astr
and it will be read in. - cell_labels (np.ndarray | str):
If type
str
, the labels forscmkl
to learn are captured fromadata.obs['cell_labels']
. Else, anp.ndarray
of cell phenotypes corresponding with the cells inadata.X
. - group_dict (dict | str):
Dictionary containing feature grouping information (i.e.
{geneset1: np.array([gene_1, gene_2, ..., gene_n]), geneset2: np.array([...]), ...}
. A pickle file can be provided as astr
and it will be read in. - obs_names (None | np.ndarray):
The cell names corresponding to
X
to be assigned to output object.obs_names
attribute. - use_raw (bool):
If
False
, will useadata.X
to create newadata
. Else, will useadata.raw.X
. - scale_data (bool):
If
True
, data matrix is log transformed and standard scaled. - split_data (None | np.ndarray):
If
None
, data will be split stratified by cell labels. Else, is an array of precalculated train/test split corresponding to samples. Can include labels for entire dataset to benchmark performance or for only training data to classify unknown cell types (i.e.np.array(['train', 'test', ..., 'train'])
. - D (int):
Number of Random Fourier Features used to calculate Z.
Should be a positive integer. Higher values of D will
increase classification accuracy at the cost of computation
time. If set to
None
, will be calculated given number of samples. - remove_features (bool):
If
True
, will remove features fromX
andfeature_names
not ingroup_dict
and remove features from groupings not infeature_names
. - train_ratio (float):
Ratio of number of training samples to entire data set. Note:
if a threshold is applied, the ratio training samples may
decrease depending on class balance and
class_threshold
parameter ifallow_multiclass = True
. - distance_metric (str):
The pairwise distance metric used to estimate sigma. Must
be one of the options used in
scipy.spatial.distance.cdist
. - kernel_type (str):
The approximated kernel function used to calculate Zs.
Must be one of
'Gaussian'
,'Laplacian'
, or'Cauchy'
. - random_state (int): Integer random_state used to set the seed for reproducibilty.
- allow_multiclass (bool):
If
False
, will ensure that cell labels are binary. - class_threshold (str | int):
Number of samples allowed in the training data for each cell
class in the training data. If
'median'
, the median number of cells per cell class will be the threshold for number of samples per class. - reduction (str | None):
Choose which dimension reduction technique to perform on
features within a group. 'svd' will run
sklearn.decomposition.TruncatedSVD
, 'linear' will multiply by an array of 1s down to 50 dimensions. - tfidf (bool): Whether to calculate TFIDF transformation on peaks within groupings.
Returns
adata (ad.AnnData): AnnData with the following attributes and keys:
adata.X
(array_like): Data matrix.adata.var_names
(array_like): Feature names corresponding toadata.X
.adata.obs['labels']
(array_like): cell classes/phenotypes fromcell_labels
.adata.uns['train_indices']
(array_like): Indices for training data.adata.uns['test_indices']
(array_like) Indices for testing data.adata.uns['group_dict']
(dict): Grouping information.adata.uns['seed_obj']
(np.random._generator.Generator): Seed object with seed equal to 100 *random_state
.adata.uns['D']
(int): Number of dimensions to scMKL with.adata.uns['scale_data']
(bool): Whether or not data is log and z-score transformed.adata.uns['distance_metric']
(str): Distance metric as given.adata.uns['kernel_type']
(str): Kernel function as given.adata.uns['svd']
(bool): Whether to calculate SVD reduction.adata.uns['tfidf']
(bool): Whether to calculate TF-IDF per grouping.
Examples
>>> adata = ad.read_h5ad('MCF7_rna.h5ad')
>>> group_dict = np.load('hallmark_genesets.pkl',
>>> allow_pickle = True)
>>>
>>>
>>> # The labels in adata.obs we want to learn are 'celltypes'
>>> adata = scmkl.format_adata(adata, 'celltypes',
... group_dict)
>>> adata
AnnData object with n_obs Ă— n_vars = 1000 Ă— 4341
obs: 'labels'
uns: 'group_dict', 'seed_obj', 'scale_data', 'D', 'kernel_type',
'distance_metric', 'train_indices', 'test_indices'
90def format_group_names(group_names: list | pd.Series | np.ndarray, 91 rm_words: list=list()): 92 """ 93 Takes an ArrayLike object of group names and formats them. 94 95 Parameters 96 ---------- 97 group_names : array_like 98 An array of group names to format. 99 100 rm_words : list 101 Words to remove from all group names. 102 103 Returns 104 ------- 105 new_group_names : list 106 Formatted version of the input group names. 107 108 Examples 109 -------- 110 >>> groups = ['HALLMARK_E2F_TARGETS', 'HALLMARK_HYPOXIA'] 111 >>> new_groups = scmkl.format_group_names(groups) 112 >>> new_groups 113 ['Hallmark E2F Targets', 'Hallmark Hypoxia'] 114 """ 115 new_group_names = list() 116 rm_words = [word.lower() for word in rm_words] 117 118 for name in group_names: 119 new_name = list() 120 for word in re.split(r'_|\s', name): 121 if word.isalpha() and (len(word) > 3): 122 word = word.capitalize() 123 if word.lower() not in rm_words: 124 new_name.append(word) 125 new_name = ' '.join(new_name) 126 new_group_names.append(new_name) 127 128 return new_group_names
Takes an ArrayLike object of group names and formats them.
Parameters
- group_names (array_like): An array of group names to format.
- rm_words (list): Words to remove from all group names.
Returns
- new_group_names (list): Formatted version of the input group names.
Examples
>>> groups = ['HALLMARK_E2F_TARGETS', 'HALLMARK_HYPOXIA']
>>> new_groups = scmkl.format_group_names(groups)
>>> new_groups
['Hallmark E2F Targets', 'Hallmark Hypoxia']
262def get_gene_groupings(lib_name: str, organism: str='human', key_terms: str | list='', 263 blacklist: str | list | bool=False, min_overlap: int=2, 264 genes: list | tuple | pd.Series | np.ndarray | set=[]): 265 """ 266 Takes a gene set library name and filters to groups containing 267 element(s) in `key_terms`. If genes is provided, will 268 ensure that there are at least `min_overlap` number of genes in 269 each group. Resulting groups will meet all of the before-mentioned 270 criteria if `isin_logic` is `'and'` | `'or'`. 271 272 Parameters 273 ---------- 274 lib_name : str 275 The desired library name. 276 277 organism : str 278 The species the gene grouping is for. Options are 279 `{'Human', 'Mouse', 'Yeast', 'Fly', 'Fish', 'Worm'}`. 280 281 key_terms : str | list 282 The types of cells or other specifiers the gene set is for 283 (example: 'CD4 T'). 284 285 genes : array_like 286 A vector of genes from the reference/query datasets. If not 287 assigned, function will not filter groups based on feature 288 overlap. 289 290 min_overlap : int 291 The minimum number of genes that must be present in a group 292 for it to be kept. If `genes` is not given, ignored. 293 294 Returns 295 ------- 296 lib : dict 297 The filtered library as a `dict` where keys are group names 298 and keys are features. 299 300 Examples 301 -------- 302 >>> dataset_feats = [ 303 ... 'FUCA1', 'CLIC4', 'STMN1', 'SYF2', 'TAS1R1', 304 ... 'NOL9', 'TAS1R3', 'SLC2A5', 'THAP3', 'IGHM', 305 ... 'MARCKS', 'BANK1', 'TNFRSF13B', 'IGKC', 'IGHD', 306 ... 'LINC01857', 'CD24', 'CD37', 'IGHD', 'RALGPS2' 307 ... ] 308 >>> rna_grouping = scmkl.get_gene_groupings( 309 ... 'Azimuth_2023', key_terms=[' b ', 'b cell', 'b '], 310 ... genes=dataset_feats) 311 >>> 312 >>> rna_groupings.keys() 313 dict_keys(['PBMC-L1-B Cell', 'PBMC-L2-Intermediate B Cell', ...]) 314 """ 315 check_organism(organism) 316 317 lib = gp.get_library(lib_name, organism) 318 319 if organism.lower() in global_lib_orgs: 320 glo = global_lib_orgs.copy() 321 glo.remove(organism) 322 other_org = glo[0] 323 else: 324 other_org = '' 325 326 group_names = list(lib.keys()) 327 res = check_groups(group_names, key_terms, blacklist, other_org) 328 del res['num_groups'] 329 330 # Finding groups where group name matches key_terms 331 g_summary = pd.DataFrame(res) 332 333 if key_terms: 334 kept = g_summary['key_terms_in'] 335 kept_groups = g_summary['name'][kept].to_numpy() 336 g_summary = g_summary[kept] 337 else: 338 print("Not filtering with `key_terms` parameter.") 339 kept_groups = g_summary['name'].to_numpy() 340 341 if blacklist: 342 kept = ~g_summary['blacklist_in'] 343 kept_groups = g_summary['name'][kept].to_numpy() 344 else: 345 print("Not filtering with `blacklist` parameter.") 346 347 # Filtering library 348 lib = {group : lib[group] for group in kept_groups} 349 350 if 0 < len(genes): 351 del_groups = list() 352 genes = list(set(genes.copy())) 353 for group, features in lib.items(): 354 overlap = np.isin(features, genes) 355 overlap = np.sum(overlap) 356 if overlap < min_overlap: 357 print(overlap, flush=True) 358 del_groups.append(group) 359 360 # Removing genes without enough overlap 361 for group in del_groups: 362 print(f'Removing {group} from grouping.') 363 del lib[group] 364 365 else: 366 print("Not checking overlap between group and dataset features.") 367 368 return lib
Takes a gene set library name and filters to groups containing
element(s) in key_terms
. If genes is provided, will
ensure that there are at least min_overlap
number of genes in
each group. Resulting groups will meet all of the before-mentioned
criteria if isin_logic
is 'and'
| 'or'
.
Parameters
- lib_name (str): The desired library name.
- organism (str):
The species the gene grouping is for. Options are
{'Human', 'Mouse', 'Yeast', 'Fly', 'Fish', 'Worm'}
. - key_terms (str | list): The types of cells or other specifiers the gene set is for (example: 'CD4 T').
- genes (array_like): A vector of genes from the reference/query datasets. If not assigned, function will not filter groups based on feature overlap.
- min_overlap (int):
The minimum number of genes that must be present in a group
for it to be kept. If
genes
is not given, ignored.
Returns
- lib (dict):
The filtered library as a
dict
where keys are group names and keys are features.
Examples
>>> dataset_feats = [
... 'FUCA1', 'CLIC4', 'STMN1', 'SYF2', 'TAS1R1',
... 'NOL9', 'TAS1R3', 'SLC2A5', 'THAP3', 'IGHM',
... 'MARCKS', 'BANK1', 'TNFRSF13B', 'IGKC', 'IGHD',
... 'LINC01857', 'CD24', 'CD37', 'IGHD', 'RALGPS2'
... ]
>>> rna_grouping = scmkl.get_gene_groupings(
... 'Azimuth_2023', key_terms=[' b ', 'b cell', 'b '],
... genes=dataset_feats)
>>>
>>> rna_groupings.keys()
dict_keys(['PBMC-L1-B Cell', 'PBMC-L2-Intermediate B Cell', ...])
398def get_metrics(results: dict, include_as: bool=False) -> pd.DataFrame: 399 """ 400 Takes either a single scMKL result or a dictionary where each 401 entry cooresponds to one result. Returns a dataframe with cols 402 ['Alpha', 'Metric', 'Value']. If `include_as == True`, another 403 col of booleans will be added to indicate whether or not the run 404 respective to that alpha was chosen as optimal via CV. If 405 `include_key == True`, another column will be added with the name 406 of the key to the respective file (only applicable with multiple 407 results). 408 409 Parameters 410 ---------- 411 results : dict | None 412 A dictionary with the results of a single run from 413 `scmkl.run()`. Must be `None` if `rfiles is not None`. 414 415 rfiles : dict | None 416 A dictionary of results dictionaries containing multiple 417 results from `scmkl.run()`. 418 419 include_as : bool 420 When `True`, will add a bool col to output pd.DataFrame 421 where rows with alphas cooresponding to alpha_star will be 422 `True`. 423 424 Returns 425 ------- 426 df : pd.DataFrame 427 A pd.DataFrame containing all of the metrics present from 428 the runs input. 429 430 Examples 431 -------- 432 >>> # For a single file 433 >>> results = scmkl.run(adata) 434 >>> metrics = scmkl.get_metrics(results = results) 435 436 >>> # For multiple runs saved in a dict 437 >>> output_dir = 'scMKL_outputs/' 438 >>> rfiles = scmkl.read_files(output_dir) 439 >>> metrics = scmkl.get_metrics(rfiles=rfiles) 440 """ 441 # Checking which data is being worked with 442 is_mult, is_many = _parse_result_type(results) 443 444 # Initiating col list with minimal columns 445 cols = ['Alpha', 'Metric', 'Value'] 446 447 if include_as: 448 cols.append('Alpha Star') 449 if is_mult: 450 cols.append('Class') 451 452 if is_many: 453 cols.append('Key') 454 df = pd.DataFrame(columns = cols) 455 for key, result in results.items(): 456 cur_df = parse_metrics(results = result, key = key, 457 include_as = include_as) 458 df = pd.concat([df, cur_df.copy()]) 459 460 else: 461 df = parse_metrics(results = results, include_as = include_as) 462 463 return df
Takes either a single scMKL result or a dictionary where each
entry cooresponds to one result. Returns a dataframe with cols
['Alpha', 'Metric', 'Value']. If include_as == True
, another
col of booleans will be added to indicate whether or not the run
respective to that alpha was chosen as optimal via CV. If
include_key == True
, another column will be added with the name
of the key to the respective file (only applicable with multiple
results).
Parameters
- results (dict | None):
A dictionary with the results of a single run from
scmkl.run
. Must beNone
ifrfiles is not None
. - rfiles (dict | None):
A dictionary of results dictionaries containing multiple
results from
scmkl.run
. - include_as (bool):
When
True
, will add a bool col to output pd.DataFrame where rows with alphas cooresponding to alpha_star will beTrue
.
Returns
- df (pd.DataFrame): A pd.DataFrame containing all of the metrics present from the runs input.
Examples
>>> # For a single file
>>> results = scmkl.run(adata)
>>> metrics = scmkl.get_metrics(results = results)
>>> # For multiple runs saved in a dict
>>> output_dir = 'scMKL_outputs/'
>>> rfiles = scmkl.read_files(output_dir)
>>> metrics = scmkl.get_metrics(rfiles=rfiles)
290def get_region_groupings(gene_anno : pd.DataFrame, gene_sets : dict, 291 feature_names : np.ndarray | pd.Series | list | set, 292 len_up : int = 5000, len_down : int = 5000) -> dict: 293 """ 294 Creates a peak set where keys are gene set names from `gene_sets` 295 and values are arrays of features pulled from `feature_names`. 296 Features are added to each peak set given overlap between regions 297 in single-cell data matrix and inferred gene promoter regions in 298 `gene_anno`. 299 300 Parameters 301 ---------- 302 gene_anno : pd.DataFrame 303 Gene annotations in GTF format as a pd.DataFrame with columns 304 `['chr', 'start', 'end', 'strand', 'gene_name']`. 305 306 gene_sets : dict 307 Gene set names as keys and an iterable object of gene names 308 as values. 309 310 feature_names : array_like | set 311 Feature names corresponding to a single_cell epigenetic data 312 matrix. 313 314 Returns 315 ------- 316 epi_grouping : dict 317 Keys are the names from `gene_sets` and values 318 are a list of regions from `feature_names` that overlap with 319 promotor regions respective to genes in `gene_sets` (i.e., if 320 region in `feature_names` overlaps with promotor region from a 321 gene in a gene set from `gene_sets`, that region will be added 322 to the new dictionary under the respective gene set name). 323 324 Examples 325 -------- 326 >>> # Reading in a gene set and the peak names from dataset 327 >>> gene_sets = np.load("data/RNA_hallmark_groupings.pkl", 328 ... allow_pickle = True) 329 >>> peaks = np.load("data/MCF7_region_names.npy", 330 ... allow_pickle = True) 331 >>> 332 >>> # Reading in GTF file 333 >>> gtf_path = "data/hg38_subset_protein_coding.annotation.gtf" 334 >>> gtf = scmkl.read_gtf(gtf_path, filter_to_coding=True) 335 >>> 336 >>> region_grouping = scmkl.get_region_groupings(gene_anno = gtf, 337 ... gene_sets = gene_sets, 338 ... feature_names = peaks) 339 >>> 340 >>> region_grouping.keys() 341 dict_keys(['HALLMARK_TNFA_SIGNALING_VIA_NFKB', ...]) 342 """ 343 # Getting a unique set of gene names from gene_sets 344 all_genes = {gene for group in gene_sets.keys() 345 for gene in gene_sets[group]} 346 347 # Filtering out NaN values 348 all_genes = [gene for gene in all_genes if type(gene) != float] 349 350 # Filtering out annotations for genes not present in gene_sets 351 gene_anno = gene_anno[np.isin(gene_anno['gene_name'], all_genes)] 352 353 # Adjusting start and end columns to represent promotor regions 354 gene_anno = adjust_regions(gene_anno = gene_anno, 355 len_up = len_up, len_down = len_down) 356 357 # Creating a dictionary from assay features where [chr] : (start, end) 358 feature_dict = create_feature_dict(feature_names) 359 360 # Creating data structures from gene_anno for comparing regions 361 peak_gene_dict, ga_regions = create_region_dicts(gene_anno) 362 363 # Capturing the separator type used in assay 364 chr_sep = ':' if ':' in feature_names[0] else '-' 365 366 epi_groupings = compare_regions(feature_dict = feature_dict, 367 ga_regions = ga_regions, 368 peak_gene_dict = peak_gene_dict, 369 gene_sets = gene_sets, 370 chr_sep = chr_sep) 371 372 return epi_groupings
Creates a peak set where keys are gene set names from gene_sets
and values are arrays of features pulled from feature_names
.
Features are added to each peak set given overlap between regions
in single-cell data matrix and inferred gene promoter regions in
gene_anno
.
Parameters
- gene_anno (pd.DataFrame):
Gene annotations in GTF format as a pd.DataFrame with columns
['chr', 'start', 'end', 'strand', 'gene_name']
. - gene_sets (dict): Gene set names as keys and an iterable object of gene names as values.
- feature_names (array_like | set): Feature names corresponding to a single_cell epigenetic data matrix.
Returns
- epi_grouping (dict):
Keys are the names from
gene_sets
and values are a list of regions fromfeature_names
that overlap with promotor regions respective to genes ingene_sets
(i.e., if region infeature_names
overlaps with promotor region from a gene in a gene set fromgene_sets
, that region will be added to the new dictionary under the respective gene set name).
Examples
>>> # Reading in a gene set and the peak names from dataset
>>> gene_sets = np.load("data/RNA_hallmark_groupings.pkl",
... allow_pickle = True)
>>> peaks = np.load("data/MCF7_region_names.npy",
... allow_pickle = True)
>>>
>>> # Reading in GTF file
>>> gtf_path = "data/hg38_subset_protein_coding.annotation.gtf"
>>> gtf = scmkl.read_gtf(gtf_path, filter_to_coding=True)
>>>
>>> region_grouping = scmkl.get_region_groupings(gene_anno = gtf,
... gene_sets = gene_sets,
... feature_names = peaks)
>>>
>>> region_grouping.keys()
dict_keys(['HALLMARK_TNFA_SIGNALING_VIA_NFKB', ...])
531def get_selection(weights_df: pd.DataFrame, 532 order_groups: bool=False) -> pd.DataFrame: 533 """ 534 This function takes a pd.DataFrame created by 535 `scmkl.get_weights()` and returns a selection table. Selection 536 refers to how many times a group had a nonzero group weight. To 537 calculate this, a col is added indicating whether the group was 538 selected. Then, the dataframe is grouped by alpha and group. 539 Selection can then be summed returning a dataframe with cols 540 `['Alpha', 'Group', Selection]`. If is the result of multiclass 541 run(s), `'Class'` column must be present and will be in resulting 542 df as well. 543 544 Parameters 545 ---------- 546 weights_df : pd.DataFrame 547 A dataframe output by `scmkl.get_weights()` with cols 548 `['Alpha', 'Group', 'Kernel Weight']`. If is the result of 549 multiclass run(s), `'Class'` column must be present as well. 550 551 order_groups : bool 552 If `True`, the `'Group'` col of the output dataframe will be 553 made into a `pd.Categorical` col ordered by number of times 554 each group was selected in decending order. 555 556 Returns 557 ------- 558 df : pd.DataFrame 559 A dataframe with cols `['Alpha', 'Group', Selection]`. Also, 560 `'Class'` column if is a multiclass result. 561 562 Example 563 ------- 564 >>> # For a single file 565 >>> results = scmkl.run(adata) 566 >>> weights = scmkl.get_weights(results = results) 567 >>> selection = scmkl.get_selection(weights) 568 569 >>> # For multiple runs saved in a dict 570 >>> output_dir = 'scMKL_outputs/' 571 >>> rfiles = scmkl.read_files(output_dir) 572 >>> weights = scmkl.get_weights(rfiles=rfiles) 573 >>> selection = scmkl.get_selection(weights) 574 """ 575 # Adding col indicating whether or not groups have nonzero weight 576 selection = weights_df['Kernel Weight'].apply(lambda x: x > 0) 577 weights_df['Selection'] = selection 578 579 # Summing selection across replications to get selection 580 is_mult = 'Class' in weights_df.columns 581 if is_mult: 582 df = weights_df.groupby(['Alpha', 'Group', 'Class'])['Selection'].sum() 583 else: 584 df = weights_df.groupby(['Alpha', 'Group'])['Selection'].sum() 585 df = df.reset_index() 586 587 # Getting group order 588 if order_groups and not is_mult: 589 order = df.groupby('Group')['Selection'].sum() 590 order = order.reset_index().sort_values(by = 'Selection', 591 ascending = False) 592 order = order['Group'] 593 df['Group'] = pd.Categorical(df['Group'], categories = order) 594 595 596 return df
This function takes a pd.DataFrame created by
scmkl.get_weights()
and returns a selection table. Selection
refers to how many times a group had a nonzero group weight. To
calculate this, a col is added indicating whether the group was
selected. Then, the dataframe is grouped by alpha and group.
Selection can then be summed returning a dataframe with cols
['Alpha', 'Group', Selection]
. If is the result of multiclass
run(s), 'Class'
column must be present and will be in resulting
df as well.
Parameters
- weights_df (pd.DataFrame):
A dataframe output by
scmkl.get_weights()
with cols['Alpha', 'Group', 'Kernel Weight']
. If is the result of multiclass run(s),'Class'
column must be present as well. - order_groups (bool):
If
True
, the'Group'
col of the output dataframe will be made into apd.Categorical
col ordered by number of times each group was selected in decending order.
Returns
- df (pd.DataFrame):
A dataframe with cols
['Alpha', 'Group', Selection]
. Also,'Class'
column if is a multiclass result.
Example
>>> # For a single file
>>> results = scmkl.run(adata)
>>> weights = scmkl.get_weights(results = results)
>>> selection = scmkl.get_selection(weights)
>>> # For multiple runs saved in a dict
>>> output_dir = 'scMKL_outputs/'
>>> rfiles = scmkl.read_files(output_dir)
>>> weights = scmkl.get_weights(rfiles=rfiles)
>>> selection = scmkl.get_selection(weights)
294def get_summary(results: dict, metric: str='AUROC'): 295 """ 296 Takes the results from `scmkl.run()` and generates a dataframe 297 for each model containing columns for alpha, area under the ROC, 298 number of groups with nonzero weights, and highest weighted 299 group. 300 301 Parameters 302 ---------- 303 results : dict 304 A dictionary of results from scMKL generated from 305 `scmkl.run()`. 306 307 metric : str 308 Which metric to include in the summary. Default is AUROC. 309 Options include `'AUROC'`, `'Recall'`, `'Precision'`, 310 `'Accuracy'`, and `'F1-Score'`. 311 312 Returns 313 ------- 314 summary_df : pd.DataFrame 315 A table with columns: `['Alpha', 'AUROC', 316 'Number of Selected Groups', 'Top Group']`. 317 318 Examples 319 -------- 320 >>> results = scmkl.run(adata, alpha_list) 321 >>> summary_df = scmkl.get_summary(results) 322 ... 323 >>> summary_df.head() 324 Alpha AUROC Number of Selected Groups 325 0 2.20 0.8600 3 326 1 1.96 0.9123 4 327 2 1.72 0.9357 5 328 3 1.48 0.9524 7 329 4 1.24 0.9666 9 330 Top Group 331 0 RNA-HALLMARK_E2F_TARGETS 332 1 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY 333 2 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY 334 3 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY 335 4 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY 336 """ 337 is_multi, is_many = _parse_result_type(results) 338 assert not is_many, "This function only supports single results" 339 340 if is_multi: 341 summaries = list() 342 for ct in results['Classes']: 343 data = extract_results(results[ct], metric) 344 data['Class'] = [ct]*len(data) 345 summaries.append(data.copy()) 346 summary = pd.concat(summaries) 347 348 else: 349 summary = extract_results(results, metric) 350 351 return summary
Takes the results from scmkl.run
and generates a dataframe
for each model containing columns for alpha, area under the ROC,
number of groups with nonzero weights, and highest weighted
group.
Parameters
- results (dict):
A dictionary of results from scMKL generated from
scmkl.run
. - metric (str):
Which metric to include in the summary. Default is AUROC.
Options include
'AUROC'
,'Recall'
,'Precision'
,'Accuracy'
, and'F1-Score'
.
Returns
- summary_df (pd.DataFrame):
A table with columns:
['Alpha', 'AUROC', 'Number of Selected Groups', 'Top Group']
.
Examples
>>> results = scmkl.run(adata, alpha_list)
>>> summary_df = scmkl.get_summary(results)
...
>>> summary_df.head()
Alpha AUROC Number of Selected Groups
0 2.20 0.8600 3
1 1.96 0.9123 4
2 1.72 0.9357 5
3 1.48 0.9524 7
4 1.24 0.9666 9
Top Group
0 RNA-HALLMARK_E2F_TARGETS
1 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY
2 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY
3 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY
4 RNA-HALLMARK_ESTROGEN_RESPONSE_EARLY
466def get_weights(results : dict, include_as : bool = False) -> pd.DataFrame: 467 """ 468 Takes either a single scMKL result or dictionary of results and 469 returns a pd.DataFrame with cols ['Alpha', 'Group', 470 'Kernel Weight']. If `include_as == True`, a fourth col will be 471 added to indicate whether or not the run respective to that alpha 472 was chosen as optimal via cross validation. 473 474 Parameters 475 ---------- 476 results : dict | None 477 A dictionary with the results of a single run from 478 `scmkl.run()`. Must be `None` if `rfiles is not None`. 479 480 rfiles : dict | None 481 A dictionary of results dictionaries containing multiple 482 results from `scmkl.run()`. 483 484 include_as : bool 485 When `True`, will add a bool col to output pd.DataFrame 486 where rows with alphas cooresponding to alpha_star will be 487 `True`. 488 489 Returns 490 ------- 491 df : pd.DataFrame 492 A pd.DataFrame containing all of the groups from each alpha 493 and their cooresponding kernel weights. 494 495 Examples 496 -------- 497 >>> # For a single file 498 >>> results = scmkl.run(adata) 499 >>> weights = scmkl.get_weights(results = results) 500 501 >>> # For multiple runs saved in a dict 502 >>> output_dir = 'scMKL_outputs/' 503 >>> rfiles = scmkl.read_files(output_dir) 504 >>> weights = scmkl.get_weights(rfiles=rfiles) 505 """ 506 # Checking which data is being worked with 507 is_mult, is_many = _parse_result_type(results) 508 509 # Initiating col list with minimal columns 510 cols = ['Alpha', 'Group', 'Kernel Weight'] 511 512 if include_as: 513 cols.append('Alpha Star') 514 if is_mult: 515 cols.append('Class') 516 517 if is_many: 518 cols.append('Key') 519 df = pd.DataFrame(columns = cols) 520 for key, result in results.items(): 521 cur_df = parse_weights(results = result, key = key, 522 include_as = include_as) 523 df = pd.concat([df, cur_df.copy()]) 524 525 else: 526 df = parse_weights(results = results, include_as = include_as) 527 528 return df
Takes either a single scMKL result or dictionary of results and
returns a pd.DataFrame with cols ['Alpha', 'Group',
'Kernel Weight']. If include_as == True
, a fourth col will be
added to indicate whether or not the run respective to that alpha
was chosen as optimal via cross validation.
Parameters
- results (dict | None):
A dictionary with the results of a single run from
scmkl.run
. Must beNone
ifrfiles is not None
. - rfiles (dict | None):
A dictionary of results dictionaries containing multiple
results from
scmkl.run
. - include_as (bool):
When
True
, will add a bool col to output pd.DataFrame where rows with alphas cooresponding to alpha_star will beTrue
.
Returns
- df (pd.DataFrame): A pd.DataFrame containing all of the groups from each alpha and their cooresponding kernel weights.
Examples
>>> # For a single file
>>> results = scmkl.run(adata)
>>> weights = scmkl.get_weights(results = results)
>>> # For multiple runs saved in a dict
>>> output_dir = 'scMKL_outputs/'
>>> rfiles = scmkl.read_files(output_dir)
>>> weights = scmkl.get_weights(rfiles=rfiles)
599def groups_per_alpha(selection_df: pd.DataFrame) -> dict: 600 """ 601 This function takes a pd.DataFrame from `scmkl.get_selection()` 602 generated from multiple scMKL results and returns a dictionary 603 with keys being alphas from the input dataframe and values being 604 the mean number of selected groups for a given alpha across 605 results. 606 607 Parameters 608 ---------- 609 selection_df : pd.DataFrame 610 A dataframe output by `scmkl.get_selection()` with cols 611 `['Alpha', 'Group', Selection]. 612 613 Returns 614 ------- 615 mean_groups : dict 616 A dictionary with alphas as keys and the mean number of 617 selected groups for that alpha as keys. 618 619 Examples 620 -------- 621 >>> weights = scmkl.get_weights(rfiles) 622 >>> selection = scmkl.get_selection(weights) 623 >>> mean_groups = scmkl.mean_groups_per_alpha(selection) 624 >>> mean_groups = {alpha : np.round(num_selected, 1) 625 ... for alpha, num_selected in mean_groups.items()} 626 >>> 627 >>> print(mean_groups) 628 {0.05 : 50.0, 0.2 : 24.7, 1.1 : 5.3} 629 """ 630 mean_groups = {} 631 for alpha in np.unique(selection_df['Alpha']): 632 633 # Capturing rows for given alpha 634 rows = selection_df['Alpha'] == alpha 635 636 # Adding mean number of groups for alpha 637 mean_groups[alpha] = np.mean(selection_df[rows]['Selection']) 638 639 return mean_groups
This function takes a pd.DataFrame from scmkl.get_selection()
generated from multiple scMKL results and returns a dictionary
with keys being alphas from the input dataframe and values being
the mean number of selected groups for a given alpha across
results.
Parameters
- selection_df (pd.DataFrame):
A dataframe output by
scmkl.get_selection()
with cols `['Alpha', 'Group', Selection].
Returns
- mean_groups (dict): A dictionary with alphas as keys and the mean number of selected groups for that alpha as keys.
Examples
>>> weights = scmkl.get_weights(rfiles)
>>> selection = scmkl.get_selection(weights)
>>> mean_groups = scmkl.mean_groups_per_alpha(selection)
>>> mean_groups = {alpha : np.round(num_selected, 1)
... for alpha, num_selected in mean_groups.items()}
>>>
>>> print(mean_groups)
{0.05 : 50.0, 0.2 : 24.7, 1.1 : 5.3}
613def group_umap(adata: ad.AnnData, g_name: str | list, is_binary: bool=False, 614 labels: None | np.ndarray | list=None, title: str='', 615 save: str=''): 616 """ 617 Uses a scmkl formatted `ad.AnnData` object to show sample 618 separation using scmkl discovered groupings. 619 620 Parameters 621 ---------- 622 adata : ad.AnnData 623 A scmkl formatted `ad.AnnData` object with `'group_dict'` in 624 `.uns`. 625 626 g_name : str | list 627 The groups who's features should be used to filter `adata`. If 628 is a list, features from multiple groups will be used. 629 630 is_binary : bool 631 If `True`, data will be processed using `muon` which includes 632 TF-IDF normalization and LSI. 633 634 labels : None | np.ndarray | list 635 If `None`, labels in `adata.obs['labels']` will be used to 636 color umap points. Else, provided labels will be used to color 637 points. 638 639 title : str 640 The title of the plot. 641 642 save : str 643 If provided, plot will be saved using `scanpy`'s `save` 644 argument. Should be the desired file name. Output will be 645 `<cwd>/figures/<save>`. 646 647 Returns 648 ------- 649 None 650 651 Examples 652 -------- 653 >>> adata_fp = 'data/_pbmc_rna.h5ad' 654 >>> group_fp = 'data/_RNA_azimuth_pbmc_groupings.pkl' 655 >>> adata = scmkl.format_adata(adata_fp, 'celltypes', group_fp, 656 ... allow_multiclass=True) 657 >>> scmkl.group_umap(adata, 'CD16+ Monocyte Markers') 658 659  660 """ 661 if list == type(g_name): 662 feats = {feature 663 for group in g_name 664 for feature in adata.uns['group_dict'][group]} 665 feats = np.array(list(feats)) 666 else: 667 feats = np.array(list(adata.uns['group_dict'][g_name])) 668 669 if labels: 670 assert len(labels) == adata.shape[0], "`labels` do not match `adata`" 671 adata.obs['labels'] = labels 672 673 var_names = adata.var_names.to_numpy() 674 675 col_filter = np.isin(var_names, feats) 676 adata = adata[:, col_filter].copy() 677 678 if not is_binary: 679 sc.pp.normalize_total(adata) 680 sc.pp.log1p(adata) 681 sc.tl.pca(adata) 682 sc.pp.neighbors(adata) 683 sc.tl.umap(adata, random_state=1) 684 685 else: 686 ac.pp.tfidf(adata, scale_factor=1e4) 687 sc.pp.normalize_total(adata) 688 sc.pp.log1p(adata) 689 ac.tl.lsi(adata) 690 sc.pp.scale(adata) 691 sc.tl.pca(adata) 692 sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30) 693 sc.tl.umap(adata, random_state=1) 694 695 if save: 696 sc.pl.umap(adata, title=title, color='labels', save=save, show=False) 697 698 else: 699 sc.pl.umap(adata, title=title, color='labels') 700 701 return None
Uses a scmkl formatted ad.AnnData
object to show sample
separation using scmkl discovered groupings.
Parameters
- adata (ad.AnnData):
A scmkl formatted
ad.AnnData
object with'group_dict'
in.uns
. - g_name (str | list):
The groups who's features should be used to filter
adata
. If is a list, features from multiple groups will be used. - is_binary (bool):
If
True
, data will be processed usingmuon
which includes TF-IDF normalization and LSI. - labels (None | np.ndarray | list):
If
None
, labels inadata.obs['labels']
will be used to color umap points. Else, provided labels will be used to color points. - title (str): The title of the plot.
- save (str):
If provided, plot will be saved using
scanpy
'ssave
argument. Should be the desired file name. Output will be<cwd>/figures/<save>
.
Returns
- None
Examples
>>> adata_fp = 'data/_pbmc_rna.h5ad'
>>> group_fp = 'data/_RNA_azimuth_pbmc_groupings.pkl'
>>> adata = scmkl.format_adata(adata_fp, 'celltypes', group_fp,
... allow_multiclass=True)
>>> scmkl.group_umap(adata, 'CD16+ Monocyte Markers')
101def multimodal_processing(adatas : list[ad.AnnData], names : list[str], 102 tfidf: list[bool], combination: str='concatenate', 103 batches: int=10, batch_size: int=100) -> ad.AnnData: 104 """ 105 Combines and processes a list of `ad.AnnData` objects. 106 107 Parameters 108 ---------- 109 adatas : list[ad.AnnData] 110 List of `ad.AnnData` objects where each object is a different 111 modality. Annotations must match between objects (i.e. same 112 sample order). 113 114 names : list[str] 115 List of strings names for each modality repective to each 116 object in adatas. 117 118 combination: str 119 How to combine the matrices, either `'sum'` or `'concatenate'`. 120 121 tfidf : list[bool] 122 If element `i` is `True`, `adata[i]` will be TF-IDF normalized. 123 124 batches : int 125 The number of batches to use for the distance calculation. 126 This will average the result of `batches` distance calculations 127 of `batch_size` randomly sampled cells. More batches will converge 128 to population distance values at the cost of scalability. 129 130 batch_size : int 131 The number of cells to include per batch for distance 132 calculations. Higher batch size will converge to population 133 distance values at the cost of scalability. 134 If `batches*batch_size > num_training_cells`, `batch_size` 135 will be reduced to `int(num_training_cells / batches)`. 136 137 Returns 138 ------- 139 adata : ad.AnnData 140 Concatenated from objects from `adatas` with Z matrices 141 calculated. 142 143 Examples 144 -------- 145 >>> rna_adata = scmkl.create_adata(X = mcf7_rna_mat, 146 ... feature_names = gene_names, 147 ... scale_data = True, 148 ... cell_labels = cell_labels, 149 ... group_dict = rna_grouping) 150 >>> 151 >>> atac_adata = scmkl.create_adata(X = mcf7_atac_mat, 152 ... feature_names = peak_names, 153 ... scale_data = False, 154 ... cell_labels = cell_labels, 155 ... group_dict = atac_grouping) 156 >>> 157 >>> adatas = [rna_adata, atac_adata] 158 >>> mod_names = ['rna', 'atac'] 159 >>> adata = scmkl.multimodal_processing(adatas = adatas, 160 ... names = mod_names, 161 ... tfidf = [False, True]) 162 >>> 163 >>> adata 164 AnnData object with n_obs Ă— n_vars = 1000 Ă— 12676 165 obs: 'labels' 166 var: 'labels' 167 uns: 'D', 'kernel_type', 'distance_metric', 'train_indices', 168 'test_indices', 'Z_train', 'Z_test', 'group_dict', 'seed_obj' 169 """ 170 import warnings 171 warnings.filterwarnings('ignore') 172 173 diff_num_warn = "Different number of cells present in each object." 174 assert all([adata.shape[0] for adata in adatas]), diff_num_warn 175 176 # True if all train indices match 177 same_train = np.all([np.array_equal(adatas[0].uns['train_indices'], 178 adatas[i].uns['train_indices']) 179 for i in range(1, len(adatas))]) 180 181 # True if all test indices match 182 same_test = np.all([np.array_equal(adatas[0].uns['test_indices'], 183 adatas[i].uns['test_indices']) 184 for i in range(1, len(adatas))]) 185 186 assert same_train, "Different train indices" 187 assert same_test, "Different test indices" 188 189 # Creates a boolean array for each modality of cells with non-empty rows 190 non_empty_rows = [np.array(sparse_var(adata.X, axis = 1) != 0).ravel() 191 for adata in adatas] 192 193 # Returns a 1d array where sample feature sums 194 # across all modalities are more than 0 195 non_empty_rows = np.logical_and(*non_empty_rows).squeeze() 196 197 # Initializing final train test split array 198 train_test = np.repeat('train', adatas[0].shape[0]) 199 train_test[adatas[0].uns['test_indices']] = 'test' 200 201 # Capturing train test split with empty rows filtered out 202 train_test = train_test[non_empty_rows] 203 train_indices = np.where(train_test == 'train')[0] 204 test_indices = np.where(train_test == 'test')[0] 205 206 # Adding train test split arrays to AnnData objects 207 # and filtering out empty samples 208 for i, adata in enumerate(adatas): 209 adatas[i].uns['train_indices'] = train_indices 210 adatas[i].uns['test_indices'] = test_indices 211 adatas[i] = adata[non_empty_rows, :] 212 # tfidf normalizing if corresponding element in tfidf is True 213 if tfidf[i]: 214 adatas[i] = tfidf_normalize(adata) 215 216 print(f"Estimating sigma and calculating Z for {names[i]}", flush = True) 217 adatas[i] = calculate_z(adata, n_features = 5000, batches=batches, 218 batch_size=batch_size) 219 220 if 'labels' in adatas[0].obs: 221 all_labels = [adata.obs['labels'] for adata in adatas] 222 # Ensuring cell labels for each AnnData object are the same 223 uneq_labs_warn = ("Cell labels between AnnData object in position 0 " 224 "and position {} in adatas do not match") 225 for i in range(1, len(all_labels)): 226 same_labels = np.all(all_labels[0] == all_labels[i]) 227 assert same_labels, uneq_labs_warn.format(i) 228 229 adata = combine_modalities(adatas=adatas, 230 names=names, 231 combination=combination) 232 233 del adatas 234 gc.collect() 235 236 return adata
Combines and processes a list of ad.AnnData
objects.
Parameters
- adatas (list[ad.AnnData]):
List of
ad.AnnData
objects where each object is a different modality. Annotations must match between objects (i.e. same sample order). - names (list[str]): List of strings names for each modality repective to each object in adatas.
- combination (str):
How to combine the matrices, either
'sum'
or'concatenate'
. - tfidf (list[bool]):
If element
i
isTrue
,adata[i]
will be TF-IDF normalized. - batches (int):
The number of batches to use for the distance calculation.
This will average the result of
batches
distance calculations ofbatch_size
randomly sampled cells. More batches will converge to population distance values at the cost of scalability. - batch_size (int):
The number of cells to include per batch for distance
calculations. Higher batch size will converge to population
distance values at the cost of scalability.
If
batches*batch_size > num_training_cells
,batch_size
will be reduced toint(num_training_cells / batches)
.
Returns
- adata (ad.AnnData):
Concatenated from objects from
adatas
with Z matrices calculated.
Examples
>>> rna_adata = scmkl.create_adata(X = mcf7_rna_mat,
... feature_names = gene_names,
... scale_data = True,
... cell_labels = cell_labels,
... group_dict = rna_grouping)
>>>
>>> atac_adata = scmkl.create_adata(X = mcf7_atac_mat,
... feature_names = peak_names,
... scale_data = False,
... cell_labels = cell_labels,
... group_dict = atac_grouping)
>>>
>>> adatas = [rna_adata, atac_adata]
>>> mod_names = ['rna', 'atac']
>>> adata = scmkl.multimodal_processing(adatas = adatas,
... names = mod_names,
... tfidf = [False, True])
>>>
>>> adata
AnnData object with n_obs Ă— n_vars = 1000 Ă— 12676
obs: 'labels'
var: 'labels'
uns: 'D', 'kernel_type', 'distance_metric', 'train_indices',
'test_indices', 'Z_train', 'Z_test', 'group_dict', 'seed_obj'
206def one_v_rest(adatas : list, names : list, alpha_list : np.ndarray, 207 tfidf : list, batches: int=10, batch_size: int=100, 208 force_balance: bool=False, other_factor: float=1.0) -> dict: 209 """ 210 For each cell class, creates model(s) comparing that class to all 211 others. Then, predicts on the training data using `scmkl.run()`. 212 Only labels in both training and testing will be run. 213 214 Parameters 215 ---------- 216 adatas : list[AnnData] 217 List of `ad.AnnData` objects created by `create_adata()` 218 where each `ad.AnnData` is one modality and composed of both 219 training and testing samples. Requires that `'train_indices'` 220 and `'test_indices'` are the same between all `ad.AnnData`s. 221 222 names : list[str] 223 String variables that describe each modality respective to 224 `adatas` for labeling. 225 226 alpha_list : np.ndarray | float 227 An array of alpha values to create each model with or a float 228 to run with a single alpha. 229 230 tfidf : list[bool] 231 If element `i` is `True`, `adatas[i]` will be TF-IDF 232 normalized. 233 234 batches : int 235 The number of batches to use for the distance calculation. 236 This will average the result of `batches` distance calculations 237 of `batch_size` randomly sampled cells. More batches will 238 converge to population distance values at the cost of 239 scalability. 240 241 batch_size : int 242 The number of cells to include per batch for distance 243 calculations. Higher batch size will converge to population 244 distance values at the cost of scalability. 245 If `batches*batch_size > num_training_cells`, 246 `batch_size` will be reduced to 247 `int(num_training_cells / batches)`. 248 249 force_balance : bool 250 If `True`, training sets will be balanced to reduce class label 251 imbalance. Defaults to `False`. 252 253 other_factor : float 254 The ratio of cells to sample for the other class for each 255 model. For example, if classifying B cells with 100 B cells in 256 training, if `other_factor=1`, 100 cells that are not B cells 257 will be trained on with the B cells. 258 259 Returns 260 ------- 261 results : dict 262 Contains keys for each cell class with results from cell class 263 versus all other samples. See `scmkl.run()` for futher details. 264 Will also include a probablilities table with the predictions 265 from each model. 266 267 Examples 268 -------- 269 >>> adata = scmkl.create_adata(X = data_mat, 270 ... feature_names = gene_names, 271 ... group_dict = group_dict) 272 >>> 273 >>> results = scmkl.one_v_rest(adatas = [adata], names = ['rna'], 274 ... alpha_list = np.array([0.05, 0.1]), 275 ... tfidf = [False]) 276 >>> 277 >>> adata.keys() 278 dict_keys(['B cells', 'Monocytes', 'Dendritic cells', ...]) 279 """ 280 # Formatting checks ensuring all adata elements are 281 # AnnData objects and train/test indices are all the same 282 _check_adatas(adatas, check_obs = True, check_uns = True) 283 284 285 # Extracting train and test indices 286 train_indices = adatas[0].uns['train_indices'] 287 test_indices = adatas[0].uns['test_indices'] 288 289 # Checking and capturing cell labels 290 uniq_labels = _eval_labels(cell_labels = adatas[0].obs['labels'], 291 train_indices = train_indices, 292 test_indices = test_indices) 293 294 295 # Calculating Z matrices, method depends on whether there are multiple 296 # adatas (modalities) 297 if (len(adatas) == 1) and ('Z_train' not in adatas[0].uns.keys()): 298 adata = calculate_z(adata, n_features = 5000, batches=batches, batch_size=batch_size) 299 elif len(adatas) > 1: 300 adata = multimodal_processing(adatas = adatas, 301 names = names, 302 tfidf = tfidf, 303 batches=batches, 304 batch_size=batch_size) 305 else: 306 adata = adatas[0].copy() 307 308 del adatas 309 gc.collect() 310 311 # Initializing for capturing model outputs 312 results = dict() 313 314 # Capturing cell labels before overwriting 315 cell_labels = np.array(adata.obs['labels'].copy()) 316 317 # Capturing perfect train/test splits for each class 318 if force_balance: 319 train_idx = get_class_train(adata.uns['train_indices'], 320 cell_labels, 321 adata.uns['seed_obj'], 322 other_factor) 323 324 for label in uniq_labels: 325 326 print(f"Comparing {label} to other types", flush = True) 327 cur_labels = cell_labels.copy() 328 cur_labels[cell_labels != label] = 'other' 329 330 # Replacing cell labels for current cell type vs rest 331 adata.obs['labels'] = cur_labels 332 333 if force_balance: 334 adata.uns['train_indices'] = train_idx[label] 335 336 # Running scMKL 337 results[label] = run(adata, alpha_list, return_probs = True) 338 339 # Getting final predictions 340 alpha = np.min(alpha_list) 341 prob_table, pred_class, low_conf = get_prob_table(results, alpha) 342 macro_f1 = f1_score(cell_labels[adata.uns['test_indices']], 343 pred_class, average='macro') 344 345 model_summary = per_model_summary(results, uniq_labels, alpha) 346 347 results['Per_model_summary'] = model_summary 348 results['Classes'] = uniq_labels 349 results['Probability_table'] = prob_table 350 results['Predicted_class'] = pred_class 351 results['Truth_labels'] = cell_labels[adata.uns['test_indices']] 352 results['Low_confidence'] = low_conf 353 results['Macro_F1-Score'] = macro_f1 354 355 if force_balance: 356 results['Training_indices'] = train_idx 357 358 return results
For each cell class, creates model(s) comparing that class to all
others. Then, predicts on the training data using scmkl.run
.
Only labels in both training and testing will be run.
Parameters
- adatas (list[AnnData]):
List of
ad.AnnData
objects created bycreate_adata()
where eachad.AnnData
is one modality and composed of both training and testing samples. Requires that'train_indices'
and'test_indices'
are the same between allad.AnnData
s. - names (list[str]):
String variables that describe each modality respective to
adatas
for labeling. - alpha_list (np.ndarray | float): An array of alpha values to create each model with or a float to run with a single alpha.
- tfidf (list[bool]):
If element
i
isTrue
,adatas[i]
will be TF-IDF normalized. - batches (int):
The number of batches to use for the distance calculation.
This will average the result of
batches
distance calculations ofbatch_size
randomly sampled cells. More batches will converge to population distance values at the cost of scalability. - batch_size (int):
The number of cells to include per batch for distance
calculations. Higher batch size will converge to population
distance values at the cost of scalability.
If
batches*batch_size > num_training_cells
,batch_size
will be reduced toint(num_training_cells / batches)
. - force_balance (bool):
If
True
, training sets will be balanced to reduce class label imbalance. Defaults toFalse
. - other_factor (float):
The ratio of cells to sample for the other class for each
model. For example, if classifying B cells with 100 B cells in
training, if
other_factor=1
, 100 cells that are not B cells will be trained on with the B cells.
Returns
- results (dict):
Contains keys for each cell class with results from cell class
versus all other samples. See
scmkl.run
for futher details. Will also include a probablilities table with the predictions from each model.
Examples
>>> adata = scmkl.create_adata(X = data_mat,
... feature_names = gene_names,
... group_dict = group_dict)
>>>
>>> results = scmkl.one_v_rest(adatas = [adata], names = ['rna'],
... alpha_list = np.array([0.05, 0.1]),
... tfidf = [False])
>>>
>>> adata.keys()
dict_keys(['B cells', 'Monocytes', 'Dendritic cells', ...])
150def optimize_alpha(adata: ad.AnnData | list[ad.AnnData], 151 group_size: int | None=None, 152 tfidf: bool | list[bool]=False, 153 alpha_array: np.ndarray=np.round(np.linspace(1.9,0.1, 10),2), 154 k: int=4, metric: str='AUROC', 155 batches: int=10, batch_size: int=100): 156 """ 157 Iteratively train a grouplasso model and update alpha to find the 158 parameter yielding best performing sparsity. This function 159 currently only works for binary experiments. 160 161 Parameters 162 ---------- 163 adata : ad.AnnData | list[ad.AnnData] 164 `ad.AnnData`(s) with `'Z_train'` and `'Z_test'` in 165 `adata.uns.keys()`. 166 167 group_size : None | int 168 Argument describing how the features are grouped. If `None`, 169 `2 * adata.uns['D']` will be used. For more information see 170 [celer documentation](https://mathurinm.github.io/celer/ 171 generated/celer.GroupLasso.html). 172 173 tfidf : list | bool 174 If `False`, no data will be TF-IDF transformed. If 175 `type(adata) is list` and TF-IDF transformation is desired for 176 all or some of the data, a bool list corresponding to `adata` 177 must be provided. To simply TF-IDF transform `adata` when 178 `type(adata) is ad.AnnData`, use `True`. 179 180 alpha_array : np.ndarray 181 Array of all alpha values to be tested. 182 183 k : int 184 Number of folds to perform cross validation over. 185 186 metric : str 187 Which metric to use to optimize alpha. Options are `'AUROC'`, 188 `'Accuracy'`, `'F1-Score'`, `'Precision'`, and `'Recall'`. 189 190 batches : int 191 The number of batches to use for the distance calculation. 192 This will average the result of `batches` distance calculations 193 of `batch_size` randomly sampled cells. More batches will converge 194 to population distance values at the cost of scalability. 195 196 batch_size : int 197 The number of cells to include per batch for distance 198 calculations. Higher batch size will converge to population 199 distance values at the cost of scalability. If 200 `batches*batch_size > num_training_cells`, `batch_size` will be 201 reduced to `int(num_training_cells/batches)`. 202 203 Returns 204 ------- 205 alpha_star : float 206 The best performing alpha value from cross validation on 207 training data. 208 209 Examples 210 -------- 211 >>> alpha_star = scmkl.optimize_alpha(adata) 212 >>> alpha_star 213 0.1 214 """ 215 assert isinstance(k, int) and k > 0, "'k' must be positive" 216 217 import warnings 218 warnings.filterwarnings('ignore') 219 220 if group_size == None: 221 group_size = adata.uns['D']*2 222 223 if type(adata) == list: 224 alpha_star = multimodal_optimize_alpha(adatas = adata, 225 group_size = group_size, 226 tfidf_list = tfidf, 227 alpha_array = alpha_array, 228 metric = metric, 229 batch_size = batch_size, 230 batches = batches) 231 return alpha_star 232 233 y = adata.obs['labels'].iloc[adata.uns['train_indices']].to_numpy() 234 235 # Splits the labels evenly between folds 236 positive_indices = np.where(y == np.unique(y)[0])[0] 237 negative_indices = np.setdiff1d(np.arange(len(y)), positive_indices) 238 239 positive_annotations = np.arange(len(positive_indices)) % k 240 negative_annotations = np.arange(len(negative_indices)) % k 241 242 metric_array = np.zeros((len(alpha_array), k)) 243 244 gc.collect() 245 246 for fold in np.arange(k): 247 248 cv_adata = adata[adata.uns['train_indices'],:] 249 250 if 'sigma' in cv_adata.uns_keys(): 251 del cv_adata.uns['sigma'] 252 253 # Create CV train/test indices 254 fold_train = np.concatenate((positive_indices[np.where(positive_annotations != fold)[0]], 255 negative_indices[np.where(negative_annotations != fold)[0]])) 256 fold_test = np.concatenate((positive_indices[np.where(positive_annotations == fold)[0]], 257 negative_indices[np.where(negative_annotations == fold)[0]])) 258 259 cv_adata.uns['train_indices'] = fold_train 260 cv_adata.uns['test_indices'] = fold_test 261 262 if tfidf: 263 cv_adata = tfidf_normalize(cv_adata, binarize= True) 264 265 # Estimating kernel widths and calculating Zs 266 cv_adata = calculate_z(cv_adata, n_features= 5000, 267 batches = batches, batch_size = batch_size) 268 269 # In train_model we index Z_train for balancing multiclass labels. We just recreate 270 # dummy indices here that are unused for use in the binary case 271 cv_adata.uns['train_indices'] = np.arange(0, len(fold_train)) 272 273 gc.collect() 274 275 for i, alpha in enumerate(alpha_array): 276 277 cv_adata = train_model(cv_adata, group_size, alpha = alpha) 278 _, metrics = predict(cv_adata, metrics = [metric]) 279 metric_array[i, fold] = metrics[metric] 280 281 gc.collect() 282 283 del cv_adata 284 gc.collect() 285 286 # Take AUROC mean across the k folds to find alpha yielding highest AUROC 287 alpha_star = alpha_array[np.argmax(np.mean(metric_array, axis = 1))] 288 gc.collect() 289 290 291 return alpha_star
Iteratively train a grouplasso model and update alpha to find the parameter yielding best performing sparsity. This function currently only works for binary experiments.
Parameters
- adata (ad.AnnData | list[ad.AnnData]):
ad.AnnData
(s) with'Z_train'
and'Z_test'
inadata.uns.keys()
. - group_size (None | int):
Argument describing how the features are grouped. If
None
,2 * adata.uns['D']
will be used. For more information see celer documentation. - tfidf (list | bool):
If
False
, no data will be TF-IDF transformed. Iftype(adata) is list
and TF-IDF transformation is desired for all or some of the data, a bool list corresponding toadata
must be provided. To simply TF-IDF transformadata
whentype(adata) is ad.AnnData
, useTrue
. - alpha_array (np.ndarray): Array of all alpha values to be tested.
- k (int): Number of folds to perform cross validation over.
- metric (str):
Which metric to use to optimize alpha. Options are
'AUROC'
,'Accuracy'
,'F1-Score'
,'Precision'
, and'Recall'
. - batches (int):
The number of batches to use for the distance calculation.
This will average the result of
batches
distance calculations ofbatch_size
randomly sampled cells. More batches will converge to population distance values at the cost of scalability. - batch_size (int):
The number of cells to include per batch for distance
calculations. Higher batch size will converge to population
distance values at the cost of scalability. If
batches*batch_size > num_training_cells
,batch_size
will be reduced toint(num_training_cells/batches)
.
Returns
- alpha_star (float): The best performing alpha value from cross validation on training data.
Examples
>>> alpha_star = scmkl.optimize_alpha(adata)
>>> alpha_star
0.1
9def optimize_sparsity(adata: ad.AnnData, group_size: int | None=None, starting_alpha = 1.9, 10 increment = 0.2, target = 1, n_iter = 10): 11 """ 12 Iteratively train a grouplasso model and update alpha to find the 13 parameter yielding the desired sparsity. 14 15 Parameters 16 ---------- 17 adata : ad.AnnData 18 `ad.AnnData` with `'Z_train'` and `'Z_test'` in 19 `adata.uns.keys()`. 20 21 group_size : None | int 22 Argument describing how the features are grouped. If `None`, 23 `2 * adata.uns['D']` will be used. For more information see 24 [celer documentation](https://mathurinm.github.io/celer/ 25 generated/celer.GroupLasso.html). 26 27 starting_alpha : float 28 The alpha value to start the search at. 29 30 increment : float 31 Amount to adjust alpha by between iterations. 32 33 target : int 34 The desired number of groups selected by the model. 35 36 n_iter : int 37 The maximum number of iterations to run. 38 39 Returns 40 ------- 41 sparsity_dict : dict 42 Tested alpha as keys and the number of selected groups as 43 the values. 44 45 alpha : float 46 The alpha value yielding the number of selected groups closest 47 to the target. 48 49 Examples 50 -------- 51 >>> sparcity_dict, alpha = scmkl.optimize_sparsity(adata, 52 ... target = 1) 53 >>> 54 >>> alpha 55 0.01 56 57 See Also 58 -------- 59 celer.GroupLasso : https://mathurinm.github.io/celer/ 60 """ 61 assert increment > 0 and increment < starting_alpha, ("Choose a positive " 62 "increment less " 63 "than alpha") 64 assert target > 0 and isinstance(target, int), ("Choose an integer " 65 "target number of groups " 66 "that is greater than 0") 67 assert n_iter > 0 and isinstance(n_iter, int), ("Choose an integer " 68 "number of iterations " 69 "that is greater than 0") 70 71 if group_size == None: 72 group_size = adata.uns['D']*2 73 74 sparsity_dict = {} 75 alpha = starting_alpha 76 77 for _ in np.arange(n_iter): 78 adata = train_model(adata, group_size, alpha) 79 num_selected = len(find_selected_groups(adata)) 80 81 sparsity_dict[np.round(alpha, 4)] = num_selected 82 83 if num_selected < target: 84 #Decreasing alpha will increase the number of selected pathways 85 if alpha - increment in sparsity_dict.keys(): 86 # Make increment smaller so the model can't go back and forth 87 # between alpha values 88 increment/=2 89 # Ensures that alpha will never be negative 90 alpha = np.max([alpha - increment, 1e-3]) 91 92 elif num_selected > target: 93 if alpha + increment in sparsity_dict.keys(): 94 increment/=2 95 96 alpha += increment 97 elif num_selected == target: 98 break 99 100 # Find the alpha that minimizes the difference between target and observed 101 # number of selected groups 102 spar_idx = np.argmin([np.abs(selected - target) 103 for selected in sparsity_dict.values()]) 104 optimal_alpha = list(sparsity_dict.keys())[spar_idx] 105 106 return sparsity_dict, optimal_alpha
Iteratively train a grouplasso model and update alpha to find the parameter yielding the desired sparsity.
Parameters
- adata (ad.AnnData):
ad.AnnData
with'Z_train'
and'Z_test'
inadata.uns.keys()
. - group_size (None | int):
Argument describing how the features are grouped. If
None
,2 * adata.uns['D']
will be used. For more information see celer documentation. - starting_alpha (float): The alpha value to start the search at.
- increment (float): Amount to adjust alpha by between iterations.
- target (int): The desired number of groups selected by the model.
- n_iter (int): The maximum number of iterations to run.
Returns
- sparsity_dict (dict): Tested alpha as keys and the number of selected groups as the values.
- alpha (float): The alpha value yielding the number of selected groups closest to the target.
Examples
>>> sparcity_dict, alpha = scmkl.optimize_sparsity(adata,
... target = 1)
>>>
>>> alpha
0.01
See Also
celer.GroupLasso
: https://mathurinm.github.io/celer/
131def parse_metrics(results: dict, key: str | None=None, 132 include_as: bool=False) -> pd.DataFrame: 133 """ 134 This function returns a pd.DataFrame for a single scMKL result 135 with performance results. 136 137 Parameters 138 ---------- 139 results : dict 140 A result dictionary from `scmkl.run()`. 141 142 key : str 143 If specified, will add a key column to the output dataframe 144 where each element is `key`. 145 146 include_as : bool 147 If `True`, will add a column indicating which models' used 148 the optimal alphas. 149 150 Returns 151 ------- 152 df : pd.DataFrame 153 A dataframe with columns `['Alpha', 'Metric', 'Value']`. 154 `'Key'` col only added if `key` is not `None`. 155 """ 156 df = { 157 'Alpha' : list(), 158 'Metric' : list(), 159 'Value' : list() 160 } 161 162 # Check if is a multiclass result 163 is_mult, _ = _parse_result_type(results) 164 165 if is_mult: 166 df['Class'] = list() 167 168 # Ensuring results is a scMKL result and checking multiclass 169 if 'Metrics' in results.keys(): 170 for alpha in results['Metrics'].keys(): 171 for metric, value in results['Metrics'][alpha].items(): 172 df['Alpha'].append(alpha) 173 df['Metric'].append(metric) 174 df['Value'].append(value) 175 176 elif 'Classes' in results.keys(): 177 for ct in results['Classes']: 178 for alpha in results[ct]['Metrics'].keys(): 179 for metric, value in results[ct]['Metrics'][alpha].items(): 180 df['Alpha'].append(alpha) 181 df['Metric'].append(metric) 182 df['Value'].append(value) 183 df['Class'].append(ct) 184 185 else: 186 print(f"{key} is not a scMKL result and will be ignored.") 187 188 df = pd.DataFrame(df) 189 190 if include_as: 191 assert 'Alpha_star' in results.keys(), "'Alpha_star' not in results" 192 df['Alpha Star'] = df['Alpha'] == results['Alpha_star'] 193 194 if key is not None: 195 df['Key'] = [key] * df.shape[0] 196 197 return df
This function returns a pd.DataFrame for a single scMKL result with performance results.
Parameters
- results (dict):
A result dictionary from
scmkl.run
. - key (str):
If specified, will add a key column to the output dataframe
where each element is
key
. - include_as (bool):
If
True
, will add a column indicating which models' used the optimal alphas.
Returns
- df (pd.DataFrame):
A dataframe with columns
['Alpha', 'Metric', 'Value']
.'Key'
col only added ifkey
is notNone
.
200def parse_weights(results: dict, include_as: bool=False, 201 key: None | str=None) -> pd.DataFrame: 202 """ 203 This function returns a pd.DataFrame for a single scMKL result 204 with group weights. 205 206 Parameters 207 ---------- 208 results : dict 209 A result dictionary from `scmkl.run()`. 210 211 key : str 212 If specified, will add a key column to the output dataframe 213 where each element is `key`. 214 215 include_as : bool 216 If `True`, will add a column indicating which models' used 217 the optimal alphas. 218 219 Returns 220 ------- 221 df : pd.DataFrame 222 A dataframe with columns `['Alpha', 'Group', 223 'Kernel Weight']`. `'Key'` col only added if `key` is not 224 `None`. 225 """ 226 df = { 227 'Alpha' : list(), 228 'Group' : list(), 229 'Kernel Weight' : list() 230 } 231 232 # Check if is a multiclass result 233 is_mult, _ = _parse_result_type(results) 234 235 if is_mult: 236 df['Class'] = list() 237 238 # Ensuring results is a scMKL result and checking multiclass 239 if 'Norms' in results.keys(): 240 for alpha in results['Norms'].keys(): 241 df['Alpha'].extend([alpha]*len(results['Norms'][alpha])) 242 df['Group'].extend(results['Group_names']) 243 df['Kernel Weight'].extend(results['Norms'][alpha]) 244 245 elif 'Classes' in results.keys(): 246 for ct in results['Classes']: 247 for alpha in results[ct]['Norms'].keys(): 248 df['Alpha'].extend([alpha] * len(results[ct]['Norms'][alpha])) 249 df['Group'].extend(results[ct]['Group_names']) 250 df['Kernel Weight'].extend(results[ct]['Norms'][alpha]) 251 df['Class'].extend([ct]*len(results[ct]['Norms'][alpha])) 252 253 df = pd.DataFrame(df) 254 255 if include_as: 256 df['Alpha Star'] = df['Alpha'] == results['Alpha_star'] 257 258 if key is not None: 259 df['Key'] = [key] * df.shape[0] 260 261 return df
This function returns a pd.DataFrame for a single scMKL result with group weights.
Parameters
- results (dict):
A result dictionary from
scmkl.run
. - key (str):
If specified, will add a key column to the output dataframe
where each element is
key
. - include_as (bool):
If
True
, will add a column indicating which models' used the optimal alphas.
Returns
- df (pd.DataFrame):
A dataframe with columns
['Alpha', 'Group', 'Kernel Weight']
.'Key'
col only added ifkey
is notNone
.
196def plot_metric(summary_df : pd.DataFrame, alpha_star = None, 197 x_axis: str='Alpha', color = 'red'): 198 """ 199 Takes a data frame of model metrics and optionally alpha star and 200 creates a scatter plot given metrics against alpha values. For 201 multiclass results, `alpha_star` is not shown and points are 202 colored by class. 203 204 Parameters 205 ---------- 206 summary_df : pd.DataFrame 207 Dataframe created by `scmkl.get_summary()`. 208 209 alpha_star : None | float 210 If `not None`, a label will be added for tuned `alpha_star` 211 being optimal model parameter for performance from cross 212 validation on the training data. Can be calculated with 213 `scmkl.optimize_alpha()`. Is ignored if `summary_df` is from a 214 multiclass result. 215 216 x_axis : str 217 Must be either `'Alpha'` or `'Number of Selected Groups'`. Is 218 the variable that will be plotted on the x-axis. 219 220 color : str 221 Color to make points on plot. 222 223 Returns 224 ------- 225 metric_plot : plotnine.ggplot.ggplot 226 A plot with alpha values on x-axis and metric on y-axis. 227 228 Examples 229 -------- 230 >>> results = scmkl.run(adata, alpha_list) 231 >>> summary_df = scmkl.get_summary(results) 232 >>> metric_plot = plot_metric(results) 233 >>> 234 >>> metric_plot.save('scMKL_performance.png') 235 236  237 """ 238 # Capturing metric from summary_df 239 metric_options = ['AUROC', 'Accuracy', 'F1-Score', 'Precision', 'Recall'] 240 metric = np.intersect1d(metric_options, summary_df.columns)[0] 241 242 x_axis_ticks = np.unique(summary_df[x_axis]) 243 summary_df['Alpha Star'] = summary_df['Alpha'] == alpha_star 244 245 if 'Class' in summary_df.columns: 246 color_lab = 'Class' 247 else: 248 c_dict = { 249 True : 'gold', 250 False : color 251 } 252 color_lab = 'Alpha Star' 253 254 if np.min(summary_df[metric]) < 0.5: 255 min_y = 0 256 else: 257 min_y = 0.6 258 259 plot = (ggplot(summary_df, aes(x = x_axis, y = metric, color=color_lab)) 260 + geom_point(size=4) 261 + theme_classic() 262 + ylim(min_y, 1) 263 + scale_x_reverse(breaks=x_axis_ticks) 264 + theme( 265 axis_text=element_text(weight='bold', size=10), 266 axis_title=element_text(weight='bold', size=12), 267 legend_title=element_text(weight='bold', size=14), 268 legend_text=element_text(weight='bold', size=12) 269 ) 270 ) 271 272 if not 'Class' in summary_df: 273 plot += scale_color_manual(c_dict) 274 275 return plot
Takes a data frame of model metrics and optionally alpha star and
creates a scatter plot given metrics against alpha values. For
multiclass results, alpha_star
is not shown and points are
colored by class.
Parameters
- summary_df (pd.DataFrame):
Dataframe created by
scmkl.get_summary()
. - alpha_star (None | float):
If
not None
, a label will be added for tunedalpha_star
being optimal model parameter for performance from cross validation on the training data. Can be calculated withscmkl.optimize_alpha
. Is ignored ifsummary_df
is from a multiclass result. - x_axis (str):
Must be either
'Alpha'
or'Number of Selected Groups'
. Is the variable that will be plotted on the x-axis. - color (str): Color to make points on plot.
Returns
- metric_plot (plotnine.ggplot.ggplot): A plot with alpha values on x-axis and metric on y-axis.
Examples
>>> results = scmkl.run(adata, alpha_list)
>>> summary_df = scmkl.get_summary(results)
>>> metric_plot = plot_metric(results)
>>>
>>> metric_plot.save('scMKL_performance.png')
80def plot_conf_mat(results, title = '', cmap = None, normalize = True, 81 alpha = None, save = None) -> None: 82 """ 83 Creates a confusion matrix from the output of scMKL. 84 85 Parameters 86 ---------- 87 results : dict 88 The output from either scmkl.run() or scmkl.one_v_rest() 89 containing results from scMKL. 90 91 title : str 92 The text to display at the top of the matrix. 93 94 cmap : matplotlib.colors.LinearSegmentedColormap 95 The gradient of the values displayed from `matplotlib.pyplot`. 96 If `None`, `'Purples'` is used see matplotlib color map 97 reference for more information. 98 99 normalize : bool 100 If `False`, plot the raw numbers. If `True`, plot the 101 proportions. 102 103 alpha : None | float 104 Alpha that matrix should be created for. If `results` is from 105 `scmkl.one_v_all()`, this is ignored. If `None`, smallest alpha 106 will be used. 107 108 save : None | str 109 File path to save plot. If `None`, plot is not saved. 110 111 Returns 112 ------- 113 None 114 115 Examples 116 -------- 117 >>> # Running scmkl and capturing results 118 >>> results = scmkl.run(adata = adata, alpha_list = alpha_list) 119 >>> 120 >>> from matplotlib.pyplot import get_cmap 121 >>> 122 >>> scmkl.plot_conf_mat(results, title = '', cmap = get_cmap('Blues')) 123 124  125 126 Citiation 127 --------- 128 http://scikit-learn.org/stable/auto_examples/model_selection/ 129 plot_confusion_matrix.html 130 """ 131 # Determining type of results 132 if ('Observed' in results.keys()) and ('Metrics' in results.keys()): 133 multi_class = False 134 names = np.unique(results['Observed']) 135 else: 136 multi_class = True 137 names = np.unique(results['Truth_labels']) 138 139 if multi_class: 140 cm = metrics.confusion_matrix(y_true = results['Truth_labels'], 141 y_pred = results['Predicted_class'], 142 labels = names) 143 else: 144 min_alpha = np.min(list(results['Metrics'].keys())) 145 alpha = alpha if alpha != None else min_alpha 146 cm = metrics.confusion_matrix(y_true = results['Observed'], 147 y_pred = results['Predictions'][alpha], 148 labels = names) 149 150 accuracy = np.trace(cm) / float(np.sum(cm)) 151 misclass = 1 - accuracy 152 153 if cmap is None: 154 cmap = plt.get_cmap('Purples') 155 156 if normalize: 157 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 158 159 plt.figure(figsize=(8, 6)) 160 plt.imshow(cm, interpolation='nearest', cmap=cmap) 161 plt.title(title) 162 plt.colorbar() 163 164 tick_marks = np.arange(len(names)) 165 plt.xticks(tick_marks, names, rotation=45) 166 plt.yticks(tick_marks, names) 167 168 169 thresh = cm.max() / 1.5 if normalize else cm.max() / 2 170 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 171 if normalize: 172 plt.text(j, i, "{:0.4f}".format(cm[i, j]), 173 horizontalalignment="center", 174 color="white" if cm[i, j] > thresh else "black") 175 else: 176 plt.text(j, i, "{:,}".format(cm[i, j]), 177 horizontalalignment="center", 178 color="white" if cm[i, j] > thresh else "black") 179 180 acc_label = 'Predicted label\naccuracy={:0.4f}; misclass={:0.4f}' 181 acc_label = acc_label.format(accuracy, misclass) 182 183 plt.tight_layout() 184 plt.ylabel('True label') 185 plt.xlabel(acc_label) 186 187 if save != None: 188 plt.savefig(save) 189 plt.clf() 190 else: 191 plt.show() 192 193 return None
Creates a confusion matrix from the output of scMKL.
Parameters
- results (dict): The output from either scmkl.run or scmkl.one_v_rest containing results from scMKL.
- title (str): The text to display at the top of the matrix.
- cmap (matplotlib.colors.LinearSegmentedColormap):
The gradient of the values displayed from
matplotlib.pyplot
. IfNone
,'Purples'
is used see matplotlib color map reference for more information. - normalize (bool):
If
False
, plot the raw numbers. IfTrue
, plot the proportions. - alpha (None | float):
Alpha that matrix should be created for. If
results
is fromscmkl.one_v_all()
, this is ignored. IfNone
, smallest alpha will be used. - save (None | str):
File path to save plot. If
None
, plot is not saved.
Returns
- None
Examples
>>> # Running scmkl and capturing results
>>> results = scmkl.run(adata = adata, alpha_list = alpha_list)
>>>
>>> from matplotlib.pyplot import get_cmap
>>>
>>> scmkl.plot_conf_mat(results, title = '', cmap = get_cmap('Blues'))
Citiation
http://scikit-learn.org/stable/auto_examples/model_selection/ plot_confusion_matrix.html
354def read_files(dir: str, pattern: str | None=None) -> dict: 355 """ 356 This function takes a directory of scMKL results as pickle files 357 and returns a dictionary with the file names as keys and the data 358 from the respective files as the values. 359 360 Parameters 361 ---------- 362 dir : str 363 A string specifying the file path for the output scMKL runs. 364 365 pattern : str 366 A regex string for filtering down to desired files. If 367 `None`, all files in the directory with the pickle file 368 extension will be added to the dictionary. 369 370 Returns 371 ------- 372 results : dict 373 A dictionary with the file names as keys and data as values. 374 375 Examples 376 -------- 377 >>> filepath = 'scMKL_results/rna+atac/' 378 ... 379 >>> all_results = scmkl.read_files(filepath) 380 >>> all_results.keys() 381 dict_keys(['Rep_1.pkl', Rep_2.pkl, Rep_3.pkl, ...]) 382 """ 383 # Reading all pickle files in patter is None 384 if pattern is None: 385 data = {file : np.load(f'{dir}/{file}', allow_pickle = True) 386 for file in os.listdir(dir) if '.pkl' in file} 387 388 # Reading only files matching pattern if not None 389 else: 390 pattern = repr(pattern) 391 data = {file : np.load(f'{dir}/{file}', allow_pickle = True) 392 for file in os.listdir(dir) 393 if re.fullmatch(pattern, file) is not None} 394 395 return data
This function takes a directory of scMKL results as pickle files and returns a dictionary with the file names as keys and the data from the respective files as the values.
Parameters
- dir (str): A string specifying the file path for the output scMKL runs.
- pattern (str):
A regex string for filtering down to desired files. If
None
, all files in the directory with the pickle file extension will be added to the dictionary.
Returns
- results (dict): A dictionary with the file names as keys and data as values.
Examples
>>> filepath = 'scMKL_results/rna+atac/'
...
>>> all_results = scmkl.read_files(filepath)
>>> all_results.keys()
dict_keys(['Rep_1.pkl', Rep_2.pkl, Rep_3.pkl, ...])
642def read_gtf(path: str, filter_to_coding: bool=False): 643 """ 644 Reads and formats a gtf file. Adds colnames: `['chr', 'source', 645 'feature', 'start', 'end', 'score', 'strand', 'frame', 646 'attribute']`. 647 648 Parameters 649 ---------- 650 path : str 651 The file path to the gtf file to be read in. If the file is 652 gzipped, file name must end with .gz. 653 654 filter_to_coding : bool 655 If `True`, will filter rows in gtf data frame to only 656 protein coding genes. Will add column `'gene_name'` containing 657 the gene name for each row. 658 659 Returns 660 ------- 661 df : pd.DataFrame 662 A pandas dataframe of the input gtf file. 663 664 Examples 665 -------- 666 >>> import scmkl 667 >>> 668 >>> file = 'data/hg38_subset_protein_coding.annotation.gtf' 669 >>> gtf = scmkl.read_gtf(file) 670 >>> 671 >>> gtf.head() 672 chr source feature start end score strand frame 673 0 chr1 HAVANA gene 11869 14409 . + . 674 1 chr1 HAVANA transcript 11869 14409 . + . 675 2 chr1 HAVANA exon 11869 12227 . + . 676 3 chr1 HAVANA exon 12613 12721 . + . 677 4 chr1 HAVANA exon 13221 14409 . + . 678 attribute 679 gene_id "ENSG00000223972.5"; gene_type "transc... 680 gene_id "ENSG00000223972.5"; transcript_id "EN... 681 gene_id "ENSG00000223972.5"; transcript_id "EN... 682 gene_id "ENSG00000223972.5"; transcript_id "EN... 683 gene_id "ENSG00000223972.5"; transcript_id "EN... 684 """ 685 df = pd.read_csv(path, sep='\t', comment='#', 686 skip_blank_lines=True, header=None) 687 688 df.columns = ['chr', 'source', 'feature', 'start', 'end', 689 'score', 'strand', 'frame', 'attribute'] 690 691 if filter_to_coding: 692 prot_rows = df['attribute'].str.contains('protein_coding') 693 df = df[prot_rows] 694 df = df[df['feature'] == 'gene'] 695 696 # Capturing and adding gene name to df 697 df['gene_name'] = [re.findall(r'(?<=gene_name ")[A-z0-9]+', 698 attr)[0] 699 for attr in df['attribute']] 700 701 return df
Reads and formats a gtf file. Adds colnames: ['chr', 'source',
'feature', 'start', 'end', 'score', 'strand', 'frame',
'attribute']
.
Parameters
- path (str): The file path to the gtf file to be read in. If the file is gzipped, file name must end with .gz.
- filter_to_coding (bool):
If
True
, will filter rows in gtf data frame to only protein coding genes. Will add column'gene_name'
containing the gene name for each row.
Returns
- df (pd.DataFrame): A pandas dataframe of the input gtf file.
Examples
>>> import scmkl
>>>
>>> file = 'data/hg38_subset_protein_coding.annotation.gtf'
>>> gtf = scmkl.read_gtf(file)
>>>
>>> gtf.head()
chr source feature start end score strand frame
0 chr1 HAVANA gene 11869 14409 . + .
1 chr1 HAVANA transcript 11869 14409 . + .
2 chr1 HAVANA exon 11869 12227 . + .
3 chr1 HAVANA exon 12613 12721 . + .
4 chr1 HAVANA exon 13221 14409 . + .
attribute
gene_id "ENSG00000223972.5"; gene_type "transc...
gene_id "ENSG00000223972.5"; transcript_id "EN...
gene_id "ENSG00000223972.5"; transcript_id "EN...
gene_id "ENSG00000223972.5"; transcript_id "EN...
gene_id "ENSG00000223972.5"; transcript_id "EN...
11def run(adata: ad.AnnData, alpha_list: np.ndarray, 12 metrics: list | None = None, 13 return_probs: bool=False) -> dict: 14 """ 15 Wrapper function for training and test with multiple alpha values. 16 Returns metrics, predictions, group weights, and resource usage. 17 18 Parameters 19 ---------- 20 adata : ad.AnnData 21 A processed `ad.AnnData` with `'Z_train'`, `'Z_test'`, and 22 `'group_dict'` keys in `adata.uns`. 23 24 alpha_list : np.ndarray 25 Sparsity values to create models with. Alpha refers to the 26 penalty parameter in Group Lasso. Larger alphas force group 27 weights to shrink towards zero while smaller alphas apply a 28 lesser penalty to kernal weights. Values too large will results 29 in models that weight all groups as zero. 30 31 metrics : list[str] 32 Metrics that should be calculated on predictions. Options are 33 `['AUROC', 'F1-Score', 'Accuracy', 'Precision', 'Recall']`. 34 When set to `None`, all metrics are calculated. 35 36 Returns 37 ------- 38 results : dict 39 Results with keys and values: 40 41 `'Metrics'` (dict): 42 A nested dictionary as `[alpha][metric] = value`. 43 44 `'Group_names'` (np.ndarray): 45 Array of group names used in model(s). 46 47 `'Selected_groups'` (dict): 48 A nested dictionary as `[alpha] = np.array([nonzero_groups])`. 49 Nonzero groups are groups that had a kernel weight above zero. 50 51 `'Norms'` (dict): 52 A nested dictionary as `[alpha] = np.array([kernel_weights])` 53 Order of `kernel_weights` is respective to `'Group_names'` 54 values. 55 56 `'Observed'` (np.nparray): 57 An array of ground truth cell labels from the test set. 58 59 `'Predictions'` (dict): 60 A nested dictionary as `[alpha] = predicted_class` respective 61 to `'Observations'` for `alpha`. 62 63 `'Test_indices'` (np.array: 64 Indices of samples respective to adata used in the training 65 set. 66 67 `'Model'` (dict): 68 A nested dictionary where `[alpha] = celer.GroupLasso` object 69 for `alpha`. 70 71 `'RAM_usage'` (dict): 72 A nested dictionary with memory usage in GB after 73 training models for each `alpha`. 74 75 Examples 76 -------- 77 >>> results = scmkl.run(adata = adata, 78 ... alpha_list = np.array([0.05, 0.1, 0.5])) 79 >>> results 80 dict_keys(['Metrics', 'Selected_groups', 'Norms', 'Predictions', 81 ... 'Observed', 'Test_indices', 'Group_names', 'Models', 82 ... 'Train_time', 'RAM_usage']) 83 >>> 84 >>> alpha values 85 >>> results['Metrics'].keys() 86 dict_keys([0.05, 0.1, 0.5]) 87 >>> 88 >>> results['Metrics'][0.05] 89 {'AUROC': 0.9859, 90 'Accuracy': 0.945, 91 'F1-Score': 0.9452736318407959, 92 'Precision': 0.9405940594059405, 93 'Recall': 0.95} 94 """ 95 if metrics is None: 96 metrics = ['AUROC', 'F1-Score','Accuracy', 'Precision', 'Recall'] 97 98 # Initializing variables to capture metrics 99 group_names = list(adata.uns['group_dict'].keys()) 100 preds = {} 101 group_norms = {} 102 mets_dict = {} 103 selected_groups = {} 104 train_time = {} 105 models = {} 106 probs = {} 107 108 D = adata.uns['D'] 109 110 # Generating models for each alpha and outputs 111 tracemalloc.start() 112 for alpha in alpha_list: 113 114 print(f' Evaluating model. Alpha: {alpha}', flush = True) 115 116 train_start = time.time() 117 118 adata = train_model(adata, group_size= 2*D, alpha = alpha) 119 120 if return_probs: 121 alpha_res = predict(adata, 122 metrics = metrics, 123 return_probs = return_probs) 124 preds[alpha], mets_dict[alpha], probs[alpha] = alpha_res 125 126 else: 127 alpha_res = predict(adata, 128 metrics = metrics, 129 return_probs = return_probs) 130 preds[alpha], mets_dict[alpha] = alpha_res 131 132 selected_groups[alpha] = find_selected_groups(adata) 133 134 kernel_weights = adata.uns['model'].coef_ 135 group_norms[alpha] = [ 136 np.linalg.norm(kernel_weights[i * 2 * D : (i + 1) * 2 * D - 1]) 137 for i in np.arange(len(group_names)) 138 ] 139 140 models[alpha] = adata.uns['model'] 141 142 train_end = time.time() 143 train_time[alpha] = train_end - train_start 144 145 # Combining results into one object 146 results = {} 147 results['Metrics'] = mets_dict 148 results['Selected_groups'] = selected_groups 149 results['Norms'] = group_norms 150 results['Predictions'] = preds 151 results['Observed'] = adata.obs['labels'].iloc[adata.uns['test_indices']] 152 results['Test_indices'] = adata.uns['test_indices'] 153 results['Group_names']= group_names 154 results['Models'] = models 155 results['Train_time'] = train_time 156 results['RAM_usage'] = f'{tracemalloc.get_traced_memory()[1]/1e9} GB' 157 results['Probabilities'] = probs 158 159 return results
Wrapper function for training and test with multiple alpha values. Returns metrics, predictions, group weights, and resource usage.
Parameters
- adata (ad.AnnData):
A processed
ad.AnnData
with'Z_train'
,'Z_test'
, and'group_dict'
keys inadata.uns
. - alpha_list (np.ndarray): Sparsity values to create models with. Alpha refers to the penalty parameter in Group Lasso. Larger alphas force group weights to shrink towards zero while smaller alphas apply a lesser penalty to kernal weights. Values too large will results in models that weight all groups as zero.
- metrics (list[str]):
Metrics that should be calculated on predictions. Options are
['AUROC', 'F1-Score', 'Accuracy', 'Precision', 'Recall']
. When set toNone
, all metrics are calculated.
Returns
results (dict): Results with keys and values:
'Metrics'
(dict): A nested dictionary as[alpha][metric] = value
.'Group_names'
(np.ndarray): Array of group names used in model(s).'Selected_groups'
(dict): A nested dictionary as[alpha] = np.array([nonzero_groups])
. Nonzero groups are groups that had a kernel weight above zero.'Norms'
(dict): A nested dictionary as[alpha] = np.array([kernel_weights])
Order ofkernel_weights
is respective to'Group_names'
values.'Observed'
(np.nparray): An array of ground truth cell labels from the test set.'Predictions'
(dict): A nested dictionary as[alpha] = predicted_class
respective to'Observations'
foralpha
.'Test_indices'
(np.array: Indices of samples respective to adata used in the training set.'Model'
(dict): A nested dictionary where[alpha] = celer.GroupLasso
object foralpha
.'RAM_usage'
(dict): A nested dictionary with memory usage in GB after training models for eachalpha
.
Examples
>>> results = scmkl.run(adata = adata,
... alpha_list = np.array([0.05, 0.1, 0.5]))
>>> results
dict_keys(['Metrics', 'Selected_groups', 'Norms', 'Predictions',
... 'Observed', 'Test_indices', 'Group_names', 'Models',
... 'Train_time', 'RAM_usage'])
>>>
>>> alpha values
>>> results['Metrics'].keys()
dict_keys([0.05, 0.1, 0.5])
>>>
>>> results['Metrics'][0.05]
{'AUROC': 0.9859,
'Accuracy': 0.945,
'F1-Score': 0.9452736318407959,
'Precision': 0.9405940594059405,
'Recall': 0.95}
49def sort_groups(df: pd.DataFrame, group_col: str='Group', 50 norm_col: str='Kernel Weight'): 51 """ 52 Takes a dataframe with `group_col` and returns sorted group list 53 with groups in decending order by their weights. Assumes there is 54 one instance of each group. 55 56 Parameters 57 ---------- 58 df : pd.DataFrame 59 A dataframe with `group_col` and `norm_col` to be sorted by. 60 61 group_col : str 62 The column containing the group names. 63 64 norm_col : str 65 The column containing the kernel weights. 66 67 Returns 68 ------- 69 group_order : list 70 A list of groups in descending order according to their kernel 71 weights. 72 73 Examples 74 -------- 75 >>> result = scmkl.run(adata, alpha_list) 76 >>> weights = scmkl.get_weights(result) 77 >>> group_order = scmkl.sort_groups(weights, 'Group', 78 ... 'Kernel Weight') 79 >>> 80 >>> group_order 81 ['HALLMARK_ESTROGEN_RESPONSE_EARLY', 'HALLM...', ...] 82 """ 83 df = df.copy() 84 df = df.sort_values(norm_col, ascending=False) 85 group_order = list(df[group_col]) 86 87 return group_order
Takes a dataframe with group_col
and returns sorted group list
with groups in decending order by their weights. Assumes there is
one instance of each group.
Parameters
- df (pd.DataFrame):
A dataframe with
group_col
andnorm_col
to be sorted by. - group_col (str): The column containing the group names.
- norm_col (str): The column containing the kernel weights.
Returns
- group_order (list): A list of groups in descending order according to their kernel weights.
Examples
>>> result = scmkl.run(adata, alpha_list)
>>> weights = scmkl.get_weights(result)
>>> group_order = scmkl.sort_groups(weights, 'Group',
... 'Kernel Weight')
>>>
>>> group_order
['HALLMARK_ESTROGEN_RESPONSE_EARLY', 'HALLM...', ...]
77def tfidf_normalize(adata: ad.AnnData, binarize: bool=False): 78 """ 79 Function to TF-IDF normalize the data in an adata object. If any 80 rows are entirely 0, that row and its metadata will be removed from 81 the object. 82 83 Parameters 84 ---------- 85 adata : ad.AnnData 86 `ad.Anndata` with `.X` to be normalized. If `'train_indices'` 87 and `'test_indices'` in `'adata.uns.keys()'`, normalization 88 will be done separately for the training and testing data. 89 Otherwise, it will calculate it on the entire dataset. 90 91 binarize : bool 92 If `True`, all values in `adata.X` greater than 1 will become 93 1. 94 95 Returns 96 ------- 97 adata : ad.AnnData 98 `adata` with `adata.X` TF-IDF normalized. Will now have the 99 train data stacked on test data, and the indices will be 100 adjusted accordingly. 101 102 Examples 103 -------- 104 >>> adata = scmkl.create_adata(X = data_mat, 105 ... feature_names = gene_names, 106 ... group_dict = group_dict) 107 >>> 108 >>> adata = scmkl.tfidf_normalize(adata) 109 """ 110 X = adata.X.copy() 111 row_sums = np.sum(X, axis = 1) 112 assert np.all(row_sums > 0), "TFIDF requires all row sums be positive" 113 114 if binarize: 115 X[X > 0] = 1 116 117 if 'train_indices' in adata.uns_keys(): 118 119 train_indices = adata.uns['train_indices'].copy() 120 test_indices = adata.uns['test_indices'].copy() 121 122 # Calculate the train TFIDF matrix on just the training data so it is 123 # not biased by testing data 124 tfidf_train = tfidf(X[train_indices,:], mode = 'normalize') 125 126 # Calculate the test TFIDF by calculating it on the train and test 127 # data and index the test data 128 tfidf_test = tfidf(X, mode = 'normalize')[test_indices,:] 129 130 # Impossible to add rows back to original location so we need to 131 # stack the matrices to maintain train/test 132 if scipy.sparse.issparse(X): 133 tfidf_norm = scipy.sparse.vstack((tfidf_train, tfidf_test)) 134 else: 135 tfidf_norm = np.vstack((tfidf_train, tfidf_test)) 136 137 # I'm not sure why this reassignment is necessary, but without, the 138 # values will be saved as 0s in adata 139 adata.uns['train_indices'] = train_indices 140 adata.uns['test_indices'] = test_indices 141 142 combined_indices = np.concatenate((train_indices, test_indices)) 143 144 # Anndata indexes by "rownames" not position so we need to rename the 145 # rows to properly index 146 adata_index = adata.obs_names[combined_indices].astype(int) 147 tfidf_norm = tfidf_norm[np.argsort(adata_index),:] 148 149 else: 150 151 tfidf_norm = tfidf(X, mode = 'normalize') 152 153 adata.X = tfidf_norm.copy() 154 155 return adata
Function to TF-IDF normalize the data in an adata object. If any rows are entirely 0, that row and its metadata will be removed from the object.
Parameters
- adata (ad.AnnData):
ad.Anndata
with.X
to be normalized. If'train_indices'
and'test_indices'
in'adata.uns.keys()'
, normalization will be done separately for the training and testing data. Otherwise, it will calculate it on the entire dataset. - binarize (bool):
If
True
, all values inadata.X
greater than 1 will become 1.
Returns
- adata (ad.AnnData):
adata
withadata.X
TF-IDF normalized. Will now have the train data stacked on test data, and the indices will be adjusted accordingly.
Examples
>>> adata = scmkl.create_adata(X = data_mat,
... feature_names = gene_names,
... group_dict = group_dict)
>>>
>>> adata = scmkl.tfidf_normalize(adata)
7def train_model(adata: ad.AnnData, group_size: int | None=None, alpha:float=0.9): 8 """ 9 Fit a grouplasso model to the provided data. 10 11 Parameters 12 ---------- 13 adata : ad.AnnData 14 Has `'Z_train'` and `'Z_test'` keys in `.uns.keys()`. 15 16 group_size : None | int 17 Argument describing how the features are grouped. If `None`, 18 `2 * adata.uns['D']` will be used. For more information see 19 [celer documentation](https://mathurinm.github.io/celer/ 20 generated/celer.GroupLasso.html). 21 22 alpha : float 23 Group Lasso regularization coefficient, is a floating point 24 value controlling model solution sparsity. Must be a positive 25 float. The smaller the value, the more feature groups will be 26 selected in the trained model. 27 28 Returns 29 ------- 30 adata : ad.AnnData 31 Trained model accessible with `adata.uns['model']`. 32 33 Examples 34 -------- 35 >>> adata = scmkl.estimate_sigma(adata) 36 >>> adata = scmkl.calculate_z(adata) 37 >>> metrics = ['AUROC', 'F1-Score', 'Accuracy', 'Precision', 38 ... 'Recall'] 39 >>> d = scmkl.calculate_d(adata.shape[0]) 40 >>> group_size = 2 * d 41 >>> adata = scmkl.train_model(adata, group_size) 42 >>> 43 >>> 'model' in adata.uns.keys() 44 True 45 46 See Also 47 -------- 48 celer : 49 https://mathurinm.github.io/celer/generated/celer.GroupLasso.html 50 """ 51 assert alpha > 0, 'Alpha must be positive' 52 53 if group_size is None: 54 group_size = 2*adata.uns['D'] 55 56 y_train = adata.obs['labels'].iloc[adata.uns['train_indices']] 57 X_train = adata.uns['Z_train'][adata.uns['train_indices']] 58 59 cell_labels = np.unique(y_train) 60 61 # This is a regression algorithm. We need to make the labels 'continuous' 62 # for classification, but they will remain binary. Casts training labels 63 # to array of -1,1 64 train_labels = np.ones(y_train.shape) 65 train_labels[y_train == cell_labels[1]] = -1 66 67 # Alphamax is a calculation to regularize the effect of alpha across 68 # different data sets 69 alphamax = np.max(np.abs(X_train.T.dot(train_labels))) 70 alphamax /= X_train.shape[0] 71 alphamax *= alpha 72 73 # Instantiate celer Group Lasso Regression Model Object 74 model = celer.GroupLasso(groups = group_size, alpha = alphamax) 75 76 # Fit model using training data 77 model.fit(X_train, train_labels.ravel()) 78 79 adata.uns['model'] = model 80 return adata
Fit a grouplasso model to the provided data.
Parameters
- adata (ad.AnnData):
Has
'Z_train'
and'Z_test'
keys in.uns.keys()
. - group_size (None | int):
Argument describing how the features are grouped. If
None
,2 * adata.uns['D']
will be used. For more information see celer documentation. - alpha (float): Group Lasso regularization coefficient, is a floating point value controlling model solution sparsity. Must be a positive float. The smaller the value, the more feature groups will be selected in the trained model.
Returns
- adata (ad.AnnData):
Trained model accessible with
adata.uns['model']
.
Examples
>>> adata = scmkl.estimate_sigma(adata)
>>> adata = scmkl.calculate_z(adata)
>>> metrics = ['AUROC', 'F1-Score', 'Accuracy', 'Precision',
... 'Recall']
>>> d = scmkl.calculate_d(adata.shape[0])
>>> group_size = 2 * d
>>> adata = scmkl.train_model(adata, group_size)
>>>
>>> 'model' in adata.uns.keys()
True
See Also
celer
:
https://mathurinm.github.io/celer/generated/celer.GroupLasso.html
278def weights_barplot(result, n_groups: int=1, alpha: None | float=None, 279 color: str='red'): 280 """ 281 Plots the top `n_groups` weighted groups for each cell class. Works 282 for a single scmkl result (either multiclass or binary). 283 284 Parameters 285 ---------- 286 result : dict 287 The output of `scmkl.run()`. 288 289 n_groups : int 290 The number of top groups to plot for each cell class. 291 292 alpha : None | float 293 The alpha parameter to create figure for. If `None`, the 294 smallest alpha is used. 295 296 color : str 297 The color the bars should be. 298 299 Returns 300 ------- 301 plot : plotnine.ggplot.ggplot 302 A barplot of weights. 303 304 Examples 305 -------- 306 >>> result = scmkl.run(adata, alpha_list) 307 >>> plot = scmkl.plot_weights(result) 308 309  310 """ 311 is_multi, is_many = _parse_result_type(result) 312 assert not is_many, "This function only supports single results" 313 314 alpha = _get_alpha(alpha, result, is_multi) 315 df = get_weights(result) 316 317 # Subsetting to only alpha and filtering to top groups 318 df = df[df['Alpha'] == alpha] 319 if is_multi: 320 g_order = set(df['Group']) 321 for ct in result['Classes']: 322 cor_ct = df['Class'] == ct 323 other_ct = ~cor_ct 324 325 temp_df = df.copy() 326 temp_df = temp_df[cor_ct] 327 temp_df = temp_df.sort_values('Kernel Weight', ascending=False) 328 top_groups = temp_df['Group'].iloc[0:n_groups].to_numpy() 329 330 cor_groups = np.isin(df['Group'], top_groups) 331 filter_array = np.logical_and(cor_ct, cor_groups) 332 333 filter_array = np.logical_or(filter_array, other_ct) 334 335 df = df[filter_array] 336 337 df['Group'] = format_group_names(df['Group'], rm_words = ['Markers']) 338 339 340 else: 341 df = df.sort_values('Kernel Weight', ascending=False) 342 df = df.iloc[0:n_groups] 343 df['Group'] = format_group_names(df['Group'], rm_words = ['Markers']) 344 g_order = sort_groups(df)[::-1] 345 df['Group'] = pd.Categorical(df['Group'], g_order) 346 347 plot = (ggplot(df) 348 + theme_classic() 349 + coord_flip() 350 + labs(y=f'Kernel Weight (λ = {alpha})') 351 + theme( 352 axis_text=element_text(weight='bold', size=10), 353 axis_title=element_text(weight='bold', size=12), 354 axis_title_y=element_blank() 355 ) 356 ) 357 358 # This needs to be reworked for multiclass runs 359 if is_multi: 360 height = (3*ceil((len(set(df['Class'])) / 3))) 361 print(height) 362 plot += geom_bar(aes(x='Group', y='Kernel Weight'), 363 stat='identity', fill=color) 364 plot += facet_wrap('Class', scales='free', ncol=3) 365 plot += theme(figure_size=(15,height)) 366 else: 367 plot += geom_bar(aes(x='Group', y='Kernel Weight'), 368 stat='identity', fill=color) 369 plot += theme(figure_size=(7, 9)) 370 371 return plot
Plots the top n_groups
weighted groups for each cell class. Works
for a single scmkl result (either multiclass or binary).
Parameters
- result (dict):
The output of
scmkl.run
. - n_groups (int): The number of top groups to plot for each cell class.
- alpha (None | float):
The alpha parameter to create figure for. If
None
, the smallest alpha is used. - color (str): The color the bars should be.
Returns
- plot (plotnine.ggplot.ggplot): A barplot of weights.
Examples
>>> result = scmkl.run(adata, alpha_list)
>>> plot = scmkl.plot_weights(result)
493def weights_dotplot(result, n_groups: None | int=None, 494 class_lab: str | None=None, low: str='white', 495 high: str='red', alpha: float | None=None, 496 scale_weights: bool=False): 497 """ 498 Plots a dotplot of kernel weights with groups on the y-axis and 499 alpha on the x-axis if binary result. If a multiclass result, one 500 alpha is used per class and the x-axis is class. 501 502 Parameters 503 ---------- 504 result : dict 505 The output of `scmkl.run()`. 506 507 n_groups : int 508 The number of top groups to plot. Not recommended for 509 multiclass results. 510 511 class_lab : str | None 512 For multiclass results, if `not None`, will only plot group 513 weights for `class_lab`. 514 515 low : str 516 The color for low kernel weight. 517 518 high : str 519 The color for high kernel weight. 520 521 alpha : None | float 522 The alpha parameter to create figure for. If `None`, the 523 smallest alpha is used. 524 525 scale_weights : bool 526 If `True`, the the kernel weights will be scaled for each 527 within each class. 528 529 Returns 530 ------- 531 plot : plotnine.ggplot.ggplot 532 A barplot of weights. 533 534 Examples 535 -------- 536 >>> result = scmkl.run(adata, alpha_list) 537 >>> plot = scmkl.plot_weights(result) 538 539  540 """ 541 is_multi, is_many = _parse_result_type(result) 542 assert not is_many, "This function only supports single results" 543 544 if type(class_lab) is str: 545 result = result[class_lab] 546 547 df = get_weights(result) 548 df['Group'] = format_group_names(df['Group'], ['Markers']) 549 550 # Filtering and sorting values 551 sum_df = df.groupby('Group')['Kernel Weight'].sum() 552 sum_df = sum_df.reset_index() 553 order = sort_groups(sum_df)[::-1] 554 df['Group'] = pd.Categorical(df['Group'], categories=order) 555 556 if type(n_groups) is int: 557 sum_df = sum_df.sort_values(by='Kernel Weight', ascending=False) 558 top_groups = sum_df.iloc[0:n_groups]['Group'].to_numpy() 559 df = df[np.isin(df['Group'], top_groups)] 560 else: 561 n_groups = len(set(df['Group'])) 562 563 df['Alpha'] = pd.Categorical(df['Alpha'], np.unique(df['Alpha'])) 564 565 if n_groups > 40: 566 fig_size = (7,8) 567 elif n_groups < 25: 568 fig_size = (7,6) 569 else: 570 fig_size = (7,8) 571 572 if 'Class' in df.columns: 573 alpha = _get_alpha(alpha, result, is_multi) 574 df = df[df['Alpha'] == alpha] 575 x_lab = 'Class' 576 else: 577 x_lab = 'Alpha' 578 579 if scale_weights: 580 max_norms = dict() 581 for ct in set(df['Class']): 582 g_rows = df['Class'] == ct 583 max_norms[ct] = np.max(df[g_rows]['Kernel Weight']) 584 scale_cols = ['Class', 'Kernel Weight'] 585 586 new = df[scale_cols].apply(lambda x: x[1] / max_norms[x[0]], axis=1) 587 df['Kernel Weight'] = new 588 589 l_title = 'Scaled\nKernel Weight' 590 591 else: 592 l_title = 'Kernel Weight' 593 594 595 plot = (ggplot(df, aes(x=x_lab, y='Group', fill='Kernel Weight', color='Kernel Weight')) 596 + geom_point(size=5) 597 + scale_fill_gradient(high=high, low=low) 598 + scale_color_gradient(high=high, low=low) 599 + theme_classic() 600 + theme( 601 figure_size=fig_size, 602 axis_text=element_text(weight='bold', size=10), 603 axis_text_x=element_text(rotation=90), 604 axis_title=element_text(weight='bold', size=12), 605 axis_title_y=element_blank(), 606 legend_title=element_text(text=l_title, weight='bold', size=12), 607 legend_text=element_text(weight='bold', size=10) 608 )) 609 610 return plot
Plots a dotplot of kernel weights with groups on the y-axis and alpha on the x-axis if binary result. If a multiclass result, one alpha is used per class and the x-axis is class.
Parameters
- result (dict):
The output of
scmkl.run
. - n_groups (int): The number of top groups to plot. Not recommended for multiclass results.
- class_lab (str | None):
For multiclass results, if
not None
, will only plot group weights forclass_lab
. - low (str): The color for low kernel weight.
- high (str): The color for high kernel weight.
- alpha (None | float):
The alpha parameter to create figure for. If
None
, the smallest alpha is used. - scale_weights (bool):
If
True
, the the kernel weights will be scaled for each within each class.
Returns
- plot (plotnine.ggplot.ggplot): A barplot of weights.
Examples
>>> result = scmkl.run(adata, alpha_list)
>>> plot = scmkl.plot_weights(result)
374def weights_heatmap(result, n_groups: None | int=None, 375 class_lab: str | None=None, low: str='white', 376 high: str='red', alpha: float | None=None, 377 scale_weights: bool=False): 378 """ 379 Plots a heatmap of kernel weights with groups on the y-axis and 380 alpha on the x-axis if binary result. If a multiclass result, one 381 alpha is used per class and the x-axis is class. 382 383 Parameters 384 ---------- 385 result : dict 386 The output of `scmkl.run()`. 387 388 n_groups : int 389 The number of top groups to plot. Not recommended for 390 multiclass results. 391 392 class_lab : str | None 393 For multiclass results, if `not None`, will only plot group 394 weights for `class_lab`. 395 396 low : str 397 The color for low kernel weight. 398 399 high : str 400 The color for high kernel weight. 401 402 alpha : None | float 403 The alpha parameter to create figure for. If `None`, the 404 smallest alpha is used. 405 406 scale_weights : bool 407 If `True`, the the kernel weights will be scaled for each group 408 within each class. Ignored if result is from a binary 409 classification. 410 411 Returns 412 ------- 413 plot : plotnine.ggplot.ggplot 414 A heatmap of weights. 415 416 Examples 417 -------- 418 >>> result = scmkl.run(adata, alpha_list) 419 >>> plot = scmkl.plot_weights(result) 420 421  422 """ 423 is_multi, is_many = _parse_result_type(result) 424 assert not is_many, "This function only supports single results" 425 426 if type(class_lab) is str: 427 result = result[class_lab] 428 429 df = get_weights(result) 430 df['Group'] = format_group_names(df['Group'], ['Markers']) 431 432 # Filtering and sorting values 433 sum_df = df.groupby('Group')['Kernel Weight'].sum() 434 sum_df = sum_df.reset_index() 435 order = sort_groups(sum_df)[::-1] 436 df['Group'] = pd.Categorical(df['Group'], categories=order) 437 438 if type(n_groups) is int: 439 sum_df = sum_df.sort_values(by='Kernel Weight', ascending=False) 440 top_groups = sum_df.iloc[0:n_groups]['Group'].to_numpy() 441 df = df[np.isin(df['Group'], top_groups)] 442 else: 443 n_groups = len(set(df['Group'])) 444 445 df['Alpha'] = pd.Categorical(df['Alpha'], np.unique(df['Alpha'])) 446 447 if n_groups > 40: 448 fig_size = (7,8) 449 elif n_groups < 25: 450 fig_size = (7,6) 451 else: 452 fig_size = (7,8) 453 454 if 'Class' in df.columns: 455 alpha = _get_alpha(alpha, result, is_multi) 456 df = df[df['Alpha'] == alpha] 457 x_lab = 'Class' 458 else: 459 x_lab = 'Alpha' 460 461 if scale_weights and is_multi: 462 max_norms = dict() 463 for ct in set(df['Class']): 464 g_rows = df['Class'] == ct 465 max_norms[ct] = np.max(df[g_rows]['Kernel Weight']) 466 scale_cols = ['Class', 'Kernel Weight'] 467 468 new = df[scale_cols].apply(lambda x: x[1] / max_norms[x[0]], axis=1) 469 df['Kernel Weight'] = new 470 471 l_title = 'Scaled\nKernel Weight' 472 473 else: 474 l_title = 'Kernel Weight' 475 476 plot = (ggplot(df, aes(x=x_lab, y='Group', fill='Kernel Weight')) 477 + geom_tile(color='black') 478 + scale_fill_gradient(high=high, low=low) 479 + theme_classic() 480 + theme( 481 figure_size=fig_size, 482 axis_text=element_text(weight='bold', size=10), 483 axis_text_x=element_text(rotation=90), 484 axis_title=element_text(weight='bold', size=12), 485 axis_title_y=element_blank(), 486 legend_title=element_text(text=l_title, weight='bold', size=12), 487 legend_text=element_text(weight='bold', size=10) 488 )) 489 490 return plot
Plots a heatmap of kernel weights with groups on the y-axis and alpha on the x-axis if binary result. If a multiclass result, one alpha is used per class and the x-axis is class.
Parameters
- result (dict):
The output of
scmkl.run
. - n_groups (int): The number of top groups to plot. Not recommended for multiclass results.
- class_lab (str | None):
For multiclass results, if
not None
, will only plot group weights forclass_lab
. - low (str): The color for low kernel weight.
- high (str): The color for high kernel weight.
- alpha (None | float):
The alpha parameter to create figure for. If
None
, the smallest alpha is used. - scale_weights (bool):
If
True
, the the kernel weights will be scaled for each group within each class. Ignored if result is from a binary classification.
Returns
- plot (plotnine.ggplot.ggplot): A heatmap of weights.
Examples
>>> result = scmkl.run(adata, alpha_list)
>>> plot = scmkl.plot_weights(result)