Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can we link topics with samples? #34

Open
elmonten opened this issue Dec 6, 2024 · 4 comments
Open

How can we link topics with samples? #34

elmonten opened this issue Dec 6, 2024 · 4 comments

Comments

@elmonten
Copy link

elmonten commented Dec 6, 2024

Hello,

How can we map back which topics are present in which samples?

Best wishes,
Elena

@JCSzamosi
Copy link

I have written a set of functions that do this, based on the plot_beta() structure, but using the gammas instead of the betas. They are in draft form right now, but I'll copy them below. @lasy, please let me know if you would be interested in a pull request for these. Also, please let me know if I did this correctly. I'm reasonably confident in it, but would welcome any feedback/corrections you have.

plot_gamma = function(x, models = 'all', x_axis = 'label', gamma_aes = 'size',
                      color_by = 'path'){
    gamma_aes <- match.arg(gamma_aes, choices = c('size', 'alpha'))
    color_by <- match.arg(color_by, choices = c('topic', 'path', 
                                                'refinement', 'coherence'))
    x_axis <- match.arg(x_axis, choices = c('label', 'index'))
    
    gamma <- (x
              %>% plot_gamma_layout(models, color_by)
              %>% format_gamma(x_axis = x_axis))
    g <- ggplot(gamma, aes(x = x, y = s)) + guides(size = 'none')
    if (gamma_aes == 'size'){
        g <- g + geom_point(aes(size = g, col = topic_col)) +
            scale_size(range = c(0,5), limits = c(0,1))
    } else {
        g <- g + geom_tile(aes(alpha = g, fill = topic_col)) +
            scale_alpha(range = c(0, 1), limits = c(0, 1))
    }
    
    if (color_by %in% c('topic', 'path')){
        g <- g + scale_color_discrete(color_by) +
            scale_fill_discrete(color_by)
    } else {
        max_score <- ifelse(color_by == 'refinement', n_models(x), 1)
        g <- g + scale_color_gradient(color_by, low = 'brown1',
                                      high = 'cornflowerblue',
                                      limits = c(0, max_score)) +
            scale_fill_gradient(color_by, low = 'brown1',
                                high = 'cornflowerblue',
                                limits = c(0, max_score))
    }
    g <- g + facet_grid(.~m, scales = 'free', space = 'free') +
        labs(x = '', y = '') +
        theme(panel.border = element_rect(color = 'black', fill = NA),
              panel.spacing.x = unit(0, 'pt'),
              strip.background = element_rect(color = 'black'))
    return(g)
}

plot_gamma_layout = function(x, subset = 'all', color_by = 'path'){
    model_params <- models(x)
    if (length(subset) == 1 && subset == 'last'){
        model_params <- model_params[n_models(x)]
    } else if (length(subset) == 1 && subset == 'all'){
        model_params <- model_params
    } else {
        model_params <- model_params[subset]
    }
    gammas <- (model_params
               %>% map_dfr(~as.data.frame(t(.$gamma)), .id = 'm')
               %>% mutate(m = factor(m, levels = rev(names(model_params)))))
    topic_weights <- (topics(x)
                      %>% filter(m %in% gammas$m)
                      %>% mutate(topic = factor(k))
                      %>% rename(topic_col = !!color_by))
    lst = list(gammas = gammas, weights = topic_weights)
    return(lst)
}

format_gamma = function(p, x_axis = 'label'){
    gamma <- (p$gammas
              %>% group_by(m)
              %>% mutate(k = row_number())
              %>% ungroup()
              %>% left_join(p$weights
                            %>% select(m, k, k_label, topic_col),
                            by = c('m','k'))
              %>% pivot_longer(-c(m, k, k_label, topic_col),
                               names_to = 's', values_to = 'g'))
    if (x_axis == 'label'){
        gamma$x <- gamma$k_label
    } else {
        gamma$x <- gamma$k %>% factor()
    }
    s_order <- (gamma
                %>% slice_min(m)
                %>% arrange(s, -g)
                %>% group_by(s)
                %>% slice_head(n = 1)
                %>% ungroup()
                %>% arrange(x))
    gamma <- (gamma
              %>% mutate(s = factor(s, levels = (s_order$s
                                                 %>% rev())),
                         m = factor(m, levels = rev(levels(m)))))
    return(gamma)
}

