scmkl.one_v_rest

  1import gc
  2import tracemalloc
  3import numpy as np
  4import pandas as pd
  5import anndata as ad
  6from sklearn.metrics import f1_score
  7
  8
  9from scmkl.run import run
 10from scmkl.calculate_z import calculate_z
 11from scmkl.multimodal_processing import multimodal_processing
 12from scmkl._checks import _check_adatas
 13
 14
 15def _eval_labels(cell_labels: np.ndarray, train_indices: np.ndarray, 
 16                  test_indices: np.ndarray) -> np.ndarray:
 17    """
 18    Takes an array of multiclass cell labels and returns a unique array 
 19    of cell labels to test for.
 20
 21    Parameters
 22    ----------
 23    cell_labels : np.ndarray
 24        Cell labels that coorespond to an AnnData object.
 25
 26    train_indices : np.ndarray
 27        Indices for the training samples in an AnnData object.
 28    
 29    test_indices : np.ndarray
 30        Indices for the testing samples in an AnnData object.
 31
 32    remove_labels : bool
 33        If `True`, models will only be created for cell labels in both 
 34        the training and test data, if `False`, models will be generated
 35        for all cell labels in the training data.
 36
 37    Returns
 38    -------
 39    uniq_labels : np.ndarray
 40        Returns a numpy array of unique cell labels to be iterated 
 41        through during one versus all setups.
 42    """
 43    train_uniq_labels = np.unique(cell_labels[train_indices])
 44    test_uniq_labels = np.unique(cell_labels[test_indices])
 45
 46    # Getting only labels in both training and testing sets
 47    uniq_labels = np.intersect1d(train_uniq_labels, test_uniq_labels)
 48
 49    # Ensuring that at least one cell type label between the two data
 50    #   are the same
 51    cl_intersect = np.intersect1d(train_uniq_labels, test_uniq_labels)
 52    assert len(cl_intersect) > 0, ("There are no common labels between cells "
 53                                   "in the training and testing samples")
 54
 55    return uniq_labels
 56
 57
 58def get_prob_table(results : dict, alpha: float | dict):
 59    """
 60    Takes a results dictionary with class and probabilities keys and 
 61    returns a table of probabilities for each class and the most 
 62    probable class for each cell.
 63
 64    Parameters
 65    ----------
 66    results : dict
 67        A nested dictionary that contains a dictionary for each class 
 68        containing probabilities for each cell class.
 69
 70    alpha : float | dict
 71        A float for which model probabilities should be evaluated 
 72        for.
 73
 74    Returns
 75    -------
 76    prob_table : pd.DataFrame
 77        Each column is a cell class and the elements are the
 78        class probability outputs from the model.
 79
 80    pred_class : list[str]
 81        The most probable cell classes respective to the training set 
 82        cells. 
 83
 84    low_conf : list[bool]
 85        A bool list where `True`, sample max probability is less than 
 86        0.5.
 87    """
 88    if isinstance(alpha, float):
 89        prob_table = {class_ : results[class_]['Probabilities'][alpha][class_]
 90                    for class_ in results.keys()}
 91    else:
 92        prob_table = {class_ : list()
 93                      for class_ in alpha.keys()}
 94        for class_ in results.keys():
 95            cur_alpha = alpha[class_]
 96            prob_table[class_] = results[class_]['Probabilities'][cur_alpha][class_]
 97
 98    prob_table = pd.DataFrame(prob_table)
 99
