组快速线性回归 [英] Fast linear regression by group

查看:169
本文介绍了组快速线性回归的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我拥有 500K 用户,我需要为每个用户计算一个线性回归(带截图)。

I have 500K users and I need to compute a linear regression (with intercept) for each of them.

每个用户都有大约30条记录。

Each user has around 30 records.

我尝试使用 dplyr lm ,这太慢了。
用户约2秒。

I tried with dplyr and lm and this is way too slow. Around 2 sec by user.

  df%>%                       
      group_by(user_id, add =  FALSE) %>%
      do(lm = lm(Y ~ x, data = .)) %>%
      mutate(lm_b0 = summary(lm)$coeff[1],
             lm_b1 = summary(lm)$coeff[2]) %>%
      select(user_id, lm_b0, lm_b1) %>%
      ungroup()
    )

我尝试使用已知速度更快的 lm.fit 它似乎不兼容 dplyr

I tried to use lm.fit which is known to be faster but it doesn't seem to be compatible with dplyr.

是否有一个快速的方法来做

Is there a fast way to do a linear regression by group?

推荐答案

您可以使用计算斜率和回归的基本公式。如果你所关心的是这两个数字,那么 lm 会做很多不必要的事情。在这里,我使用 data.table 进行聚合,但是您也可以在基本R中执行(或 dplyr ) :

You can just use the basic formulas for calculating slope and regression. lm does a lot of unnecessary things if all you care about are those two numbers. Here I use data.table for the aggregation, but you could do it in base R as well (or dplyr):

system.time(
  res <- DT[, 
    {
      ux <- mean(x)
      uy <- mean(y)
      slope <- sum((x - ux) * (y - uy)) / sum((x - ux) ^ 2)
      list(slope=slope, intercept=uy - slope * ux)
    }, by=user.id
  ]
)

为500K个用户生成约30个obs(以秒为单位):

Produces for 500K users ~30 obs each (in seconds):

 user  system elapsed 
 7.35    0.00    7.36 

或关于每个用户15微秒。并确认这是按预期工作:

Or about 15 microseconds per user. And to confirm this is working as expected:

> summary(DT[user.id==89663, lm(y ~ x)])$coefficients
             Estimate Std. Error   t value  Pr(>|t|)
(Intercept) 0.1965844  0.2927617 0.6714826 0.5065868
x           0.2021210  0.5429594 0.3722580 0.7120808
> res[user.id == 89663]
   user.id    slope intercept
1:   89663 0.202121 0.1965844

数据:

set.seed(1)
users <- 5e5
records <- 30
x <- runif(users * records)
DT <- data.table(
  x=x, y=x + runif(users * records) * 4 - 2, 
  user.id=sample(users, users * records, replace=T)
)

这篇关于组快速线性回归的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