scmkl.plotting

  1import numpy as np
  2import pandas as pd
  3import matplotlib.pyplot as plt
  4import scanpy as sc
  5import anndata as ad
  6import muon
  7from muon import atac as ac
  8import itertools
  9from math import ceil
 10from sklearn import metrics
 11from plotnine import (ggplot, aes, theme_classic, ylim, element_text, theme,
 12                      geom_point, scale_x_reverse, annotate, geom_bar, 
 13                      coord_flip, element_blank, labs, geom_tile, 
 14                      scale_fill_gradient, facet_wrap, 
 15                      scale_color_manual, scale_color_gradient)
 16
 17from scmkl.dataframes import (_parse_result_type, get_weights, sort_groups, 
 18                              format_group_names)
 19
 20
 21def _get_alpha(alpha: None | float, result: dict, is_multiclass: bool):
 22    """
 23    Gets the smallest alpha from a results file. Works for both binary 
 24    and multiclass results.
 25    """
 26    if type(alpha) == float:
 27        return alpha
 28    
 29    if 'Alpha_star' in result.keys():
 30        return result['Alpha_star']
 31    
 32    if is_multiclass:
 33        classes = list(result['Classes'])
 34        alpha_list = list(result[classes[0]]['Norms'].keys())
 35        alpha = np.min(alpha_list)
 36
 37    else:
 38        alpha_list = list(result['Norms'].keys())
 39        alpha = np.min(alpha_list)
 40
 41    return alpha
 42
 43
 44def color_alpha_star(alphas, alpha_star, color):
 45    """
 46    Takes an array of alphas and returns a list of the same size where 
 47    each element is `'black'` except where 
 48    `alpha_star == alphas`, which will be `'gold'`.
 49
 50    Parameters
 51    ----------
 52    alphas : list | tuple | np.ndarray
 53        The 1D array of alphas.
 54
 55    alpha_star: float
 56        The best performing alpha from cross-validation.
 57
 58    color : str
 59        The color of all alphas other than `alpha_star`.
 60
 61    Returns
 62    -------
 63    c_array, c_dict : np.ndarray, dict
 64        `c_array` is the array of colors corresponding to alphas. 
 65        `c_dict` is the color dict with alphas as keys and color as 
 66        values.
 67    """
 68    c_array = np.array([color] * len(alphas), dtype='<U15')
 69    as_pos = np.where(alphas == alpha_star)[0]
 70    c_array[as_pos] = 'gold'
 71
 72    c_dict = {alphas[i] : c_array[i]
 73              for i in range(len(alphas))}
 74
 75    return c_array, c_dict
 76
 77
 78
 79def plot_conf_mat(results, title = '', cmap = None, normalize = True,
 80                          alpha = None, save = None) -> None:
 81    """
 82    Creates a confusion matrix from the output of scMKL.
 83
 84    Parameters
 85    ----------
 86    results : dict
 87        The output from either scmkl.run() or scmkl.one_v_rest()
 88        containing results from scMKL.
 89
 90    title : str
 91        The text to display at the top of the matrix.
 92
 93    cmap : matplotlib.colors.LinearSegmentedColormap
 94        The gradient of the values displayed from `matplotlib.pyplot`.
 95        If `None`, `'Purples'` is used see matplotlib color map 
 96        reference for more information. 
 97
 98    normalize : bool
 99        If `False`, plot the raw numbers. If `True`, plot the 
100        proportions.
101
102    alpha : None | float
103        Alpha that matrix should be created for. If `results` is from
104        `scmkl.one_v_all()`, this is ignored. If `None`, smallest alpha
105        will be used.
106
107    save : None | str
108        File path to save plot. If `None`, plot is not saved.
109
110    Returns
111    -------
112    None
113    
114    Examples
115    --------
116    >>> # Running scmkl and capturing results
117    >>> results = scmkl.run(adata = adata, alpha_list = alpha_list)
118    >>> 
119    >>> from matplotlib.pyplot import get_cmap
120    >>> 
121    >>> scmkl.plot_conf_mat(results, title = '', cmap = get_cmap('Blues'))
122
123    ![conf_mat](../tests/figures/plot_conf_mat_binary.png)
124
125    Citiation
126    ---------
127    http://scikit-learn.org/stable/auto_examples/model_selection/
128    plot_confusion_matrix.html
129    """
130    # Determining type of results
131    if ('Observed' in results.keys()) and ('Metrics' in results.keys()):
132        multi_class = False
133        names = np.unique(results['Observed'])
134    else:
135        multi_class = True
136        names = np.unique(results['Truth_labels'])
137
138    if multi_class:
139        cm = metrics.confusion_matrix(y_true = results['Truth_labels'], 
140                              y_pred = results['Predicted_class'], 
141                              labels = names)
142    else:
143        min_alpha = np.min(list(results['Metrics'].keys()))
144        alpha = alpha if alpha != None else min_alpha
145        cm = metrics.confusion_matrix(y_true = results['Observed'],
146                              y_pred = results['Predictions'][alpha],
147                              labels = names)
148
149    accuracy = np.trace(cm) / float(np.sum(cm))
150    misclass = 1 - accuracy
151
152    if cmap is None:
153        cmap = plt.get_cmap('Purples')
154
155    if normalize:
156        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
157
158    plt.figure(figsize=(8, 6))
159    plt.imshow(cm, interpolation='nearest', cmap=cmap)
160    plt.title(title)
161    plt.colorbar()
162
163    tick_marks = np.arange(len(names))
164    plt.xticks(tick_marks, names, rotation=45)
165    plt.yticks(tick_marks, names)
166
167
168    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
169    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
170        if normalize:
171            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
172                     horizontalalignment="center",
173                     color="white" if cm[i, j] > thresh else "black")
174        else:
175            plt.text(j, i, "{:,}".format(cm[i, j]),
176                     horizontalalignment="center",
177                     color="white" if cm[i, j] > thresh else "black")
178
179    acc_label = 'Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'
180    acc_label = acc_label.format(accuracy, misclass)
181
182    plt.tight_layout()
183    plt.ylabel('True label')
184    plt.xlabel(acc_label)
185
186    if save != None:
187        plt.savefig(save)
188        plt.clf()
189    else:
190        plt.show()
191
192    return None
193
194
195def plot_metric(summary_df : pd.DataFrame, alpha_star = None, 
196                x_axis: str='Alpha', color = 'red'):
197    """
198    Takes a data frame of model metrics and optionally alpha star and
199    creates a scatter plot given metrics against alpha values. For 
200    multiclass results, `alpha_star` is not shown and points are 
201    colored by class.
202    
203    Parameters
204    ----------
205    summary_df : pd.DataFrame
206        Dataframe created by `scmkl.get_summary()`.
207
208    alpha_star : None | float
209        If `not None`, a label will be added for tuned `alpha_star` 
210        being optimal model parameter for performance from cross 
211        validation on the training data. Can be calculated with 
212        `scmkl.optimize_alpha()`. Is ignored if `summary_df` is from a 
213        multiclass result.
214
215    x_axis : str
216        Must be either `'Alpha'` or `'Number of Selected Groups'`. Is 
217        the variable that will be plotted on the x-axis.
218
219    color : str
220        Color to make points on plot.
221
222    Returns
223    -------
224    metric_plot : plotnine.ggplot.ggplot
225        A plot with alpha values on x-axis and metric on y-axis.
226
227    Examples
228    --------
229    >>> results = scmkl.run(adata, alpha_list)
230    >>> summary_df = scmkl.get_summary(results)
231    >>> metric_plot = plot_metric(results)
232    >>>
233    >>> metric_plot.save('scMKL_performance.png')
234
235    ![metric_plot](../tests/figures/plot_metric_binary.png)
236    """
237    # Capturing metric from summary_df
238    metric_options = ['AUROC', 'Accuracy', 'F1-Score', 'Precision', 'Recall']
239    metric = np.intersect1d(metric_options, summary_df.columns)[0]
240
241    x_axis_ticks = np.unique(summary_df[x_axis])
242    summary_df['Alpha Star'] = summary_df['Alpha'] == alpha_star
243
244    if 'Class' in summary_df.columns:
245        color_lab = 'Class'
246    else:
247        c_dict = {
248            True : 'gold',
249            False : color
250            }
251        color_lab = 'Alpha Star'
252
253    if np.min(summary_df[metric]) < 0.5:
254        min_y = 0
255    else:
256        min_y = 0.6
257
258    plot = (ggplot(summary_df, aes(x = x_axis, y = metric, color=color_lab)) 
259            + geom_point(size=4)
260            + theme_classic()
261            + ylim(min_y, 1)
262            + scale_x_reverse(breaks=x_axis_ticks)
263            + theme(
264                axis_text=element_text(weight='bold', size=10),
265                axis_title=element_text(weight='bold', size=12),
266                legend_title=element_text(weight='bold', size=14),
267                legend_text=element_text(weight='bold', size=12)
268            )
269            )
270    
271    if not 'Class' in  summary_df:
272        plot += scale_color_manual(c_dict)
273        
274    return plot
275
276
277def weights_barplot(result, n_groups: int=1, alpha: None | float=None, 
278                 color: str='red'):
279    """
280    Plots the top `n_groups` weighted groups for each cell class. Works 
281    for a single scmkl result (either multiclass or binary).
282
283    Parameters
284    ----------
285    result : dict
286        The output of `scmkl.run()`.
287
288    n_groups : int
289        The number of top groups to plot for each cell class.
290
291    alpha : None | float
292        The alpha parameter to create figure for. If `None`, the 
293        smallest alpha is used.
294
295    color : str
296        The color the bars should be.
297
298    Returns
299    -------
300    plot : plotnine.ggplot.ggplot
301        A barplot of weights.
302
303    Examples
304    --------
305    >>> result = scmkl.run(adata, alpha_list)
306    >>> plot = scmkl.plot_weights(result)
307
308    ![weights_barplot](../tests/figures/weights_barplot_binary.png)
309    """
310    is_multi, is_many = _parse_result_type(result)
311    assert not is_many, "This function only supports single results"
312
313    alpha = _get_alpha(alpha, result, is_multi)
314    df = get_weights(result)
315
316    # Subsetting to only alpha and filtering to top groups
317    df = df[df['Alpha'] == alpha]
318    if is_multi:
319        g_order = set(df['Group'])
320        for ct in result['Classes']:
321            cor_ct = df['Class'] == ct
322            other_ct = ~cor_ct
323
324            temp_df = df.copy()
325            temp_df = temp_df[cor_ct]
326            temp_df = temp_df.sort_values('Kernel Weight', ascending=False)
327            top_groups = temp_df['Group'].iloc[0:n_groups].to_numpy()
328            
329            cor_groups = np.isin(df['Group'], top_groups)
330            filter_array = np.logical_and(cor_ct, cor_groups)
331
332            filter_array = np.logical_or(filter_array, other_ct)
333
334            df = df[filter_array]
335
336        df['Group'] = format_group_names(df['Group'], rm_words = ['Markers'])
337            
338
339    else:
340        df = df.sort_values('Kernel Weight', ascending=False)
341        df = df.iloc[0:n_groups]
342        df['Group'] = format_group_names(df['Group'], rm_words = ['Markers']) 
343        g_order = sort_groups(df)[::-1]
344        df['Group'] = pd.Categorical(df['Group'], g_order)
345
346    plot = (ggplot(df)
347            + theme_classic()
348            + coord_flip()
349            + labs(y=f'Kernel Weight (λ = {alpha})')
350            + theme(
351                axis_text=element_text(weight='bold', size=10),
352                axis_title=element_text(weight='bold', size=12),
353                axis_title_y=element_blank()
354            )
355            )
356
357    # This needs to be reworked for multiclass runs
358    if is_multi:
359        height = (3*ceil((len(set(df['Class'])) / 3)))
360        print(height)
361        plot += geom_bar(aes(x='Group', y='Kernel Weight'), 
362                         stat='identity', fill=color)
363        plot += facet_wrap('Class', scales='free', ncol=3)
364        plot += theme(figure_size=(15,height))
365    else:
366        plot += geom_bar(aes(x='Group', y='Kernel Weight'), 
367                         stat='identity', fill=color)
368        plot += theme(figure_size=(7, 9))
369
370    return plot
371
372
373def weights_heatmap(result, n_groups: None | int=None, 
374                    class_lab: str | None=None, low: str='white', 
375                    high: str='red', alpha: float | None=None,
376                    scale_weights: bool=False):
377    """
378    Plots a heatmap of kernel weights with groups on the y-axis and 
379    alpha on the x-axis if binary result. If a multiclass result, one 
380    alpha is used per class and the x-axis is class.
381
382    Parameters
383    ----------
384    result : dict
385        The output of `scmkl.run()`.
386
387    n_groups : int
388        The number of top groups to plot. Not recommended for 
389        multiclass results.
390
391    class_lab : str | None
392        For multiclass results, if `not None`, will only plot group 
393        weights for `class_lab`.
394
395    low : str
396        The color for low kernel weight.
397
398    high : str
399        The color for high kernel weight.
400
401    alpha : None | float
402        The alpha parameter to create figure for. If `None`, the 
403        smallest alpha is used.
404
405    scale_weights : bool
406        If `True`, the the kernel weights will be scaled for each group 
407        within each class. Ignored if result is from a binary 
408        classification.
409
410    Returns
411    -------
412    plot : plotnine.ggplot.ggplot
413        A heatmap of weights.
414
415    Examples
416    --------
417    >>> result = scmkl.run(adata, alpha_list)
418    >>> plot = scmkl.plot_weights(result)
419
420    ![weights_heatmap](../tests/figures/weights_heatmap_binary.png)
421    """
422    is_multi, is_many = _parse_result_type(result)
423    assert not is_many, "This function only supports single results"
424
425    if type(class_lab) is str:
426        result = result[class_lab]
427
428    df = get_weights(result)
429    df['Group'] = format_group_names(df['Group'], ['Markers'])
430
431    # Filtering and sorting values
432    sum_df = df.groupby('Group')['Kernel Weight'].sum()
433    sum_df = sum_df.reset_index()
434    order = sort_groups(sum_df)[::-1]
435    df['Group'] = pd.Categorical(df['Group'], categories=order)
436
437    if type(n_groups) is int:
438        sum_df = sum_df.sort_values(by='Kernel Weight', ascending=False)
439        top_groups = sum_df.iloc[0:n_groups]['Group'].to_numpy()
440        df = df[np.isin(df['Group'], top_groups)]
441    else:
442        n_groups = len(set(df['Group']))
443
444    df['Alpha'] = pd.Categorical(df['Alpha'], np.unique(df['Alpha']))
445
446    if n_groups > 40:
447        fig_size = (7,8)
448    elif n_groups < 25:
449        fig_size = (7,6)
450    else: 
451        fig_size = (7,8)
452
453    if 'Class' in df.columns:
454        alpha = _get_alpha(alpha, result, is_multi)
455        df = df[df['Alpha'] == alpha]
456        x_lab = 'Class'
457    else:
458        x_lab = 'Alpha'
459
460    if scale_weights and is_multi:
461        max_norms = dict()
462        for ct in set(df['Class']):
463            g_rows = df['Class'] == ct
464            max_norms[ct] = np.max(df[g_rows]['Kernel Weight'])
465            scale_cols = ['Class', 'Kernel Weight']
466
467        new = df[scale_cols].apply(lambda x: x[1] / max_norms[x[0]], axis=1)
468        df['Kernel Weight'] = new
469
470        l_title = 'Scaled\nKernel Weight'
471
472    else:
473        l_title = 'Kernel Weight'
474
475    plot = (ggplot(df, aes(x=x_lab, y='Group', fill='Kernel Weight'))
476            + geom_tile(color='black')
477            + scale_fill_gradient(high=high, low=low)
478            + theme_classic()
479            + theme(
480                figure_size=fig_size,
481                axis_text=element_text(weight='bold', size=10),
482                axis_text_x=element_text(rotation=90),
483                axis_title=element_text(weight='bold', size=12),
484                axis_title_y=element_blank(),
485                legend_title=element_text(text=l_title, weight='bold', size=12),
486                legend_text=element_text(weight='bold', size=10)
487            ))
488
489    return plot
490
491
492def weights_dotplot(result, n_groups: None | int=None, 
493                    class_lab: str | None=None, low: str='white', 
494                    high: str='red', alpha: float | None=None, 
495                    scale_weights: bool=False):
496    """
497    Plots a dotplot of kernel weights with groups on the y-axis and 
498    alpha on the x-axis if binary result. If a multiclass result, one 
499    alpha is used per class and the x-axis is class.
500
501    Parameters
502    ----------
503    result : dict
504        The output of `scmkl.run()`.
505
506    n_groups : int
507        The number of top groups to plot. Not recommended for 
508        multiclass results.
509
510    class_lab : str | None
511        For multiclass results, if `not None`, will only plot group 
512        weights for `class_lab`.
513
514    low : str
515        The color for low kernel weight.
516
517    high : str
518        The color for high kernel weight.
519
520    alpha : None | float
521        The alpha parameter to create figure for. If `None`, the 
522        smallest alpha is used.
523
524    scale_weights : bool
525        If `True`, the the kernel weights will be scaled for each 
526        within each class.
527
528    Returns
529    -------
530    plot : plotnine.ggplot.ggplot
531        A barplot of weights.
532
533    Examples
534    --------
535    >>> result = scmkl.run(adata, alpha_list)
536    >>> plot = scmkl.plot_weights(result)
537
538    ![weights_dotplot](../tests/figures/weights_dotplot_binary.png)
539    """
540    is_multi, is_many = _parse_result_type(result)
541    assert not is_many, "This function only supports single results"
542
543    if type(class_lab) is str:
544        result = result[class_lab]
545
546    df = get_weights(result)
547    df['Group'] = format_group_names(df['Group'], ['Markers'])
548
549    # Filtering and sorting values
550    sum_df = df.groupby('Group')['Kernel Weight'].sum()
551    sum_df = sum_df.reset_index()
552    order = sort_groups(sum_df)[::-1]
553    df['Group'] = pd.Categorical(df['Group'], categories=order)
554
555    if type(n_groups) is int:
556        sum_df = sum_df.sort_values(by='Kernel Weight', ascending=False)
557        top_groups = sum_df.iloc[0:n_groups]['Group'].to_numpy()
558        df = df[np.isin(df['Group'], top_groups)]
559    else:
560        n_groups = len(set(df['Group']))
561
562    df['Alpha'] = pd.Categorical(df['Alpha'], np.unique(df['Alpha']))
563
564    if n_groups > 40:
565        fig_size = (7,8)
566    elif n_groups < 25:
567        fig_size = (7,6)
568    else: 
569        fig_size = (7,8)
570
571    if 'Class' in df.columns:
572        alpha = _get_alpha(alpha, result, is_multi)
573        df = df[df['Alpha'] == alpha]
574        x_lab = 'Class'
575    else:
576        x_lab = 'Alpha'
577
578    if scale_weights:
579        max_norms = dict()
580        for ct in set(df['Class']):
581            g_rows = df['Class'] == ct
582            max_norms[ct] = np.max(df[g_rows]['Kernel Weight'])
583            scale_cols = ['Class', 'Kernel Weight']
584
585        new = df[scale_cols].apply(lambda x: x[1] / max_norms[x[0]], axis=1)
586        df['Kernel Weight'] = new
587
588        l_title = 'Scaled\nKernel Weight'
589
590    else:
591        l_title = 'Kernel Weight'
592
593
594    plot = (ggplot(df, aes(x=x_lab, y='Group', fill='Kernel Weight', color='Kernel Weight'))
595            + geom_point(size=5)
596            + scale_fill_gradient(high=high, low=low)
597            + scale_color_gradient(high=high, low=low)
598            + theme_classic()
599            + theme(
600                figure_size=fig_size,
601                axis_text=element_text(weight='bold', size=10),
602                axis_text_x=element_text(rotation=90),
603                axis_title=element_text(weight='bold', size=12),
604                axis_title_y=element_blank(),
605                legend_title=element_text(text=l_title, weight='bold', size=12),
606                legend_text=element_text(weight='bold', size=10)
607            ))
608
609    return plot
610
611
612def group_umap(adata: ad.AnnData, g_name: str | list, is_binary: bool=False, 
613               labels: None | np.ndarray | list=None, title: str='', 
614               save: str=''):
615    """
616    Uses a scmkl formatted `ad.AnnData` object to show sample 
617    separation using scmkl discovered groupings.
618
619    Parameters
620    ----------
621    adata : ad.AnnData
622        A scmkl formatted `ad.AnnData` object with `'group_dict'` in 
623        `.uns`.
624
625    g_name : str | list
626        The groups who's features should be used to filter `adata`. If 
627        is a list, features from multiple groups will be used.
628    
629    is_binary : bool
630        If `True`, data will be processed using `muon` which includes 
631        TF-IDF normalization and LSI.
632
633    labels : None | np.ndarray | list
634        If `None`, labels in `adata.obs['labels']` will be used to 
635        color umap points. Else, provided labels will be used to color 
636        points.
637
638    title : str
639        The title of the plot.
640
641    save : str
642        If provided, plot will be saved using `scanpy`'s `save` 
643        argument. Should be the desired file name. Output will be 
644        `<cwd>/figures/<save>`.
645
646    Returns
647    -------
648    None
649
650    Examples
651    --------
652    >>> adata_fp = 'data/_pbmc_rna.h5ad'
653    >>> group_fp = 'data/_RNA_azimuth_pbmc_groupings.pkl'
654    >>> adata = scmkl.format_adata(adata_fp, 'celltypes', group_fp, 
655    ...                            allow_multiclass=True)
656    >>> scmkl.group_umap(adata, 'CD16+ Monocyte Markers')
657
658    ![group_umap](../tests/figures/umap_group_rna.png)
659    """
660    if list == type(g_name):
661        feats = {feature 
662                 for group in g_name 
663                 for feature in adata.uns['group_dict'][group]}
664        feats = np.array(list(feats))
665    else:
666        feats = np.array(list(adata.uns['group_dict'][g_name]))
667
668    if labels:
669        assert len(labels) == adata.shape[0], "`labels` do not match `adata`"
670        adata.obs['labels'] = labels
671
672    var_names = adata.var_names.to_numpy()
673
674    col_filter = np.isin(var_names, feats)
675    adata = adata[:, col_filter].copy()
676
677    if not is_binary:
678        sc.pp.normalize_total(adata)
679        sc.pp.log1p(adata)
680        sc.tl.pca(adata)
681        sc.pp.neighbors(adata)
682        sc.tl.umap(adata, random_state=1)
683
684    else:
685        ac.pp.tfidf(adata, scale_factor=1e4)
686        sc.pp.normalize_total(adata)
687        sc.pp.log1p(adata)
688        ac.tl.lsi(adata)
689        sc.pp.scale(adata)
690        sc.tl.pca(adata)
691        sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
692        sc.tl.umap(adata, random_state=1)
693
694    if save:
695        sc.pl.umap(adata, title=title, color='labels', save=save, show=False)
696
697    else:
698        sc.pl.umap(adata, title=title, color='labels')
699
700    return None
def color_alpha_star(alphas, alpha_star, color):
45def color_alpha_star(alphas, alpha_star, color):
46    """
47    Takes an array of alphas and returns a list of the same size where 
48    each element is `'black'` except where 
49    `alpha_star == alphas`, which will be `'gold'`.
50
51    Parameters
52    ----------
53    alphas : list | tuple | np.ndarray
54        The 1D array of alphas.
55
56    alpha_star: float
57        The best performing alpha from cross-validation.
58
59    color : str
60        The color of all alphas other than `alpha_star`.
61
62    Returns
63    -------
64    c_array, c_dict : np.ndarray, dict
65        `c_array` is the array of colors corresponding to alphas. 
66        `c_dict` is the color dict with alphas as keys and color as 
67        values.
68    """
69    c_array = np.array([color] * len(alphas), dtype='<U15')
70    as_pos = np.where(alphas == alpha_star)[0]
71    c_array[as_pos] = 'gold'
72
73    c_dict = {alphas[i] : c_array[i]
74              for i in range(len(alphas))}
75
76    return c_array, c_dict

