SQLAlchemy、array_agg 和匹配输入列表 [英] SQLAlchemy, array_agg, and matching an input list

查看:39
本文介绍了SQLAlchemy、array_agg 和匹配输入列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试更充分地使用 SQLAlchemy,而不是在遇到困难的第一个迹象时就退回到纯 SQL.在这种情况下,我在 Postgres 数据库 (9.5) 中有一个表,它通过将单个项目 atom_id 与组标识符 group_id 相关联来将一组整数存储为一个组.

I am attempting to use SQLAlchemy more fully, rather than just falling back to pure SQL at the first sign of distress. In this case, I have a table in a Postgres database (9.5) which stores a set of integers as a group by associating individual items atom_id with a group identifier group_id.

给定一个 atom_ids 列表,我希望能够找出哪个 group_id,如果有的话,那组 atom_ids属于.仅使用 group_idatom_id 列解决这个问题很简单.

Given a list of atom_ids, I'd like to be able to figure out which group_id, if any, that set of atom_ids belong to. Solving this with just the group_id and atom_id columns was straightforward.

现在我试图概括,组"不仅由 atom_ids 的列表组成,还包括其他上下文.在下面的例子中,列表是通过包含一个 sequence 列来排序的,但从概念上讲,可以使用其他列来代替,例如一个 weight 列,它给出了每个 atom_id 一个 [0,1] 浮点值,表示该原子在组中的份额".

Now I'm trying to generalize such that a 'group' is made up of not just a list of atom_ids, but other context as well. In the example below, the list is ordered by including a sequence column, but conceptually other columns could be used instead, such as a weight column which gives each atom_id a [0,1] floating point value representing that atom's 'share' of the group.

以下是演示我的问题的大部分单元测试.

Below is most of a unit test demonstrating my issue.

首先,一些设置:

def test_multi_column_grouping(self):
    class MultiColumnGroups(base.Base):
        __tablename__ = 'multi_groups'

        group_id = Column(Integer)
        atom_id = Column(Integer)
        sequence = Column(Integer)  # arbitrary 'other' column.  In this case, an integer, but it could be a float (e.g. weighting factor)

    base.Base.metadata.create_all(self.engine)

    # Insert 6 rows representing 2 different 'groups' of values
    vals = [
        # Group 1
        {'group_id': 1, 'atom_id': 1, 'sequence': 1},
        {'group_id': 1, 'atom_id': 2, 'sequence': 2},
        {'group_id': 1, 'atom_id': 3, 'sequence': 3},
        # Group 2
        {'group_id': 2, 'atom_id': 1, 'sequence': 3},
        {'group_id': 2, 'atom_id': 2, 'sequence': 2},
        {'group_id': 2, 'atom_id': 3, 'sequence': 1},
    ]

    self.session.bulk_save_objects(
        [MultiColumnGroups(**x) for x in vals])
    self.session.flush()

    self.assertEqual(6, len(self.session.query(MultiColumnGroups).all()))

现在,我想查询上表以查找一组特定的输入属于哪个组.我正在使用(命名)元组列表来表示查询参数.

Now, I want to query the above table to find which group a specific set of inputs belongs to. I'm using a list of (named) tuples to represent the query parameters.

    from collections import namedtuple
    Entity = namedtuple('Entity', ['atom_id', 'sequence'])
    values_to_match = [
        # (atom_id, sequence)
        Entity(1, 3),
        Entity(2, 2),
        Entity(3, 1),
        ]
    # The above list _should_ match with `group_id == 2`

原始 SQL 解决方案.我宁愿不要依赖于此,因为本练习的一部分是学习更多 SQLAlchemy.

Raw SQL solution. I'd prefer not to fall back on this, as a part of this exercise is to learn more SQLAlchemy.

    r = self.session.execute('''
        select group_id
        from multi_groups
        group by group_id
        having array_agg((atom_id, sequence)) = :query_tuples
        ''', {'query_tuples': values_to_match}).fetchone()
    print(r)  # > (2,)
    self.assertEqual(2, r[0])

这是上面的原始 SQL 解决方案相当直接地转换为损坏的 SQLAlchemy 查询.运行它会产生一个 psycopg2 错误:(psycopg2.ProgrammingError) 操作符不存在:record[] = integer[].我相信我需要将 array_agg 转换为 int[] 吗?只要分组列都是整数(如果需要,这是一个可接受的限制),这将起作用,但理想情况下,这适用于混合类型的输入元组/表列.

Here is the above raw-SQL solution converted fairly directly into a broken SQLAlchemy query. Running this produces a psycopg2 error: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]. I believe that I need to cast the array_agg into an int[]? That would work so long as the grouping columns are all integers (which, if need be, is an acceptable limitation), but ideally this would work with mixed-type input tuples / table columns.

    from sqlalchemy import tuple_
    from sqlalchemy.dialects.postgresql import array_agg

    existing_group = self.session.query(MultiColumnGroups).
        with_entities(MultiColumnGroups.group_id).
        group_by(MultiColumnGroups.group_id).
        having(array_agg(tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.sequence)) == values_to_match).
        one_or_none()

    self.assertIsNotNone(existing_group)
    print('|{}|'.format(existing_group))

