Kruskal重构树

Kruskal重构树是一个图论算法,目的是解决“到$u$的最大边不超过$k$的所有点”这类问题。使用Kruskal重构树可以将图变成二叉树,便于使用各种数据结构来维护

实现

先说实现再说性质。

对Kruskal算法进行魔改。在连边时,我们建出一个新的节点,点权为这条边的点权,作为这条边两个端点的父节点。并查集中,将这两个点的根指向新的节点。

很简单是不是

性质

有这么几个显然的性质

  1. 树是二叉的,根是最后一条连的边
  2. 不看叶子节点,是一个大顶堆
  3. 两点间的LCA就是连接这两点的所有简单路径中最大边的最小值

有了这三条性质,就可以解决类似“只允许走边权$\leq k$的边,能到达的节点”这样的问题了。

例题

peaks

#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;
const int MAXN = 2e5 + 5;
const int MAXM = 5e5 + 5;

int q;
int n, m, tot; 
int a[MAXN];

struct Edge{
    int x, y, z;
};

int Comp(Edge a, Edge b) {
    return a.z < b.z;
}

Edge e[MAXM];
int bel[MAXN];

int Find(int x) {
    if (bel[x] == x) return x;
    else return bel[x] = Find(bel[x]);
}

int to[MAXN << 1], nxt[MAXN << 1];
int head[MAXN], ecnt;
int fa[MAXN][25];

void Add(int u, int v) {
    ecnt++;
    to[ecnt] = v; nxt[ecnt] = head[u]; head[u] = ecnt;
}

void Kruskal() {
    tot = n;
    for (int i = 1; i <= n; i++) bel[i] = i;
    sort(e + 1, e + m + 1, Comp);
    for (int i = 1; i <= m; i++) {
        int x = e[i].x, y = e[i].y, z = e[i].z;
        x = Find(x); y = Find(y);
        if (x == y) continue;
        tot++;
        a[tot] = z;
        bel[x] = bel[y] = tot;
        bel[tot] = tot;
        Add(tot, x); Add(tot, y);
        fa[x][0] = fa[y][0] = tot;
        //cerr << x << " " << y << " " << tot << " " << z << "\n";
    }
}

int idcnt;
int ida[MAXN], idb[MAXN];
int part[MAXN][2];
int sum[MAXN];

void DFS(int u) {
    for (int i = 1; i <= 20; i++) fa[u][i] = fa[fa[u][i - 1]][i - 1];
    if (u <= n) {
        ida[u] = ++idcnt;
        idb[idcnt] = u;
        part[u][0] = part[u][1] = idcnt;
        sum[u] = 1;
        return;
    }
    part[u][0] = 0x7fffffff; part[u][1] = 0;
    for (int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        DFS(v);
        part[u][0] = min(part[u][0], part[v][0]);
        part[u][1] = max(part[u][1], part[v][1]);
        sum[u] += sum[v];
    }
}

struct Node{
    int siz;
    Node *ch[2];
}npool[4000000];

int ncnt;
Node *rt[MAXN];

Node *New() {
    return &npool[ncnt++];
}

Node *Copy(Node *x) {
    npool[ncnt] = *x;
    return &npool[ncnt++];
}

void Insert(Node *&now, int bit, int k) {
    if (!now) now = New();
    else now = Copy(now);
    now->siz++;
    if (bit == 0) return;
    int f = (k & (1 << bit - 1)) ? 1 : 0;
    Insert(now->ch[f], bit - 1, k);
}

int Query(Node *now1, Node *now2, int bit, int k, int res) {
    if (bit == 0) return res;
    int rs2 = (now2 && now2->ch[1] ? now2->ch[1]->siz : 0);
    int rs1 = (now1 && now1->ch[1] ? now1->ch[1]->siz : 0);
    int rs = rs2 - rs1;
    if (k <= rs) return Query(now1 ? now1->ch[1] : NULL, now2 ? now2->ch[1] : NULL, bit - 1, k, res | (1 << bit - 1));
    else return Query(now1 ? now1->ch[0] : NULL, now2 ? now2->ch[0] : NULL, bit - 1, k - rs, res);
}

void Print(Node *now, int bit, int res) {
    if (!now) return;
    if (bit == 0) {
        cerr << res << " ";
        return;
    }
    Print(now->ch[0], bit - 1, res);
    Print(now->ch[1], bit - 1, res | (1 << bit - 1));
}

void Build() {
    for (int i = 1; i <= n; i++) {
        rt[i] = rt[i - 1];
        Insert(rt[i], 31, a[idb[i]]);
    }
}

void Calculate(int u, int k1, int k2) {
    for (int i = 20; i >= 0; i--) {
        if (a[fa[u][i]] <= k1) u = fa[u][i];
    }
    //cerr << u << "\n";
    if (sum[u] < k2) {
        cout << "-1\n";
        return;
    }
    int l = part[u][0], r = part[u][1];
    //cerr << l << " " << r << "\n";
    int res = Query(rt[l - 1], rt[r], 31, k2, 0);
    cout << res << "\n";
}

void Init() {
    cin >> n >> m >> q;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i <= m; i++) {
        cin >> e[i].x >> e[i].y >> e[i].z;
    }
    a[0] = 0x7fffffff;
    //cerr << "###\n";
    Kruskal();
    //cerr << "###\n";
    DFS(tot);
    /*for (int i = 1; i <= n; i++) cerr << ida[i] << " ";
    cerr << "\n";
    for (int i = 1; i <= n; i++) cerr << idb[i] << " ";
    cerr << "\n";
    for (int i = 1; i <= tot; i++) cerr << sum[i] << " ";
    cerr << "\n";
    cerr << "###\n";
    for (int i = 1; i <= tot; i++) {
        for (int j = 0; j < 3; j++) cerr << fa[i][j] << " ";
        cerr << "\n";
    }
    cerr << "###\n";*/
    Build();
    /*for (int i = 1; i <= n; i++) {
        Print(rt[i], 31, 0);
        cerr << "\n";
    }
    cerr << "###\n";*/
}

void Work() {
    int x, y, z;
    for (int i = 1; i <= q; i++) {
        cin >> x >> y >> z;
        Calculate(x, y, z);
    }
}

int main() {
    ios::sync_with_stdio(false); cin.tie(NULL);
    Init();
    Work();
    return 0;
}
/*
10 11 4
1 2 3 4 5 6 7 8 9 10
1 4 4
2 5 3
9 8 2
7 8 10
7 1 4
6 7 1
6 4 8
2 1 5
10 8 10
3 4 7
3 4 6
1 5 2
1 5 6
1 5 8
8 9 2
*/