100    pred_class = []
101    maxes = []
102
103    for i, row in prob_table.iterrows():
104        row_max = np.max(row)
105        indices = np.where(row == row_max)
106        prediction = prob_table.columns[indices]
107
108        if len(prediction) > 1:
109            prediction = " and ".join(prediction)
110        else:
111            prediction = prediction[0]
112
113        pred_class.append(prediction)
114        maxes.append(row_max)
115
116    maxes = np.round(maxes, 0)
117    low_conf = np.invert(np.array(maxes, dtype = np.bool_))
118
119    return prob_table, pred_class, low_conf
120
121
122def per_model_summary(results: dict, uniq_labels: np.ndarray | list | tuple, 
123                      alpha: float) -> pd.DataFrame:
124    """
125    Takes the results dictionary from `scmkl.one_v_rest()` and adds a 
126    summary dataframe show metrics for each model generated from the 
127    runs.
128
129    Parameters
130    ----------
131    results : dict
132        Results from `scmkl.one_v_rest()`.
133
134    uniq_labels : array_like
135        Unique cell classes from the runs.
136
137    alpha : float | dict
138        The alpha for creating the summary from.
139
140    Returns
141    -------
142    summary_df : pd.DataFrame
143        Dataframe with classes on rows and metrics as cols.
144    """
145    # Getting metrics availible in results
146    if isinstance(alpha, dict):
147        alpha_key = list(alpha.keys())[0]
148        alpha_key = alpha[alpha_key]
149        avail_mets = list(results[uniq_labels[0]]['Metrics'][alpha_key])
150    else:
151        avail_mets = list(results[uniq_labels[0]]['Metrics'][alpha])
152
153    summary_df = {metric : list()
154                  for metric in avail_mets}
155    summary_df['Class'] = uniq_labels
156
157    for lab in summary_df['Class']:
158        for met in avail_mets:
159            if isinstance(alpha, dict):
160                cur_alpha = alpha[lab]
161            else:
162                cur_alpha = alpha
163
164            val = results[lab]['Metrics'][cur_alpha][met]
165            summary_df[met].append(val)
166
167    return pd.DataFrame(summary_df)
168
169
170def get_class_train(train_indices: np.ndarray,
171                    cell_labels: np.ndarray | list | pd.Series,
172                    seed_obj: np.random._generator.Generator,
173                    other_factor = 1.5):
174    """
175    This function returns a dict with each entry being a set of 
176    training indices for each cell class to be used in 
177    `scmkl.one_v_rest()`.
178
179    Parameters
180    ----------
181    train_indices : np.ndarray
182        The indices in the `ad.AnnData` object of samples availible to 
183        train on.
184
185    cell_labels : array_like
186        The identity of all cells in the anndata object.
187
188    seed_obj : np.random._generator.Generator
189        The seed object used to randomly sample non-target samples.
190
191    other_factor : float
192        The ratio of cells to sample for the other class for each 
193        model. For example, if classifying B cells with 100 B cells in 
194        training, if `other_factor=1`, 100 cells that are not B cells 
195        will be trained on with the B cells.
196
197    Returns
198    -------
199    train_idx : dict
200        Keys are cell classes and values are the train indices to 
201        train scmkl that include both target and non-target samples.
202    """
203    uniq_labels = np.unique(cell_labels)
204    train_idx = dict()
205
206    if isinstance(cell_labels, pd.Series):
207        cell_labels = cell_labels.to_numpy()
208    elif isinstance(cell_labels, list):
209        cell_labels = np.array(cell_labels)
210
211    for lab in uniq_labels:
212        target_pos = np.where(lab == cell_labels[train_indices])[0]
213        overlap = np.isin(target_pos, train_indices)
214
215        target_pos = target_pos[overlap]
216        other_pos = np.setdiff1d(train_indices, target_pos)
217
218        if (other_factor*target_pos.shape[0]) <= other_pos.shape[0]:
219            n_samples = int(other_factor*target_pos.shape[0])
220        else:
221            n_samples = other_pos.shape[0]
222
223        other_pos = seed_obj.choice(other_pos, n_samples, False)
224
225        lab_train = np.concatenate([target_pos, other_pos])
226        train_idx[lab] = lab_train.copy()
227
228    return train_idx
229
230
231def one_v_rest(adatas : list | ad.AnnData, names : list, 
232               alpha_params : np.ndarray, tfidf : list=None, batches: int=10, 
233               batch_size: int=100, train_dict: dict=None, 
234               force_balance: bool=False, other_factor: float=1.0)-> dict:
235    """
236    For each cell class, creates model(s) comparing that class to all 
237    others. Then, predicts on the training data using `scmkl.run()`.
238    Only labels in both training and testing will be run.
239
240    Parameters
241    ----------
242    adatas : list[AnnData]
243        List of `ad.AnnData` objects created by `create_adata()` 
244        where each `ad.AnnData` is one modality and composed of both 
245        training and testing samples. Requires that `'train_indices'`
246        and `'test_indices'` are the same between all `ad.AnnData`s.
247
248    names : list[str]
249        String variables that describe each modality respective to 
250        `adatas` for labeling.
251        
252    alpha_params : np.ndarray | float | dict
253        If is `dict`, expects keys to correspond to each unique label 
254        with float as key (ideally would be the output of 
255        scmkl.optimize_alpha). Else, array of alpha values to create 
256        each model with or a float to run with a single alpha.
257
258    tfidf : list[bool]
259        If element `i` is `True`, `adatas[i]` will be TF-IDF 
260        normalized. If `None`, no views will be TF-IDF normalized.
261
262    batches : int
263        The number of batches to use for the distance calculation. 
264        This will average the result of `batches` distance calculations 
265        of `batch_size` randomly sampled cells. More batches will 
266        converge to population distance values at the cost of 
267        scalability.
268
269    batch_size : int
270        The number of cells to include per batch for distance
271        calculations. Higher batch size will converge to population
272        distance values at the cost of scalability.
273        If `batches*batch_size > num_training_cells`,
274        `batch_size` will be reduced to 
275        `int(num_training_cells / batches)`.
276
277    force_balance : bool
278        If `True`, training sets will be balanced to reduce class label 
279        imbalance. Defaults to `False`.
280
281    other_factor : float
282        The ratio of cells to sample for the other class for each 
283        model. For example, if classifying B cells with 100 B cells in 
284        training, if `other_factor=1`, 100 cells that are not B cells 
285        will be trained on with the B cells.
286
287    Returns
288    -------
289    results : dict
290        Contains keys for each cell class with results from cell class
291        versus all other samples. See `scmkl.run()` for futher details. 
292        Will also include a probablilities table with the predictions 
293        from each model.
294
295    Examples
296    --------
297    >>> adata = scmkl.create_adata(X = data_mat, 
298    ...                            feature_names = gene_names, 
299    ...                            group_dict = group_dict)
300    >>>
301    >>> results = scmkl.one_v_rest(adatas = [adata], names = ['rna'],
302    ...                           alpha_list = np.array([0.05, 0.1]),
303    ...                           tfidf = [False])
304    >>>
305    >>> adata.keys()
306    dict_keys(['B cells', 'Monocytes', 'Dendritic cells', ...])
307    """
308    if isinstance(adatas, ad.AnnData):
309        adatas = [adatas]
310    if isinstance(tfidf, type(None)):
311        tfidf = len(adatas)*[False]
312
313    _check_adatas(adatas, check_obs=True, check_uns=True)
314
315    # Want to retain all original train indices
316    train_indices = adatas[0].uns['train_indices'].copy()
317    test_indices = adatas[0].uns['test_indices'].copy()
318
319    uniq_labels = _eval_labels(cell_labels = adatas[0].obs['labels'], 
320                               train_indices = train_indices,
321                               test_indices = test_indices)
322
323    if (len(adatas) == 1) and ('Z_train' not in adatas[0].uns.keys()):
324        adata = calculate_z(adatas[0], n_features = 5000, 
325                            batches=batches, batch_size=batch_size)
326    elif len(adatas) > 1:
327        adata = multimodal_processing(adatas=adatas, 
328                                      names=names, 
329                                      tfidf=tfidf,
330                                      batches=batches,
331                                      batch_size=batch_size)
332    else:
333        adata = adatas[0].copy()
334
335    # Preventing multiple copies of adata(s) in memory
336    del adatas
337    gc.collect()
338
339    # Need obj for capturing results
340    results = dict()
341
342    # Capturing cell labels to regenerate at each comparison
343    cell_labels = np.array(adata.obs['labels'].copy())
344
345    # Capturing perfect train/test splits for each class
346    if train_dict:
347        train_idx = train_dict
348    else:
349        if force_balance:
350            train_idx = get_class_train(adata.uns['train_indices'], 
351                                        cell_labels, 
352                                        adata.uns['seed_obj'],
353                                        other_factor)
354    tracemalloc.start()
355    for label in uniq_labels:
356
357        print(f"Comparing {label} to other types", flush = True)
358        cur_labels = cell_labels.copy()
359        cur_labels[cell_labels != label] = 'other'
360
361        # Need cur_label vs rest to run model
362        adata.obs['labels'] = cur_labels
363
364        if force_balance or train_dict:
365            adata.uns['train_indices'] = train_idx[label]
366
367        # Will only run scMKL with tuned alphas
368        if isinstance(alpha_params, dict):
369            alpha_list = np.array([alpha_params[label]])
370        elif isinstance(alpha_params, float):
371            alpha_list = np.array([alpha_params])
372        else:
373            alpha_list = alpha_params
374        
375        # Running scMKL
376        results[label] = run(adata, alpha_list, return_probs=True)
377        gc.collect()
378
379    # Getting final predictions
380    if isinstance(alpha_params, dict):
381        alpha = alpha_params
382    else:
383        alpha = np.min(alpha_params)
384
385    prob_table, pred_class, low_conf = get_prob_table(results, alpha)
386    macro_f1 = f1_score(cell_labels[adata.uns['test_indices']], 
387                        pred_class, average='macro')
388
389    model_summary = per_model_summary(results, uniq_labels, alpha)
390
391    # Global adata obj will be permanently changed if not reset
392    adata.obs['labels'] = cell_labels
393    adata.uns['train_indices'] = train_indices
394
395    # Need to document vars, probs, and stats
396    results['Per_model_summary'] = model_summary
397    results['Classes'] = uniq_labels
398    results['Probability_table'] = prob_table
399    results['Predicted_class'] = pred_class
400    results['Truth_labels'] = cell_labels[adata.uns['test_indices']]
401    results['Low_confidence'] = low_conf
402    results['Macro_F1-Score'] = macro_f1
403
404    if force_balance or train_dict:
405        results['Training_indices'] = train_idx
406
407    return results
def get_prob_table(results: dict, alpha: float | dict):
 59def get_prob_table(results : dict, alpha: float | dict):
 60    """
 61    Takes a results dictionary with class and probabilities keys and 
 62    returns a table of probabilities for each class and the most 
 63    probable class for each cell.
 64
 65    Parameters
 66    ----------
 67    results : dict
 68        A nested dictionary that contains a dictionary for each class 
 69        containing probabilities for each cell class.
 70
 71    alpha : float | dict
 72        A float for which model probabilities should be evaluated 
 73        for.
 74
 75    Returns
 76    -------
 77    prob_table : pd.DataFrame
 78        Each column is a cell class and the elements are the
 79        class probability outputs from the model.
 80
 81    pred_class : list[str]
 82        The most probable cell classes respective to the training set 
 83        cells. 
 84
 85    low_conf : list[bool]
 86        A bool list where `True`, sample max probability is less than 
 87        0.5.
 88    """
 89    if isinstance(alpha, float):
 90        prob_table = {class_ : results[class_]['Probabilities'][alpha][class_]
 91                    for class_ in results.keys()}
 92    else:
 93        prob_table = {class_ : list()
 94                      for class_ in alpha.keys()}
 95        for class_ in results.keys():
 96            cur_alpha = alpha[class_]
 97            prob_table[class_] = results[class_]['Probabilities'][cur_alpha][class_]
 98
 99    prob_table = pd.DataFrame(prob_table)