Takes an array of alphas and returns a list of the same size where each element is 'black' except where alpha_star == alphas, which will be 'gold'.

Parameters
  • alphas (list | tuple | np.ndarray): The 1D array of alphas.
  • alpha_star (float): The best performing alpha from cross-validation.
  • color (str): The color of all alphas other than alpha_star.
Returns
  • c_array, c_dict (np.ndarray, dict): c_array is the array of colors corresponding to alphas. c_dict is the color dict with alphas as keys and color as values.
def plot_conf_mat( results, title='', cmap=None, normalize=True, alpha=None, save=None) -> None:
 80def plot_conf_mat(results, title = '', cmap = None, normalize = True,
 81                          alpha = None, save = None) -> None:
 82    """
 83    Creates a confusion matrix from the output of scMKL.
 84
 85    Parameters
 86    ----------
 87    results : dict
 88        The output from either scmkl.run() or scmkl.one_v_rest()
 89        containing results from scMKL.
 90
 91    title : str
 92        The text to display at the top of the matrix.
 93
 94    cmap : matplotlib.colors.LinearSegmentedColormap
 95        The gradient of the values displayed from `matplotlib.pyplot`.
 96        If `None`, `'Purples'` is used see matplotlib color map 
 97        reference for more information. 
 98
 99    normalize : bool
100        If `False`, plot the raw numbers. If `True`, plot the 
101        proportions.
102
103    alpha : None | float
104        Alpha that matrix should be created for. If `results` is from
105        `scmkl.one_v_all()`, this is ignored. If `None`, smallest alpha
106        will be used.
107
108    save : None | str
109        File path to save plot. If `None`, plot is not saved.
110
111    Returns
112    -------
113    None
114    
115    Examples
116    --------
117    >>> # Running scmkl and capturing results
118    >>> results = scmkl.run(adata = adata, alpha_list = alpha_list)
119    >>> 
120    >>> from matplotlib.pyplot import get_cmap
121    >>> 
122    >>> scmkl.plot_conf_mat(results, title = '', cmap = get_cmap('Blues'))
123
124    ![conf_mat](../tests/figures/plot_conf_mat_binary.png)
125
126    Citiation
127    ---------
128    http://scikit-learn.org/stable/auto_examples/model_selection/
129    plot_confusion_matrix.html
130    """
131    # Determining type of results
132    if ('Observed' in results.keys()) and ('Metrics' in results.keys()):
133        multi_class = False
134        names = np.unique(results['Observed'])
135    else:
136        multi_class = True
137        names = np.unique(results['Truth_labels'])
138
139    if multi_class:
140        cm = metrics.confusion_matrix(y_true = results['Truth_labels'], 
141                              y_pred = results['Predicted_class'], 
142                              labels = names)
143    else:
144        min_alpha = np.min(list(results['Metrics'].keys()))
145        alpha = alpha if alpha != None else min_alpha
146        cm = metrics.confusion_matrix(y_true = results['Observed'],
147                              y_pred = results['Predictions'][alpha],
148                              labels = names)
149
150    accuracy = np.trace(cm) / float(np.sum(cm))
151    misclass = 1 - accuracy
152
153    if cmap is None:
154        cmap = plt.get_cmap('Purples')
155
156    if normalize:
157        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
158
159    plt.figure(figsize=(8, 6))
160    plt.imshow(cm, interpolation='nearest', cmap=cmap)
161    plt.title(title)
162    plt.colorbar()
163
164    tick_marks = np.arange(len(names))
165    plt.xticks(tick_marks, names, rotation=45)
166    plt.yticks(tick_marks, names)
167
168
169    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
170    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
171        if normalize:
172            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
173                     horizontalalignment="center",
174                     color="white" if cm[i, j] > thresh else "black")
175        else:
176            plt.text(j, i, "{:,}".format(cm[i, j]),
177                     horizontalalignment="center",
178                     color="white" if cm[i, j] > thresh else "black")
179
180    acc_label = 'Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'
181    acc_label = acc_label.format(accuracy, misclass)
182
183    plt.tight_layout()
184    plt.ylabel('True label')
185    plt.xlabel(acc_label)
186
187    if save != None:
188        plt.savefig(save)
189        plt.clf()
190    else:
191        plt.show()
192
193    return None

