倍增LCA

预计阅读时间: 4 分钟 592 次阅读 840 字 最后更新于 2023-04-07 算法与数据结构


前言

最近公共祖先问题求的是一棵树上任意两个节点的最近的相同父节点,又称LCA,这个问题最容易想到的做法就是暴力,我们首先记录一下每个节点在树中的深度,找到两个节点,然后把深度更深的那一个节点提升到和另一个节点相同的深度,然后两个节点一步一步往上走,直到相遇为止,相遇的节点就是他们的最近公共祖先。

这个算法时间复杂度有点高,我们有两种处理的办法,一种是通过树链剖分求,另一种则是通过倍增的思想,之前的blog有写过树剖的写法,这里则是记录一下倍增的写法。

我想翻blog找找我的LCA模版代码来着,结果全是清一色的树剖,谁考场上LCA写树剖呀(

思路

还是回到刚才的暴力,我们来尝试用倍增来优化这个暴力,首先创建一个数组 $f_{ij}$ 表示从$i$号节点走$2^j$步所能到达的节点,对于f[i][0],它表示向上走$2^0$步也就是1步时的节点,也就是它的父亲,而$2^i = 2^{i-1} + 2^{i-1}$,我们不难得到以下下的递推柿子。

fa[now][i] = fa[fa[now][i - 1]][i - 1];

好了,现在我们已经有了这个倍增数组了,刚才我们要先把两个节点上升到同一深度,这里我们可以用这个倍增数组实现,代码如下

if (dep[x] < dep[y])
    swap(x, y);
for (int i = log2max; i >= 0; i--)
    if (dep[fa[x][i]] >= dep[y])
        x = fa[x][i];

这里的log2max要根据题目的数据范围来定。

我们升到了同一深度之后,直接判断是否重合,如果没重合,我们继续使用倍增的思想往上跳,不断去缩小范围,代码和上面的十分相似

for (int i = log2max; i >= 0; i--)
    if (fa[x][i] != fa[y][i])
    {
        x = fa[x][i];
        y = fa[y][i];
    }

要注意,这里的判断条件,最后判断到fa[x][0] == fa[y][0] 就会停止,此时答案不是x,而是fa[x][0],别写错了(

Code

以洛谷LCA模版题为例

#include <iostream>
#include <vector>
using std::cin;
using std::cout;
using std::endl;
using std::ios;
using std::swap;
using std::vector;

const int maxn = 1000005, log2max = 20;
vector<int> edge[maxn]; // vector 实现邻接表
int fa[maxn][50], dep[maxn], n, m, s;

void dfs(int now, int father)
{
    fa[now][0] = father;
    dep[now] = dep[father] + 1;
    for (int i = 1; i <= log2max; i++)
        fa[now][i] = fa[fa[now][i - 1]][i - 1];

    for (auto i : edge[now])
        if (i != father)
            dfs(i, now);
}

int lca(int x, int y)
{
    if (dep[x] < dep[y])
        swap(x, y);

    for (int i = log2max; i >= 0; i--)
        if (dep[fa[x][i]] >= dep[y])
            x = fa[x][i];
    if (x == y)
        return x;

    for (int i = log2max; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
        {
            x = fa[x][i];
            y = fa[y][i];
        }
    return fa[x][0];
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m >> s;

    for (int i = 1, x, y; i <= n - 1; i++)
    {
        cin >> x >> y;
        edge[x].push_back(y);
        edge[y].push_back(x);
    }
    dfs(s, 0);
    while (m --> 0)
    {
        int x, y;
        cin >> x >> y;
        if (x == y)
            cout << x << endl;
        else
            cout << lca(x, y) << endl;
    }
    return 0;
}

END