上面的 session.query() 关闭了吗?我是否在这里蒙蔽了自己,遗漏了一些可以以其他方式解决这个问题的非常明显的东西?

Is the above session.query() close? Have I blinded myself here, and am missing something super obvious that would solve this problem in some other way?

推荐答案

我认为你的解决方案会产生不确定的结果,因为组内的行是未指定的顺序,因此数组聚合和给定数组之间的比较可能会产生真假基于:

I think your solution would produce indeterminate results, because the rows within a group are in unspecified order, and so the comparison between the array aggregate and given array may produce true or false based on that:

[local]:5432 u@sopython*=> select group_id
[local] u@sopython- > from multi_groups 
[local] u@sopython- > group by group_id
[local] u@sopython- > having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
 group_id 
----------
        2
(1 row)

[local]:5432 u@sopython*=> update multi_groups set atom_id = atom_id where atom_id = 2;
UPDATE 2
[local]:5432 u@sopython*=> select group_id                                             
from multi_groups 
group by group_id
having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
 group_id 
----------
(0 rows)

您可以对两者应用排序,或者尝试一些完全不同的方法:您可以使用 关系划分.

You could apply an ordering to both, or try something entirely different: instead of array comparison you could use relational division.

为了划分,您必须从您的 Entity 记录列表中形成一个临时关系.同样,有很多方法可以解决这个问题.这是一个使用非嵌套数组的方法:

In order to divide you have to form a temporary relation from your list of Entity records. Again, there are many ways to approach that. Here's one using unnested arrays:

In [112]: vtm = select([
     ...:     func.unnest(postgresql.array([
     ...:         getattr(e, f) for e in values_to_match
     ...:     ])).label(f)
     ...:     for f in Entity._fields
     ...: ]).alias()

另一个使用联合:

In [114]: vtm = union_all(*[
     ...:     select([literal(e.atom_id).label('atom_id'),
     ...:             literal(e.sequence).label('sequence')])
     ...:     for e in values_to_match
     ...: ]).alias()

临时表也可以.

有了手头的新关系,您想找到找到那些不存在不在组中的实体的multi_groups"的答案.这是一个可怕的句子,但有道理:

With the new relation at hand you want to find the answer to "find those multi_groups for which no entity exists that is not in the group". It's a horrible sentence, but makes sense:

In [117]: mg = aliased(MultiColumnGroups)

In [119]: session.query(MultiColumnGroups.group_id).
     ...:     filter(~exists().
     ...:         select_from(vtm).
     ...:         where(~exists().
     ...:             where(MultiColumnGroups.group_id == mg.group_id).
     ...:             where(tuple_(vtm.c.atom_id, vtm.c.sequence) ==
     ...:                   tuple_(mg.atom_id, mg.sequence)).
     ...:             correlate_except(mg))).
     ...:     distinct().
     ...:     all()
     ...: 
Out[119]: [(2)]

<小时>

另一方面,您也可以只选择具有给定实体的组的交集:


On the other hand you could also just select the intersection of groups with the given entities:

In [19]: gs = intersect(*[
    ...:     session.query(MultiColumnGroups.group_id).
    ...:         filter(MultiColumnGroups.atom_id == vtm.atom_id,
    ...:                MultiColumnGroups.sequence == vtm.sequence)
    ...:     for vtm in values_to_match
    ...: ])

In [20]: session.execute(gs).fetchall()
Out[20]: [(2,)]

<小时>

错误

ProgrammingError: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]
LINE 3: ...gg((multi_groups.atom_id, multi_groups.sequence)) = ARRAY[AR...
                                                             ^
HINT:  No operator matches the given name and argument type(s). You might need to add explicit type casts.
 [SQL: 'SELECT multi_groups.group_id AS multi_groups_group_id 
FROM multi_groups GROUP BY multi_groups.group_id 
HAVING array_agg((multi_groups.atom_id, multi_groups.sequence)) = %(array_agg_1)s'] [parameters: {'array_agg_1': [[1, 3], [2, 2], [3, 1]]}] (Background on this error at: http://sqlalche.me/e/f405)

是您的 values_to_match 首先转换为列表列表(原因未知)然后 由您的 DB-API 驱动程序转换为数组.它导致一个整数数组的数组,而不是一个记录数组(int,int).使用 原始 DB-API 连接 和游标,传递元组列表按您的预期工作.

is a result of how your values_to_match is first converted to a list of lists (for reasons unknown) and then converted to an array by your DB-API driver. It results in an array of array of integer, not an array of record (int, int). Using a raw DB-API connection and cursor, passing a list of tuples works as you'd expect.

在 SQLAlchemy 中,如果您使用 sqlalchemy.dialects.postgresql.array(),它按您的意思工作,但请记住,结果是不确定的.

In SQLAlchemy if you wrap the list values_to_match with sqlalchemy.dialects.postgresql.array(), it works as you meant it to work, though remember that the results are indeterminate.

这篇关于SQLAlchemy、array_agg 和匹配输入列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