Creates a confusion matrix from the output of scMKL.

Parameters
  • results (dict): The output from either scmkl.run or scmkl.one_v_rest containing results from scMKL.
  • title (str): The text to display at the top of the matrix.
  • cmap (matplotlib.colors.LinearSegmentedColormap): The gradient of the values displayed from matplotlib.pyplot. If None, 'Purples' is used see matplotlib color map reference for more information.
  • normalize (bool): If False, plot the raw numbers. If True, plot the proportions.
  • alpha (None | float): Alpha that matrix should be created for. If results is from scmkl.one_v_all(), this is ignored. If None, smallest alpha will be used.
  • save (None | str): File path to save plot. If None, plot is not saved.
Returns
  • None
Examples
>>> # Running scmkl and capturing results
>>> results = scmkl.run(adata = adata, alpha_list = alpha_list)
>>> 
>>> from matplotlib.pyplot import get_cmap
>>> 
>>> scmkl.plot_conf_mat(results, title = '', cmap = get_cmap('Blues'))

conf_mat

Citiation

http://scikit-learn.org/stable/auto_examples/model_selection/ plot_confusion_matrix.html

def plot_metric( summary_df: pandas.core.frame.DataFrame, alpha_star=None, x_axis: str = 'Alpha', color='red'):
196def plot_metric(summary_df : pd.DataFrame, alpha_star = None, 
197                x_axis: str='Alpha', color = 'red'):
198    """
199    Takes a data frame of model metrics and optionally alpha star and
200    creates a scatter plot given metrics against alpha values. For 
201    multiclass results, `alpha_star` is not shown and points are 
202    colored by class.
203    
204    Parameters
205    ----------
206    summary_df : pd.DataFrame
207        Dataframe created by `scmkl.get_summary()`.
208
209    alpha_star : None | float
210        If `not None`, a label will be added for tuned `alpha_star` 
211        being optimal model parameter for performance from cross 
212        validation on the training data. Can be calculated with 
213        `scmkl.optimize_alpha()`. Is ignored if `summary_df` is from a 
214        multiclass result.
215
216    x_axis : str
217        Must be either `'Alpha'` or `'Number of Selected Groups'`. Is 
218        the variable that will be plotted on the x-axis.
219
220    color : str
221        Color to make points on plot.
222
223    Returns
224    -------
225    metric_plot : plotnine.ggplot.ggplot
226        A plot with alpha values on x-axis and metric on y-axis.
227
228    Examples
229    --------
230    >>> results = scmkl.run(adata, alpha_list)
231    >>> summary_df = scmkl.get_summary(results)
232    >>> metric_plot = plot_metric(results)
233    >>>
234    >>> metric_plot.save('scMKL_performance.png')
235
236    ![metric_plot](../tests/figures/plot_metric_binary.png)
237    """
238    # Capturing metric from summary_df
239    metric_options = ['AUROC', 'Accuracy', 'F1-Score', 'Precision', 'Recall']
240    metric = np.intersect1d(metric_options, summary_df.columns)[0]
241
242    x_axis_ticks = np.unique(summary_df[x_axis])
243    summary_df['Alpha Star'] = summary_df['Alpha'] == alpha_star
244
245    if 'Class' in summary_df.columns:
246        color_lab = 'Class'
247    else:
248        c_dict = {
249            True : 'gold',
250            False : color
251            }
252        color_lab = 'Alpha Star'
253
254    if np.min(summary_df[metric]) < 0.5:
255        min_y = 0
256    else:
257        min_y = 0.6
258
259    plot = (ggplot(summary_df, aes(x = x_axis, y = metric, color=color_lab)) 
260            + geom_point(size=4)
261            + theme_classic()
262            + ylim(min_y, 1)
263            + scale_x_reverse(breaks=x_axis_ticks)
264            + theme(
265                axis_text=element_text(weight='bold', size=10),
266                axis_title=element_text(weight='bold', size=12),
267                legend_title=element_text(weight='bold', size=14),
268                legend_text=element_text(weight='bold', size=12)
269            )
270            )
271    
272    if not 'Class' in  summary_df:
273        plot += scale_color_manual(c_dict)
274        
275    return plot

