跳转至

距离

原题链接

题目描述

给出 \(n\) 个点的一棵树,多次询问两点之间的最短距离。

注意:

边是无向的。

所有节点的编号是 \(1,2,…,n\)


最近公共祖先

倍增LCA

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>
using namespace std;

const int N = 10010, M = 20010;

int n, m;
int h[N], e[M], w[M], ne[M], idx;
int depth[N], cost[N][15], fa[N][15];

void add(int a, int b, int c) {
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

void bfs(int root) {
    memset(depth, 0x3f, sizeof depth);
    depth[0] = 0;
    depth[root] = 1;

    queue<int> q;
    q.push(root);

    while (q.size()) {
        int u = q.front();
        q.pop();

        for (int i = h[u]; i != -1; i = ne[i]) {
            int j = e[i];
            if (depth[j] > depth[u] + 1) {
                depth[j] = depth[u] + 1;
                q.push(j);
                fa[j][0] = u;
                cost[j][0] = w[i];
                for (int k = 1; k < 15; k++) {
                    fa[j][k] = fa[fa[j][k - 1]][k - 1];
                    cost[j][k] = cost[fa[j][k - 1]][k - 1] + cost[j][k - 1];
                }
            }
        }
    }
}

int lca(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);

    int res = 0;
    for (int k = 14; k >= 0; k--)
        if (depth[fa[a][k]] >= depth[b]) {
            res += cost[a][k];
            a = fa[a][k];
        }

    if (a == b) return res;

    for (int k = 14; k >= 0; k--)
        if (fa[a][k] != fa[b][k]) {
            res += cost[a][k] + cost[b][k];
            a = fa[a][k];
            b = fa[b][k];
        }
    res += cost[a][0] + cost[b][0];

    return res;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    cin >> n >> m;
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i++) {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }

    bfs(1);

    while (m--) {
        int a, b;
        cin >> a >> b;
        cout << lca(a, b) << endl;
    }

    return 0;
}

Tarjan

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>
using namespace std;

#define ff first
#define ss second

typedef pair<int, int> PII;

const int N = 10010, M = 20010;

// 2号点:代表已经访问并结束回溯
// 1号点:代表正在访问
// 0号点:代表还没有访问过

int n, m;
int h[N], e[M], w[M], ne[M], idx;
int dist[N], st[N], p[N], res[M];
vector<PII> query[N];

void add(int a, int b, int c) {
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u, int fa) {
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        dist[j] = dist[u] + w[i];
        dfs(j, u);
    }
}

int find(int x) {
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

void tarjan(int u) {
    st[u] = 1;
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if (!st[j]) {
            tarjan(j);
            p[j] = u;
        }
    }

    for (auto it : query[u]) {
        int y = it.ff, id = it.ss;
        if (st[y] == 2) {
            int anc = find(y);
            res[id] = dist[u] + dist[y] - dist[anc] * 2;
        }
    }
    st[u] = 2;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    memset(h, -1, sizeof h);
    cin >> n >> m;
    for (int i = 0; i < n - 1; i++) {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }

    for (int i = 0; i < m; i++) {
        int a, b;
        cin >> a >> b;
        query[a].push_back({b, i});
        query[b].push_back({a, i});
    }

    for (int i = 1; i <= n; i++) p[i] = i;

    dfs(1, -1);
    tarjan(1);

    for (int i = 0; i < m; i++) cout << res[i] << endl;

    return 0;
}
回到页面顶部