跳转至

糖果

原题链接

题目描述

幼儿园里有 \(N\) 个小朋友,老师现在想要给这些小朋友们分配糖果,要求每个小朋友都要分到糖果。

但是小朋友们也有嫉妒心,总是会提出一些要求,比如小明不希望小红分到的糖果比他的多,于是在分配糖果的时候, 老师需要满足小朋友们的 \(K\) 个要求。

幼儿园的糖果总是有限的,老师想知道他至少需要准备多少个糖果,才能使得每个小朋友都能够分到糖果,并且满足小朋友们所有的要求。

  • 如果 \(X=1\),表示第 \(A\) 个小朋友分到的糖果必须和第 \(B\) 个小朋友分到的糖果一样多。
  • 如果 \(X=2\),表示第 \(A\) 个小朋友分到的糖果必须少于第 \(B\) 个小朋友分到的糖果。
  • 如果 \(X=3\),表示第 \(A\) 个小朋友分到的糖果必须不少于第 \(B\) 个小朋友分到的糖果。
  • 如果 \(X=4\),表示第 \(A\) 个小朋友分到的糖果必须多于第 \(B\) 个小朋友分到的糖果。
  • 如果 \(X=5\),表示第 \(A\) 个小朋友分到的糖果必须不多于第 \(B\) 个小朋友分到的糖果。

差分约束

首先,题目要求取最小值,需要用最长路来求,取所有 \(X_i\) 下界的最大值。

情况1:\(A=B\Rightarrow A\ge B,B\ge A\)

情况2:\(A<B\Rightarrow B\ge A+1\)

情况3:\(A\ge B\Rightarrow A\ge B\)

情况4:\(A>B\Rightarrow A\ge B+1\)

情况5:\(A\le B\Rightarrow B\ge A\)

限制条件:\(1\le A\le N\),因此需要建立一个超级源点 \(X_0=0\),为了满足 \(A\ge 1\),即连一条超级源点到各点,长度为 \(1\) 的边。

代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 100010, M = 300010;

int n, m;
int h[N], e[M], w[M], ne[M], idx;
LL dist[N];
int cnt[N], q[N];
bool st[N];

void add(int a, int b, int c) {
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

bool spfa() {
    memset(dist, -0x3f, sizeof dist);
    dist[0] = 0;

    int hh = 0, tt = 1;
    q[0] = 0;

    while (hh != tt) {
        int u = q[--tt];

        st[u] = false;

        for (int i = h[u]; i != -1; i = ne[i]) {
            int j = e[i];
            if (dist[j] < dist[u] + w[i]) {
                dist[j] = dist[u] + w[i];
                cnt[j] = cnt[u] + 1;
                if (cnt[j] >= n + 1) return false;
                if (!st[j]) {
                    q[tt++] = j;
                    st[j] = true;
                }
            }
        }
    }

    return true;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    memset(h, -1, sizeof h);
    cin >> n >> m;
    while (m--) {
        int x, a, b;
        cin >> x >> a >> b;
        if (x == 1) add(a, b, 0), add(b, a, 0);
        else if (x == 2) add(a, b, 1);
        else if (x == 3) add(b, a, 0);
        else if (x == 4) add(b, a, 1);
        else add(a, b, 0);
    }

    for (int i = 1; i <= n; i++) add(0, i, 1);

    if (!spfa()) cout << -1 << endl;
    else {
        LL res = 0;
        for (int i = 1; i <= n; i++) res += dist[i];
        cout << res << endl;
    }

    return 0;
}
回到页面顶部