Takes a data frame of model metrics and optionally alpha star and creates a scatter plot given metrics against alpha values. For multiclass results, alpha_star is not shown and points are colored by class.

Parameters
  • summary_df (pd.DataFrame): Dataframe created by scmkl.get_summary().
  • alpha_star (None | float): If not None, a label will be added for tuned alpha_star being optimal model parameter for performance from cross validation on the training data. Can be calculated with scmkl.optimize_alpha. Is ignored if summary_df is from a multiclass result.
  • x_axis (str): Must be either 'Alpha' or 'Number of Selected Groups'. Is the variable that will be plotted on the x-axis.
  • color (str): Color to make points on plot.
Returns
  • metric_plot (plotnine.ggplot.ggplot): A plot with alpha values on x-axis and metric on y-axis.
Examples
>>> results = scmkl.run(adata, alpha_list)
>>> summary_df = scmkl.get_summary(results)
>>> metric_plot = plot_metric(results)
>>>
>>> metric_plot.save('scMKL_performance.png')

metric_plot

def weights_barplot( result, n_groups: int = 1, alpha: None | float = None, color: str = 'red'):
278def weights_barplot(result, n_groups: int=1, alpha: None | float=None, 
279                 color: str='red'):
280    """
281    Plots the top `n_groups` weighted groups for each cell class. Works 
282    for a single scmkl result (either multiclass or binary).
283
284    Parameters
285    ----------
286    result : dict
287        The output of `scmkl.run()`.
288
289    n_groups : int
290        The number of top groups to plot for each cell class.
291
292    alpha : None | float
293        The alpha parameter to create figure for. If `None`, the 
294        smallest alpha is used.
295
296    color : str
297        The color the bars should be.
298
299    Returns
300    -------
301    plot : plotnine.ggplot.ggplot
302        A barplot of weights.
303
304    Examples
305    --------
306    >>> result = scmkl.run(adata, alpha_list)
307    >>> plot = scmkl.plot_weights(result)
308
309    ![weights_barplot](../tests/figures/weights_barplot_binary.png)
310    """
311    is_multi, is_many = _parse_result_type(result)
312    assert not is_many, "This function only supports single results"
313
314    alpha = _get_alpha(alpha, result, is_multi)
315    df = get_weights(result)
316
317    # Subsetting to only alpha and filtering to top groups
318    df = df[df['Alpha'] == alpha]
319    if is_multi:
320        g_order = set(df['Group'])
321        for ct in result['Classes']:
322            cor_ct = df['Class'] == ct
323            other_ct = ~cor_ct
324
325            temp_df = df.copy()
326            temp_df = temp_df[cor_ct]
327            temp_df = temp_df.sort_values('Kernel Weight', ascending=False)
328            top_groups = temp_df['Group'].iloc[0:n_groups].to_numpy()
329            
330            cor_groups = np.isin(df['Group'], top_groups)
331            filter_array = np.logical_and(cor_ct, cor_groups)
332
333            filter_array = np.logical_or(filter_array, other_ct)
334
335            df = df[filter_array]
336
337        df['Group'] = format_group_names(df['Group'], rm_words = ['Markers'])
338            
339
340    else:
341        df = df.sort_values('Kernel Weight', ascending=False)
342        df = df.iloc[0:n_groups]
343        df['Group'] = format_group_names(df['Group'], rm_words = ['Markers']) 
344        g_order = sort_groups(df)[::-1]
345        df['Group'] = pd.Categorical(df['Group'], g_order)
346
347    plot = (ggplot(df)
348            + theme_classic()
349            + coord_flip()
350            + labs(y=f'Kernel Weight (λ = {alpha})')
351            + theme(
352                axis_text=element_text(weight='bold', size=10),
353                axis_title=element_text(weight='bold', size=12),
354                axis_title_y=element_blank()
355            )
356            )
357
358    # This needs to be reworked for multiclass runs
359    if is_multi:
360        height = (3*ceil((len(set(df['Class'])) / 3)))
361        print(height)
362        plot += geom_bar(aes(x='Group', y='Kernel Weight'), 
363                         stat='identity', fill=color)
364        plot += facet_wrap('Class', scales='free', ncol=3)
365        plot += theme(figure_size=(15,height))
366    else:
367        plot += geom_bar(aes(x='Group', y='Kernel Weight'), 
368                         stat='identity', fill=color)
369        plot += theme(figure_size=(7, 9))
370
371    return plot