100
101    pred_class = []
102    maxes = []
103
104    for i, row in prob_table.iterrows():
105        row_max = np.max(row)
106        indices = np.where(row == row_max)
107        prediction = prob_table.columns[indices]
108
109        if len(prediction) > 1:
110            prediction = " and ".join(prediction)
111        else:
112            prediction = prediction[0]
113
114        pred_class.append(prediction)
115        maxes.append(row_max)
116
117    maxes = np.round(maxes, 0)
118    low_conf = np.invert(np.array(maxes, dtype = np.bool_))
119
120    return prob_table, pred_class, low_conf

Takes a results dictionary with class and probabilities keys and returns a table of probabilities for each class and the most probable class for each cell.

Parameters
  • results (dict): A nested dictionary that contains a dictionary for each class containing probabilities for each cell class.
  • alpha (float | dict): A float for which model probabilities should be evaluated for.
Returns
  • prob_table (pd.DataFrame): Each column is a cell class and the elements are the class probability outputs from the model.
  • pred_class (list[str]): The most probable cell classes respective to the training set cells.
  • low_conf (list[bool]): A bool list where True, sample max probability is less than 0.5.
def per_model_summary( results: dict, uniq_labels: numpy.ndarray | list | tuple, alpha: float) -> pandas.core.frame.DataFrame:
123def per_model_summary(results: dict, uniq_labels: np.ndarray | list | tuple, 
124                      alpha: float) -> pd.DataFrame:
125    """
126    Takes the results dictionary from `scmkl.one_v_rest()` and adds a 
127    summary dataframe show metrics for each model generated from the 
128    runs.
129
130    Parameters
131    ----------
132    results : dict
133        Results from `scmkl.one_v_rest()`.
134
135    uniq_labels : array_like
136        Unique cell classes from the runs.
137
138    alpha : float | dict
139        The alpha for creating the summary from.
140
141    Returns
142    -------
143    summary_df : pd.DataFrame
144        Dataframe with classes on rows and metrics as cols.
145    """
146    # Getting metrics availible in results
147    if isinstance(alpha, dict):
148        alpha_key = list(alpha.keys())[0]
149        alpha_key = alpha[alpha_key]
150        avail_mets = list(results[uniq_labels[0]]['Metrics'][alpha_key])
151    else:
152        avail_mets = list(results[uniq_labels[0]]['Metrics'][alpha])
153
154    summary_df = {metric : list()
155                  for metric in avail_mets}
156    summary_df['Class'] = uniq_labels
157
158    for lab in summary_df['Class']:
159        for met in avail_mets:
160            if isinstance(alpha, dict):
161                cur_alpha = alpha[lab]
162            else:
163                cur_alpha = alpha
164
165            val = results[lab]['Metrics'][cur_alpha][met]
166            summary_df[met].append(val)
167
168    return pd.DataFrame(summary_df)

