主题模型:对数似然或困惑的交叉验证 [英] Topic models: cross validation with loglikelihood or perplexity

查看:144
本文介绍了主题模型:对数似然或困惑的交叉验证的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用主题建模对文档进行聚类。我需要提出最佳主题编号。因此,我决定针对主题10、20,... 60进行十次交叉验证。



我将语料库分为十批,并预留了一批用于保持集。我使用主题为10到60的9个批次(共180个文档)进行了潜在狄利克雷分配(LDA)。现在,我必须计算保留集的困惑或对数可能性。



我发现


I'm clustering documents using topic modeling. I need to come up with the optimal topic numbers. So, I decided to do ten fold cross validation with topics 10, 20, ...60.

I have divided my corpus into ten batches and set aside one batch for a holdout set. I have ran latent dirichlet allocation (LDA) using nine batches (total 180 documents) with topics 10 to 60. Now, I have to calculate perplexity or log likelihood for the holdout set.

I found this code from one of CV's discussion sessions. I really don't understand several lines of codes below. I have dtm matrix using the holdout set (20 documents). But I don't know how to calculate the perplexity or log likelihood of this holdout set.


Questions:

  1. Can anybody explain to me what seq(2, 100, by =1) mean here? Also, what AssociatedPress[21:30] mean? What function(k) is doing here?

    best.model <- lapply(seq(2, 100, by=1), function(k){ LDA(AssociatedPress[21:30,], k) })
    

  2. If I want to calculate perplexity or log likelihood of the holdout set called dtm, is there better code? I know there are perplexity() and logLik() functions but since I'm new I can not figure out how to implement it with my holdout matrix, called dtm.

  3. How can I do ten fold cross validation with my corpus, containing 200 documents? Is there existing code that I can invoke? I found caret for this purpose, but again cannot figure that out either.

解决方案

The accepted answer to this question is good as far as it goes, but it doesn't actually address how to estimate perplexity on a validation dataset and how to use cross-validation.

Using perplexity for simple validation

Perplexity is a measure of how well a probability model fits a new set of data. In the topicmodels R package it is simple to fit with the perplexity function, which takes as arguments a previously fit topic model and a new set of data, and returns a single number. The lower the better.

For example, splitting the AssociatedPress data into a training set (75% of the rows) and a validation set (25% of the rows):

# load up some R packages including a few we'll need later
library(topicmodels)
library(doParallel)
library(ggplot2)
library(scales)

data("AssociatedPress", package = "topicmodels")

burnin = 1000
iter = 1000
keep = 50

full_data  <- AssociatedPress
n <- nrow(full_data)
#-----------validation--------
k <- 5

splitter <- sample(1:n, round(n * 0.75))
train_set <- full_data[splitter, ]
valid_set <- full_data[-splitter, ]

fitted <- LDA(train_set, k = k, method = "Gibbs",
                          control = list(burnin = burnin, iter = iter, keep = keep) )
perplexity(fitted, newdata = train_set) # about 2700
perplexity(fitted, newdata = valid_set) # about 4300

The perplexity is higher for the validation set than the training set, because the topics have been optimised based on the training set.

Using perplexity and cross-validation to determine a good number of topics

The extension of this idea to cross-validation is straightforward. Divide the data into different subsets (say 5), and each subset gets one turn as the validation set and four turns as part of the training set. However, it's really computationally intensive, particularly when trying out the larger numbers of topics.

You might be able to use caret to do this, but I suspect it doesn't handle topic modelling yet. In any case, it's the sort of thing I prefer to do myself to be sure I understand what's going on.

The code below, even with parallel processing on 7 logical CPUs, took 3.5 hours to run on my laptop:

#----------------5-fold cross-validation, different numbers of topics----------------
# set up a cluster for parallel processing
cluster <- makeCluster(detectCores(logical = TRUE) - 1) # leave one CPU spare...
registerDoParallel(cluster)

# load up the needed R package on all the parallel sessions
clusterEvalQ(cluster, {
   library(topicmodels)
})

folds <- 5
splitfolds <- sample(1:folds, n, replace = TRUE)
candidate_k <- c(2, 3, 4, 5, 10, 20, 30, 40, 50, 75, 100, 200, 300) # candidates for how many topics

# export all the needed R objects to the parallel sessions
clusterExport(cluster, c("full_data", "burnin", "iter", "keep", "splitfolds", "folds", "candidate_k"))

# we parallelize by the different number of topics.  A processor is allocated a value
# of k, and does the cross-validation serially.  This is because it is assumed there
# are more candidate values of k than there are cross-validation folds, hence it
# will be more efficient to parallelise
system.time({
results <- foreach(j = 1:length(candidate_k), .combine = rbind) %dopar%{
   k <- candidate_k[j]
   results_1k <- matrix(0, nrow = folds, ncol = 2)
   colnames(results_1k) <- c("k", "perplexity")
   for(i in 1:folds){
      train_set <- full_data[splitfolds != i , ]
      valid_set <- full_data[splitfolds == i, ]

      fitted <- LDA(train_set, k = k, method = "Gibbs",
                    control = list(burnin = burnin, iter = iter, keep = keep) )
      results_1k[i,] <- c(k, perplexity(fitted, newdata = valid_set))
   }
   return(results_1k)
}
})
stopCluster(cluster)

results_df <- as.data.frame(results)

ggplot(results_df, aes(x = k, y = perplexity)) +
   geom_point() +
   geom_smooth(se = FALSE) +
   ggtitle("5-fold cross-validation of topic modelling with the 'Associated Press' dataset",
           "(ie five different models fit for each candidate number of topics)") +
   labs(x = "Candidate number of topics", y = "Perplexity when fitting the trained model to the hold-out set")

We see in the results that 200 topics is too many and has some over-fitting, and 50 is too few. Of the numbers of topics tried, 100 is the best, with the lowest average perplexity on the five different hold-out sets.

这篇关于主题模型:对数似然或困惑的交叉验证的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