Plots the top n_groups weighted groups for each cell class. Works for a single scmkl result (either multiclass or binary).

Parameters
  • result (dict): The output of scmkl.run.
  • n_groups (int): The number of top groups to plot for each cell class.
  • alpha (None | float): The alpha parameter to create figure for. If None, the smallest alpha is used.
  • color (str): The color the bars should be.
Returns
  • plot (plotnine.ggplot.ggplot): A barplot of weights.
Examples
>>> result = scmkl.run(adata, alpha_list)
>>> plot = scmkl.plot_weights(result)

weights_barplot

def weights_heatmap( result, n_groups: None | int = None, class_lab: str | None = None, low: str = 'white', high: str = 'red', alpha: float | None = None, scale_weights: bool = False):
374def weights_heatmap(result, n_groups: None | int=None, 
375                    class_lab: str | None=None, low: str='white', 
376                    high: str='red', alpha: float | None=None,
377                    scale_weights: bool=False):
378    """
379    Plots a heatmap of kernel weights with groups on the y-axis and 
380    alpha on the x-axis if binary result. If a multiclass result, one 
381    alpha is used per class and the x-axis is class.
382
383    Parameters
384    ----------
385    result : dict
386        The output of `scmkl.run()`.
387
388    n_groups : int
389        The number of top groups to plot. Not recommended for 
390        multiclass results.
391
392    class_lab : str | None
393        For multiclass results, if `not None`, will only plot group 
394        weights for `class_lab`.
395
396    low : str
397        The color for low kernel weight.
398
399    high : str
400        The color for high kernel weight.
401
402    alpha : None | float
403        The alpha parameter to create figure for. If `None`, the 
404        smallest alpha is used.
405
406    scale_weights : bool
407        If `True`, the the kernel weights will be scaled for each group 
408        within each class. Ignored if result is from a binary 
409        classification.
410
411    Returns
412    -------
413    plot : plotnine.ggplot.ggplot
414        A heatmap of weights.
415
416    Examples
417    --------
418    >>> result = scmkl.run(adata, alpha_list)
419    >>> plot = scmkl.plot_weights(result)
420
421    ![weights_heatmap](../tests/figures/weights_heatmap_binary.png)
422    """
423    is_multi, is_many = _parse_result_type(result)
424    assert not is_many, "This function only supports single results"
425
426    if type(class_lab) is str:
427        result = result[class_lab]
428
429    df = get_weights(result)
430    df['Group'] = format_group_names(df['Group'], ['Markers'])
431
432    # Filtering and sorting values
433    sum_df = df.groupby('Group')['Kernel Weight'].sum()
434    sum_df = sum_df.reset_index()
435    order = sort_groups(sum_df)[::-1]
436    df['Group'] = pd.Categorical(df['Group'], categories=order)
437
438    if type(n_groups) is int:
439        sum_df = sum_df.sort_values(by='Kernel Weight', ascending=False)
440        top_groups = sum_df.iloc[0:n_groups]['Group'].to_numpy()
441        df = df[np.isin(df['Group'], top_groups)]
442    else:
443        n_groups = len(set(df['Group']))
444
445    df['Alpha'] = pd.Categorical(df['Alpha'], np.unique(df['Alpha']))
446
447    if n_groups > 40:
448        fig_size = (7,8)
449    elif n_groups < 25:
450        fig_size = (7,6)
451    else: 
452        fig_size = (7,8)
453
454    if 'Class' in df.columns:
455        alpha = _get_alpha(alpha, result, is_multi)
456        df = df[df['Alpha'] == alpha]
457        x_lab = 'Class'
458    else:
459        x_lab = 'Alpha'
460
461    if scale_weights and is_multi:
462        max_norms = dict()
463        for ct in set(df['Class']):
464            g_rows = df['Class'] == ct
465            max_norms[ct] = np.max(df[g_rows]['Kernel Weight'])
466            scale_cols = ['Class', 'Kernel Weight']
467
468        new = df[scale_cols].apply(lambda x: x[1] / max_norms[x[0]], axis=1)
469        df['Kernel Weight'] = new
470
471        l_title = 'Scaled\nKernel Weight'
472
473    else:
474        l_title = 'Kernel Weight'
475
476    plot = (ggplot(df, aes(x=x_lab, y='Group', fill='Kernel Weight'))
477            + geom_tile(color='black')
478            + scale_fill_gradient(high=high, low=low)
479            + theme_classic()
480            + theme(
481                figure_size=fig_size,
482                axis_text=element_text(weight='bold', size=10),
483                axis_text_x=element_text(rotation=90),
484                axis_title=element_text(weight='bold', size=12),
485                axis_title_y=element_blank(),
486                legend_title=element_text(text=l_title, weight='bold', size=12),
487                legend_text=element_text(weight='bold', size=10)
488            ))
489
490    return plot