Takes the results dictionary from scmkl.one_v_rest and adds a summary dataframe show metrics for each model generated from the runs.

Parameters
  • results (dict): Results from scmkl.one_v_rest.
  • uniq_labels (array_like): Unique cell classes from the runs.
  • alpha (float | dict): The alpha for creating the summary from.
Returns
  • summary_df (pd.DataFrame): Dataframe with classes on rows and metrics as cols.
def get_class_train( train_indices: numpy.ndarray, cell_labels: numpy.ndarray | list | pandas.core.series.Series, seed_obj: numpy.random._generator.Generator, other_factor=1.5):
171def get_class_train(train_indices: np.ndarray,
172                    cell_labels: np.ndarray | list | pd.Series,
173                    seed_obj: np.random._generator.Generator,
174                    other_factor = 1.5):
175    """
176    This function returns a dict with each entry being a set of 
177    training indices for each cell class to be used in 
178    `scmkl.one_v_rest()`.
179
180    Parameters
181    ----------
182    train_indices : np.ndarray
183        The indices in the `ad.AnnData` object of samples availible to 
184        train on.
185
186    cell_labels : array_like
187        The identity of all cells in the anndata object.
188
189    seed_obj : np.random._generator.Generator
190        The seed object used to randomly sample non-target samples.
191
192    other_factor : float
193        The ratio of cells to sample for the other class for each 
194        model. For example, if classifying B cells with 100 B cells in 
195        training, if `other_factor=1`, 100 cells that are not B cells 
196        will be trained on with the B cells.
197
198    Returns
199    -------
200    train_idx : dict
201        Keys are cell classes and values are the train indices to 
202        train scmkl that include both target and non-target samples.
203    """
204    uniq_labels = np.unique(cell_labels)
205    train_idx = dict()
206
207    if isinstance(cell_labels, pd.Series):
208        cell_labels = cell_labels.to_numpy()
209    elif isinstance(cell_labels, list):
210        cell_labels = np.array(cell_labels)
211
212    for lab in uniq_labels:
213        target_pos = np.where(lab == cell_labels[train_indices])[0]
214        overlap = np.isin(target_pos, train_indices)
215
216        target_pos = target_pos[overlap]
217        other_pos = np.setdiff1d(train_indices, target_pos)
218
219        if (other_factor*target_pos.shape[0]) <= other_pos.shape[0]:
220            n_samples = int(other_factor*target_pos.shape[0])
221        else:
222            n_samples = other_pos.shape[0]
223
224        other_pos = seed_obj.choice(other_pos, n_samples, False)
225
226        lab_train = np.concatenate([target_pos, other_pos])
227        train_idx[lab] = lab_train.copy()
228
229    return train_idx

