JSOI2008最小生成树计数 并查集最小生成树

[JSOI2008]最小生成树计数 作为一枚蒟蒻,我完全没学过最小生成树的说!! 最小生成树有几个好的性质,这里给出简单证明。详细证明参考这里Sengxian’s Blog 性质1:一张图的所有最小生成树,把边排序后,对应下标的边边权相等。 证明:找到第一个选择的边不等的位置,记为$a_i,b_i$。如果$a_i$在$b$中出现,出现在$b_j$,那么有$j > i$,$[b_i,b_j]$都相等。如果没有出现,把$a_i$加进$b$中,形成一个环。这个环上所有边权一定$\leq a_i$。为了顶替$a_i$,一定有一个$b_j = a_i, j > i$。 性质2:最小生成树,从小到大加边,加完某种边权后,连通性相同 证明:如果不同,考虑Kruskal算法的过程,一定会把差异的边加上。 性质3:如果某种边权在一种最小生成树中出现$k$次,在此边权中任意选$k$条边加上,只要没有形成环,都是可行方案。 证明:如果没有形成环,一定形成了树,且边权,连通性相同。 有了这三个定理,又因为每种边权不超过$10$条,可以枚举使用了哪些边来求解答案。使用可撤销并查集,复杂度$O(2^{10} m logn)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <bits/stdc++.h>

inline int rd() {
int a = 1, b = 0; char c = getchar();
while (!isdigit(c)) a = c == '-' ? 0 : 1, c = getchar();
while (isdigit(c)) b = b * 10 + c - '0', c = getchar();
return a ? b : -b;
}

const int N = 1e3 + 233, P = 31011;

int n, m, ans, fa[N], vis[N], st[N], pick[N], tot;

struct Edge {
int x, y, c;
} e[N];

bool operator<(const Edge &a, const Edge &b) {
return a.c < b.c;
}

int find(int x) {
if (x == fa[x])
return x;
return find(fa[x]);
}

int dfs(int s, int d, int p) {
if (d == st[s + 1]) {
if (p == pick[s])
return 1;
return 0;
}
int x = find(e[d].x), y = find(e[d].y), ret = 0;
if (x != y) {
fa[x] = y;
ret += dfs(s, d + 1, p + 1);
fa[x] = x;
}
return ret + dfs(s, d + 1, p);
}

int main() {
n = rd(), m = rd();
for (int i = 1; i <= m; ++i)
e[i].x = rd(), e[i].y = rd(), e[i].c = rd();
std::sort(e + 1, e + m + 1);
for (int i = 1; i <= n; ++i)
fa[i] = i;
for (int i = 1; i <= m; ++i) {
int x = find(e[i].x), y = find(e[i].y);
if (x != y) {
fa[x] = y;
vis[i] = 1;
++ans;
}
}
if (ans != n - 1) {
puts("0");
return 0;
}
for (int i = 1; i <= m; ++i)
if (e[i].c != e[i - 1].c)
st[++tot] = i;
st[tot + 1] = m + 1;
for (int i = 1; i <= tot; ++i)
for (int j = st[i]; j < st[i + 1]; ++j)
if (vis[j])
++pick[i];
for (int i = 1; i <= n; ++i)
fa[i] = i;
ans = 1;
for (int s = 1; s <= tot; ++s) {
ans = ans * dfs(s, st[s], 0) % P;
for (int i = st[s]; i < st[s + 1]; ++i) {
int x = find(e[i].x), y = find(e[i].y);
if (x != y)
fa[x] = y;
}
}
printf("%d\n", ans);
return 0;
}