Plots a heatmap of kernel weights with groups on the y-axis and alpha on the x-axis if binary result. If a multiclass result, one alpha is used per class and the x-axis is class.

Parameters
  • result (dict): The output of scmkl.run.
  • n_groups (int): The number of top groups to plot. Not recommended for multiclass results.
  • class_lab (str | None): For multiclass results, if not None, will only plot group weights for class_lab.
  • low (str): The color for low kernel weight.
  • high (str): The color for high kernel weight.
  • alpha (None | float): The alpha parameter to create figure for. If None, the smallest alpha is used.
  • scale_weights (bool): If True, the the kernel weights will be scaled for each group within each class. Ignored if result is from a binary classification.
Returns
  • plot (plotnine.ggplot.ggplot): A heatmap of weights.
Examples
>>> result = scmkl.run(adata, alpha_list)
>>> plot = scmkl.plot_weights(result)

weights_heatmap

def weights_dotplot( result, n_groups: None | int = None, class_lab: str | None = None, low: str = 'white', high: str = 'red', alpha: float | None = None, scale_weights: bool = False):
493def weights_dotplot(result, n_groups: None | int=None, 
494                    class_lab: str | None=None, low: str='white', 
495                    high: str='red', alpha: float | None=None, 
496                    scale_weights: bool=False):
497    """
498    Plots a dotplot of kernel weights with groups on the y-axis and 
499    alpha on the x-axis if binary result. If a multiclass result, one 
500    alpha is used per class and the x-axis is class.
501
502    Parameters
503    ----------
504    result : dict
505        The output of `scmkl.run()`.
506
507    n_groups : int
508        The number of top groups to plot. Not recommended for 
509        multiclass results.
510
511    class_lab : str | None
512        For multiclass results, if `not None`, will only plot group 
513        weights for `class_lab`.
514
515    low : str
516        The color for low kernel weight.
517
518    high : str
519        The color for high kernel weight.
520
521    alpha : None | float
522        The alpha parameter to create figure for. If `None`, the 
523        smallest alpha is used.
524
525    scale_weights : bool
526        If `True`, the the kernel weights will be scaled for each 
527        within each class.
528
529    Returns
530    -------
531    plot : plotnine.ggplot.ggplot
532        A barplot of weights.
533
534    Examples
535    --------
536    >>> result = scmkl.run(adata, alpha_list)
537    >>> plot = scmkl.plot_weights(result)
538
539    ![weights_dotplot](../tests/figures/weights_dotplot_binary.png)
540    """
541    is_multi, is_many = _parse_result_type(result)
542    assert not is_many, "This function only supports single results"
543
544    if type(class_lab) is str:
545        result = result[class_lab]
546
547    df = get_weights(result)
548    df['Group'] = format_group_names(df['Group'], ['Markers'])
549
550    # Filtering and sorting values
551    sum_df = df.groupby('Group')['Kernel Weight'].sum()
552    sum_df = sum_df.reset_index()
553    order = sort_groups(sum_df)[::-1]
554    df['Group'] = pd.Categorical(df['Group'], categories=order)
555
556    if type(n_groups) is int:
557        sum_df = sum_df.sort_values(by='Kernel Weight', ascending=False)
558        top_groups = sum_df.iloc[0:n_groups]['Group'].to_numpy()
559        df = df[np.isin(df['Group'], top_groups)]
560    else:
561        n_groups = len(set(df['Group']))
562
563    df['Alpha'] = pd.Categorical(df['Alpha'], np.unique(df['Alpha']))
564
565    if n_groups > 40:
566        fig_size = (7,8)
567    elif n_groups < 25:
568        fig_size = (7,6)
569    else: 
570        fig_size = (7,8)
571
572    if 'Class' in df.columns:
573        alpha = _get_alpha(alpha, result, is_multi)
574        df = df[df['Alpha'] == alpha]
575        x_lab = 'Class'
576    else:
577        x_lab = 'Alpha'
578
579    if scale_weights:
580        max_norms = dict()
581        for ct in set(df['Class']):
582            g_rows = df['Class'] == ct
583            max_norms[ct] = np.max(df[g_rows]['Kernel Weight'])
584            scale_cols = ['Class', 'Kernel Weight']
585
586        new = df[scale_cols].apply(lambda x: x[1] / max_norms[x[0]], axis=1)
587        df['Kernel Weight'] = new
588
589        l_title = 'Scaled\nKernel Weight'
590
591    else:
592        l_title = 'Kernel Weight'
593
594
595    plot = (ggplot(df, aes(x=x_lab, y='Group', fill='Kernel Weight', color='Kernel Weight'))
596            + geom_point(size=5)
597            + scale_fill_gradient(high=high, low=low)
598            + scale_color_gradient(high=high, low=low)
599            + theme_classic()
600            + theme(
601                figure_size=fig_size,
602                axis_text=element_text(weight='bold', size=10),
603                axis_text_x=element_text(rotation=90),
604                axis_title=element_text(weight='bold', size=12),
605                axis_title_y=element_blank(),
606                legend_title=element_text(text=l_title, weight='bold', size=12),
607                legend_text=element_text(weight='bold', size=10)
608            ))
609
610    return plot

