子异和 组合线段树

子异和 分别考虑每一位对答案的贡献。假设第$i$位有$n$个数,$m$个$1$。贡献为: $$\sum_{2 \mid i, 1 \leq i \leq n} \binom{m}{i} 2^{n-m} 2^i = 2^{n-1}2^i $$ 这是一个好的性质,它告诉我们,答案与$1$的个数无关。 用线段树维护区间$or$值。然而,区间$or$并不能直接$xor$,需要维护额外信息。 考虑维护区间$0,1$分别有多少个。可以发现多少个不重要,我们只关心存在与否。进一步的,可以通过维护$and,or$两个值进行求解。 讨论进行更新。如果某一位$xor$的是$0$,显然不影响。$or,and$的关系,有$(0,0),(1,0),(1,1)$三种。简单计算后可以发现,新的$and,or$分别是原来的两个值取反后,$and,or$的结果。 使用树剖+线段树即可,LCT也可以。事实证明,理论复杂度更优的LCT跑的更慢。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#include <bits/stdc++.h>

inline int rd() {
int a = 1, b = 0; char c = getchar();
while (!isdigit(c)) a = c == '-' ? 0 : 1, c = getchar();
while (isdigit(c)) b = b * 10 + c - '0', c = getchar();
return a ? b : -b;
}

const int N = 200000 + 2333, P = 1e9 + 7;

inline int fpow(int x, int y) {
int ret = 1;
for ( ; y; y >>= 1) {
if (y & 1)
ret = (long long)ret * x % P;
x = (long long)x * x % P;
}
return ret;
}

int n, m, val[N];

struct Graph {
int to, nxt;
} g[N * 2];
int head[N], tot;

void addedge(int x, int y) {
g[++tot].to = y, g[tot].nxt = head[x],
head[x] = tot;
}

int fa[N], son[N], size[N], wt[N], num, dep[N], id[N], top[N];

void dfs1(int x, int fat) {
fa[x] = fat; dep[x] = dep[fat] + 1; size[x] = 1;
for (int i = head[x]; i; i = g[i].nxt) {
int y = g[i].to;
if (y != fat) {
dfs1(y, x);
if (size[son[x]] < size[y])
son[x] = y;
size[x] += size[y];
}
}
}

void dfs2(int x, int topf) {
top[x] = topf; id[x] = ++num; wt[num] = val[x];
if (!son[x]) return;
dfs2(son[x], topf);
for (int i = head[x]; i; i = g[i].nxt) {
int y = g[i].to;
if (y != fa[x] && y != son[x])
dfs2(y, y);
}
}

struct SegTree {
unsigned val_and, val_or, tag;
} t[N * 4];
#define ls(p) p << 1
#define rs(p) p << 1 1

void pushup(int p) {
t[p].val_and = t[ls(p)].val_and & t[rs(p)].val_and;
t[p].val_or = t[ls(p)].val_or t[rs(p)].val_or;
}

void push(int p, unsigned v) {
unsigned x = t[p].val_and, y = t[p].val_or;
t[p].val_and = (x ^ v) & (y ^ v);
t[p].val_or = (x ^ v) (y ^ v);
t[p].tag ^= v;
}

void pushdown(int p) {
if (t[p].tag) {
push(ls(p), t[p].tag);
push(rs(p), t[p].tag);
t[p].tag = 0;
}
}

void build(int p, int l, int r) {
if (l == r) {
t[p].val_and = t[p].val_or = wt[l];
return;
}
int mid = (l + r) >> 1;
build(ls(p), l, mid);
build(rs(p), mid + 1, r);
pushup(p);
}

void change(int p, int l, int r, unsigned v, int L, int R) {
if (l <= L && r >= R) {
push(p, v);
return;
}
pushdown(p);
int mid = (L + R) >> 1;
if (l <= mid)
change(ls(p), l, r, v, L, mid);
if (r > mid)
change(rs(p), l, r, v, mid + 1, R);
pushup(p);
}

unsigned query(int p, int l, int r, int L, int R) {
if (l <= L && r >= R)
return t[p].val_or;
pushdown(p);
int mid = (L + R) >> 1;
unsigned ret = 0;
if (l <= mid)
ret = query(ls(p), l, r, L, mid);
if (r > mid)
ret = query(rs(p), l, r, mid + 1, R);
return ret;
}

void range_change(int x, int y, int v) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]])
std::swap(x, y);
change(1, id[top[x]], id[x], v, 1, n);
x = fa[top[x]];
}
if (dep[x] > dep[y])
std::swap(x, y);
change(1, id[x], id[y], v, 1, n);
}

int range_query(int x, int y) {
int cnt = 0;
unsigned val = 0;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]])
std::swap(x, y);
val = query(1, id[top[x]], id[x], 1, n);
cnt += dep[x] - dep[top[x]] + 1;
x = fa[top[x]];
}
if (dep[x] > dep[y])
std::swap(x, y);
val = query(1, id[x], id[y], 1, n);
cnt += dep[y] - dep[x] + 1;
return (long long)val * fpow(2, cnt - 1) % P;
}

int main() {
// freopen("data.txt", "r", stdin);
n = rd(), m = rd();
for (int i = 1; i < n; ++i) {
int x = rd(), y = rd();
addedge(x, y);
addedge(y, x);
}
for (int i = 1; i <= n; ++i)
val[i] = rd();
dfs1(1, 0);
dfs2(1, 1);
build(1, 1, n);
while (m--) {
int op = rd();
if (op == 1) {
int x = rd(), y = rd();
printf("%d\n", range_query(x, y));
} else {
int x = rd(), y = rd(), z = rd();
range_change(x, y, z);
}
}
return 0;
}