This function returns a dict with each entry being a set of training indices for each cell class to be used in scmkl.one_v_rest.

Parameters
  • train_indices (np.ndarray): The indices in the ad.AnnData object of samples availible to train on.
  • cell_labels (array_like): The identity of all cells in the anndata object.
  • seed_obj (np.random._generator.Generator): The seed object used to randomly sample non-target samples.
  • other_factor (float): The ratio of cells to sample for the other class for each model. For example, if classifying B cells with 100 B cells in training, if other_factor=1, 100 cells that are not B cells will be trained on with the B cells.
Returns
  • train_idx (dict): Keys are cell classes and values are the train indices to train scmkl that include both target and non-target samples.
def one_v_rest( adatas: list | anndata._core.anndata.AnnData, names: list, alpha_params: numpy.ndarray, tfidf: list = None, batches: int = 10, batch_size: int = 100, train_dict: dict = None, force_balance: bool = False, other_factor: float = 1.0) -> dict:
232def one_v_rest(adatas : list | ad.AnnData, names : list, 
233               alpha_params : np.ndarray, tfidf : list=None, batches: int=10, 
234               batch_size: int=100, train_dict: dict=None, 
235               force_balance: bool=False, other_factor: float=1.0)-> dict:
236    """
237    For each cell class, creates model(s) comparing that class to all 
238    others. Then, predicts on the training data using `scmkl.run()`.
239    Only labels in both training and testing will be run.
240
241    Parameters
242    ----------
243    adatas : list[AnnData]
244        List of `ad.AnnData` objects created by `create_adata()` 
245        where each `ad.AnnData` is one modality and composed of both 
246        training and testing samples. Requires that `'train_indices'`
247        and `'test_indices'` are the same between all `ad.AnnData`s.
248
249    names : list[str]
250        String variables that describe each modality respective to 
251        `adatas` for labeling.
252        
253    alpha_params : np.ndarray | float | dict
254        If is `dict`, expects keys to correspond to each unique label 
255        with float as key (ideally would be the output of 
256        scmkl.optimize_alpha). Else, array of alpha values to create 
257        each model with or a float to run with a single alpha.
258
259    tfidf : list[bool]
260        If element `i` is `True`, `adatas[i]` will be TF-IDF 
261        normalized. If `None`, no views will be TF-IDF normalized.
262
263    batches : int
264        The number of batches to use for the distance calculation. 
265        This will average the result of `batches` distance calculations 
266        of `batch_size` randomly sampled cells. More batches will 
267        converge to population distance values at the cost of 
268        scalability.
269
270    batch_size : int
271        The number of cells to include per batch for distance
272        calculations. Higher batch size will converge to population
273        distance values at the cost of scalability.
274        If `batches*batch_size > num_training_cells`,
275        `batch_size` will be reduced to 
276        `int(num_training_cells / batches)`.
277
278    force_balance : bool
279        If `True`, training sets will be balanced to reduce class label 
280        imbalance. Defaults to `False`.
281
282    other_factor : float
283        The ratio of cells to sample for the other class for each 
284        model. For example, if classifying B cells with 100 B cells in 
285        training, if `other_factor=1`, 100 cells that are not B cells 
286        will be trained on with the B cells.
287
288    Returns
289    -------
290    results : dict
291        Contains keys for each cell class with results from cell class
292        versus all other samples. See `scmkl.run()` for futher details. 
293        Will also include a probablilities table with the predictions 
294        from each model.
295
296    Examples
297    --------
298    >>> adata = scmkl.create_adata(X = data_mat, 
299    ...                            feature_names = gene_names, 
300    ...                            group_dict = group_dict)
301    >>>
302    >>> results = scmkl.one_v_rest(adatas = [adata], names = ['rna'],
303    ...                           alpha_list = np.array([0.05, 0.1]),
304    ...                           tfidf = [False])
305    >>>
306    >>> adata.keys()
307    dict_keys(['B cells', 'Monocytes', 'Dendritic cells', ...])
308    """
309    if isinstance(adatas, ad.AnnData):
310        adatas = [adatas]
311    if isinstance(tfidf, type(None)):
312        tfidf = len(adatas)*[False]
313
314    _check_adatas(adatas, check_obs=True, check_uns=True)
315
316    # Want to retain all original train indices
317    train_indices = adatas[0].uns['train_indices'].copy()
318    test_indices = adatas[0].uns['test_indices'].copy()
319
320    uniq_labels = _eval_labels(cell_labels = adatas[0].obs['labels'], 
321                               train_indices = train_indices,
322                               test_indices = test_indices)
323
324    if (len(adatas) == 1) and ('Z_train' not in adatas[0].uns.keys()):
325        adata = calculate_z(adatas[0], n_features = 5000, 
326                            batches=batches, batch_size=batch_size)
327    elif len(adatas) > 1:
328        adata = multimodal_processing(adatas=adatas, 
329                                      names=names, 
330                                      tfidf=tfidf,
331                                      batches=batches,
332                                      batch_size=batch_size)
333    else:
334        adata = adatas[0].copy()
335
336    # Preventing multiple copies of adata(s) in memory
337    del adatas
338    gc.collect()
339
340    # Need obj for capturing results
341    results = dict()
342
343    # Capturing cell labels to regenerate at each comparison
344    cell_labels = np.array(adata.obs['labels'].copy())
345
346    # Capturing perfect train/test splits for each class
347    if train_dict:
348        train_idx = train_dict
349    else:
350        if force_balance:
351            train_idx = get_class_train(adata.uns['train_indices'], 
352                                        cell_labels, 
353                                        adata.uns['seed_obj'],
354                                        other_factor)
355    tracemalloc.start()
356    for label in uniq_labels:
357
358        print(f"Comparing {label} to other types", flush = True)
359        cur_labels = cell_labels.copy()
360        cur_labels[cell_labels != label] = 'other'
361
362        # Need cur_label vs rest to run model
363        adata.obs['labels'] = cur_labels
364
365        if force_balance or train_dict:
366            adata.uns['train_indices'] = train_idx[label]
367
368        # Will only run scMKL with tuned alphas
369        if isinstance(alpha_params, dict):
370            alpha_list = np.array([alpha_params[label]])
371        elif isinstance(alpha_params, float):
372            alpha_list = np.array([alpha_params])
373        else:
374            alpha_list = alpha_params
375        
376        # Running scMKL
377        results[label] = run(adata, alpha_list, return_probs=True)
378        gc.collect()
379
380    # Getting final predictions
381    if isinstance(alpha_params, dict):
382        alpha = alpha_params
383    else:
384        alpha = np.min(alpha_params)
385
386    prob_table, pred_class, low_conf = get_prob_table(results, alpha)
387    macro_f1 = f1_score(cell_labels[adata.uns['test_indices']], 
388                        pred_class, average='macro')
389
390    model_summary = per_model_summary(results, uniq_labels, alpha)
391
392    # Global adata obj will be permanently changed if not reset
393    adata.obs['labels'] = cell_labels
394    adata.uns['train_indices'] = train_indices
395
396    # Need to document vars, probs, and stats
397    results['Per_model_summary'] = model_summary
398    results['Classes'] = uniq_labels
399    results['Probability_table'] = prob_table
400    results['Predicted_class'] = pred_class
401    results['Truth_labels'] = cell_labels[adata.uns['test_indices']]
402    results['Low_confidence'] = low_conf
403    results['Macro_F1-Score'] = macro_f1
404
405    if force_balance or train_dict:
406        results['Training_indices'] = train_idx
407
408    return results

For each cell class, creates model(s) comparing that class to all others. Then, predicts on the training data using scmkl.run. Only labels in both training and testing will be run.

Parameters
  • adatas (list[AnnData]): List of ad.AnnData objects created by create_adata() where each ad.AnnData is one modality and composed of both training and testing samples. Requires that 'train_indices' and 'test_indices' are the same between all ad.AnnDatas.
  • names (list[str]): String variables that describe each modality respective to adatas for labeling.
  • alpha_params (np.ndarray | float | dict): If is dict, expects keys to correspond to each unique label with float as key (ideally would be the output of scmkl.optimize_alpha). Else, array of alpha values to create each model with or a float to run with a single alpha.
  • tfidf (list[bool]): If element i is True, adatas[i] will be TF-IDF normalized. If None, no views will be TF-IDF normalized.
  • batches (int): The number of batches to use for the distance calculation. This will average the result of batches distance calculations of batch_size randomly sampled cells. More batches will converge to population distance values at the cost of scalability.
  • batch_size (int): The number of cells to include per batch for distance calculations. Higher batch size will converge to population distance values at the cost of scalability. If batches*batch_size > num_training_cells, batch_size will be reduced to int(num_training_cells / batches).
  • force_balance (bool): If True, training sets will be balanced to reduce class label imbalance. Defaults to False.
  • other_factor (float): The ratio of cells to sample for the other class for each model. For example, if classifying B cells with 100 B cells in training, if other_factor=1, 100 cells that are not B cells will be trained on with the B cells.
Returns
  • results (dict): Contains keys for each cell class with results from cell class versus all other samples. See scmkl.run for futher details. Will also include a probablilities table with the predictions from each model.
Examples
>>> adata = scmkl.create_adata(X = data_mat, 
...                            feature_names = gene_names, 
...                            group_dict = group_dict)
>>>
>>> results = scmkl.one_v_rest(adatas = [adata], names = ['rna'],
...                           alpha_list = np.array([0.05, 0.1]),
...                           tfidf = [False])
>>>
>>> adata.keys()
dict_keys(['B cells', 'Monocytes', 'Dendritic cells', ...])