Plots a dotplot of kernel weights with groups on the y-axis and alpha on the x-axis if binary result. If a multiclass result, one alpha is used per class and the x-axis is class.

Parameters
  • result (dict): The output of scmkl.run.
  • n_groups (int): The number of top groups to plot. Not recommended for multiclass results.
  • class_lab (str | None): For multiclass results, if not None, will only plot group weights for class_lab.
  • low (str): The color for low kernel weight.
  • high (str): The color for high kernel weight.
  • alpha (None | float): The alpha parameter to create figure for. If None, the smallest alpha is used.
  • scale_weights (bool): If True, the the kernel weights will be scaled for each within each class.
Returns
  • plot (plotnine.ggplot.ggplot): A barplot of weights.
Examples
>>> result = scmkl.run(adata, alpha_list)
>>> plot = scmkl.plot_weights(result)

weights_dotplot

def group_umap( adata: anndata._core.anndata.AnnData, g_name: str | list, is_binary: bool = False, labels: None | numpy.ndarray | list = None, title: str = '', save: str = ''):
613def group_umap(adata: ad.AnnData, g_name: str | list, is_binary: bool=False, 
614               labels: None | np.ndarray | list=None, title: str='', 
615               save: str=''):
616    """
617    Uses a scmkl formatted `ad.AnnData` object to show sample 
618    separation using scmkl discovered groupings.
619
620    Parameters
621    ----------
622    adata : ad.AnnData
623        A scmkl formatted `ad.AnnData` object with `'group_dict'` in 
624        `.uns`.
625
626    g_name : str | list
627        The groups who's features should be used to filter `adata`. If 
628        is a list, features from multiple groups will be used.
629    
630    is_binary : bool
631        If `True`, data will be processed using `muon` which includes 
632        TF-IDF normalization and LSI.
633
634    labels : None | np.ndarray | list
635        If `None`, labels in `adata.obs['labels']` will be used to 
636        color umap points. Else, provided labels will be used to color 
637        points.
638
639    title : str
640        The title of the plot.
641
642    save : str
643        If provided, plot will be saved using `scanpy`'s `save` 
644        argument. Should be the desired file name. Output will be 
645        `<cwd>/figures/<save>`.
646
647    Returns
648    -------
649    None
650
651    Examples
652    --------
653    >>> adata_fp = 'data/_pbmc_rna.h5ad'
654    >>> group_fp = 'data/_RNA_azimuth_pbmc_groupings.pkl'
655    >>> adata = scmkl.format_adata(adata_fp, 'celltypes', group_fp, 
656    ...                            allow_multiclass=True)
657    >>> scmkl.group_umap(adata, 'CD16+ Monocyte Markers')
658
659    ![group_umap](../tests/figures/umap_group_rna.png)
660    """
661    if list == type(g_name):
662        feats = {feature 
663                 for group in g_name 
664                 for feature in adata.uns['group_dict'][group]}
665        feats = np.array(list(feats))
666    else:
667        feats = np.array(list(adata.uns['group_dict'][g_name]))
668
669    if labels:
670        assert len(labels) == adata.shape[0], "`labels` do not match `adata`"
671        adata.obs['labels'] = labels
672
673    var_names = adata.var_names.to_numpy()
674
675    col_filter = np.isin(var_names, feats)
676    adata = adata[:, col_filter].copy()
677
678    if not is_binary:
679        sc.pp.normalize_total(adata)
680        sc.pp.log1p(adata)
681        sc.tl.pca(adata)
682        sc.pp.neighbors(adata)
683        sc.tl.umap(adata, random_state=1)
684
685    else:
686        ac.pp.tfidf(adata, scale_factor=1e4)
687        sc.pp.normalize_total(adata)
688        sc.pp.log1p(adata)
689        ac.tl.lsi(adata)
690        sc.pp.scale(adata)
691        sc.tl.pca(adata)
692        sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
693        sc.tl.umap(adata, random_state=1)
694
695    if save:
696        sc.pl.umap(adata, title=title, color='labels', save=save, show=False)
697
698    else:
699        sc.pl.umap(adata, title=title, color='labels')
700
701    return None

Uses a scmkl formatted ad.AnnData object to show sample separation using scmkl discovered groupings.

Parameters
  • adata (ad.AnnData): A scmkl formatted ad.AnnData object with 'group_dict' in .uns.
  • g_name (str | list): The groups who's features should be used to filter adata. If is a list, features from multiple groups will be used.
  • is_binary (bool): If True, data will be processed using muon which includes TF-IDF normalization and LSI.
  • labels (None | np.ndarray | list): If None, labels in adata.obs['labels'] will be used to color umap points. Else, provided labels will be used to color points.
  • title (str): The title of the plot.
  • save (str): If provided, plot will be saved using scanpy's save argument. Should be the desired file name. Output will be <cwd>/figures/<save>.
Returns
  • None
Examples
>>> adata_fp = 'data/_pbmc_rna.h5ad'
>>> group_fp = 'data/_RNA_azimuth_pbmc_groupings.pkl'
>>> adata = scmkl.format_adata(adata_fp, 'celltypes', group_fp, 
...                            allow_multiclass=True)
>>> scmkl.group_umap(adata, 'CD16+ Monocyte Markers')

group_umap