@JCSzamosi
Copy link

One thing to keep in mind is that this plot is communicating a slightly different thing than theplot_beta() one. With plot_beta(), it's showing each feature's contribution to a topic, so the sum of the sizes of each column (topic) will be 1. With plot_gamma(), it's showing each topic's contribution to a sample. This means that the columns don't have to sum to 1, but within each facet the rows do.

@krisrs1128
Copy link
Collaborator

Thank you @elmonten for the question and @JCSzamosi for the great functions. For completeness, here's how you can look up samples with large values in specific topics. This code runs an alignment,

library(purrr)
library(alto)
params <- map(set_names(1:10), ~ list(k = .))
models <- run_lda_models(vm_data$counts, params)
result <- align_topics(models)

The topics slot in the alignment output summarizes all the topics across the models. m refers to a column in the alignment plot and k to the topic within that model.

r$> result@topics
# A tibble: 55 × 8
   m         k k_label  mass  prop path  coherence
   <fct> <int> <fct>   <dbl> <dbl> <fct>     <dbl>
 1 1         1 1       2179  1     1         0.349
 2 2         1 1       1190. 0.546 1         0.522
 3 2         2 2        989. 0.454 10        0.547
 4 3         1 1        952. 0.437 1         0.562
 5 3         2 3        628. 0.288 4         0.451
 6 3         3 2        599. 0.275 10        0.826
 7 4         1 4        685. 0.315 1         0.619
 8 4         2 1        303. 0.139 7         0.313
 9 4         3 3        601. 0.276 4         0.463
10 4         4 2        589. 0.270 10        0.834
# ℹ 45 more rows
# ℹ 1 more variable: refinement <dbl>
# ℹ Use `print(n = ...)` to see more rows

If we realize we are interested in a specific topic, then we can go back to the original LDA output to find the samples that have high weights on that topic. E.g., for topic 3 in model 8, we could use

sort_order <- order(models[[8]]$gamma[, 3], decreasing = TRUE)
head(models[[8]]$gamma[sort_order, ], 10) |>
    round(2)

which shows that sample 1063701288 is nearly a pure representative from that topic.

r$> head(models[[8]]$gamma[sort_order, ], 10) |>
        round(2)
           [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]
1063701288 0.00    0 0.99 0.00 0.01 0.00    0 0.00
1060901118 0.00    0 0.99 0.00 0.00 0.00    0 0.01
1060901268 0.00    0 0.98 0.00 0.00 0.00    0 0.02
1063701208 0.00    0 0.97 0.00 0.02 0.00    0 0.00
1060901188 0.00    0 0.95 0.00 0.00 0.00    0 0.05
4009201318 0.12    0 0.87 0.00 0.00 0.01    0 0.00
1063701338 0.00    0 0.87 0.03 0.10 0.00    0 0.00
1063701198 0.00    0 0.85 0.01 0.14 0.00    0 0.00
1060901168 0.00    0 0.84 0.00 0.00 0.10    0 0.06
1063701318 0.00    0 0.84 0.08 0.08 0.00    0 0.00

To go from the alignment diagram to a specific path ID, you might want to show the path IDs in the legend. For this, we have to manually override the default plotting function (not elegant, but it works).

p <- plot(result)
p$guides <- guides(fill = guide_legend())
p

image

I'll let @lasy decide whether we should make a pull request. I've been able to use the function to see labels in the gamma plot for our demo data. She might also have other ideas for inspecting samples.

@elmonten
Copy link
Author

Perfect, thank you all very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants