序列取数

Luogu P1430 序列取数 蒟蒻觉得很复杂的区间dp题… $A, B$两人从给定的长度为$N$的序列中轮流取数,都取自己最大结果,求A最终得分 当$A$取得数字和最大时,由于总和一定,$B$取得的数字和最小 让$dp[l, r]$表示取到$[l, r]$时,$A$的得分减去$B$的得分 若更新$dp[l, r]$,区间$[l, r]$中的每一段都应该已被更新 对于这一区间,有三种操作方法:

  • $A$选取区间内所有的数
  • $A$从左到右选取一部分
  • $A$从右到左选取一部分

对于后两种情况,以$A$选取了$[l, i]$,$B$选取了$[i+1, r]$为例: 前半段$A$选取$S[A]$,后半段$S[A_1]$,$B$选取$S[B_1]$ $dp[l, r] = S[A] + S[A_1] - S[B_1] = S[A] - (S[B_1] - S[A_1])$ 数组$dp$表示的是$A$先选取的情况,若$B$先选取,则$dp$意义变为$B$的得分减去$A$的得分 我们求出前缀和数组$sum$,则有: $dp[l, r] = S[A] - dp[i+1, r] = sum[i - 1] - sum[l - 1] - dp[i][r]$,$l + 1 \leq i \leq r$ 基于此思路写出$O(n^3)$算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>
using namespace std;

inline int read()
{
int a = 1,b = 0; char c;
do {c=getchar(); if (c == '-') a = -1;} while (c < '0' c > '9');
do {b = b * 10 + c - '0'; c = getchar();} while (c >= '0' && c <= '9');
return a * b;
}

int a[1010];
int sum[1010];
int dp[1010][1010];

int main()
{
int T = read();
while (T--) {
int n = read();
for (int i = 1; i <= n; ++i) {
a[i] = read();
sum[i] = sum[i - 1] + a[i];
}
for (int len = 1; len <= n; ++len) {
for (int l = 1, r = len; r <= n; ++l, ++r) {
dp[l][r] = sum[r] - sum[l - 1];
for (int i = l + 1; i <= r; ++i) {
dp[l][r] = max(dp[l][r], sum[i - 1] - sum[l - 1] - dp[i][r]);
}
for (int i = r - 1; i >= l; --i) {
dp[l][r] = max(dp[l][r], sum[r] - sum[i] - dp[l][i]);
}
}
}
printf("%d\n", (sum[n] + dp[1][n]) / 2);
}
return 0;
}

40分,爽到……


于是考虑优化到$O(n^2)$的算法: 考虑将dp转移方程种有i的两项放在一起,我们可以维护两个数组,保存它们的最大或最小值,省去一次循环 仍然以刚才的情况举例: $sum[i - 1] - sum[l - 1] - dp[i][r] = (sum[i - 1] - dp[i][r]) - sum[l - 1]$ 令$P[l][r] = sum[l - 1] - dp[l][r]$,$left[l][r] = min(P[i][j] \mid l \leq i \leq j \leq r) $ $dp[l][r]$最大,则$P[l][r]$最大,维护$left$最大值即可 $dp[l][r] = max(dp[l][r], left[l+1][r] - sum[l-1])$ $left[l][r] = max(left[l+1][r], sum[l-1] - dp[l][r])$ 如果从$len = 1$开始,会出现$l = r = 1$,$l + 1 > r$,所以$len = 2$开始 为所有$len = 1$,即$l = r$增加一段预处理即可 最终AC代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <bits/stdc++.h>
using namespace std;

inline int read()
{
int a = 1,b = 0; char c;
do { c=getchar(); if (c == '-') a = -1; } while (c < '0' c > '9');
do { b = b * 10 + c - '0'; c = getchar(); } while (c >= '0' && c <= '9');
return a * b;
}

const int MAXN = 1010;
int a[MAXN], sum[MAXN], dp[MAXN][MAXN];
int lft[MAXN][MAXN], rigt[MAXN][MAXN];

int main()
{
int T = read();
while (T--) {
int n = read();
for (int i = 1; i <= n; ++i) {
a[i] = read();
sum[i] = sum[i - 1] + a[i];
}
memset(dp, 0, sizeof(dp));
memset(lft, 0, sizeof(lft));
memset(rigt, 0, sizeof(rigt));
for (int i = 1; i <= n; ++i) {
dp[i][i] = a[i];
lft[i][i] = sum[i - 1] - dp[i][i];
rigt[i][i] = dp[i][i] + sum[i];
}
for (int len = 2; len <= n; ++len) {
for (int l = 1, r = len; r <= n; ++l, ++r) {
dp[l][r] = sum[r] - sum[l - 1];
dp[l][r] = max(dp[l][r], lft[l + 1][r] - sum[l - 1]);
dp[l][r] = max(dp[l][r], sum[r] - rigt[l][r - 1]);
lft[l][r] = max(lft[l + 1][r], sum[l - 1] - dp[l][r]);
rigt[l][r] = min(rigt[l][r - 1], dp[l][r] + sum[r]);
}
}
printf("%d\n", (dp[1][n] + sum[n]) / 2);
}
return 0;
}