国王奇遇记 线性插值

国王奇遇记加强版之再加强版 本题有$O(m^2)$做法,网上很多,这里只有$O(m)$。 题意:求 $$\sum_{i=1}^n i^m m^i$$ 令$F_m(n)$等于这个式子。经过大胆打表看题解,可以知道 $$F_m(n) = m^n P_m(n) - P_m(0)$$ 其中$P_m$是$m$次多项式。 下面开始不写$_m$了,懒w 这个结论虽然不知道怎么想出来的,但是可以证明,大约想法就是归纳+差分。 考虑怎么求$P$: $$F(n+1) - F(n) = (n+1)^m m^{n+1} = m^{n+1} P(n+1) - m^n P(n) $$ $$P(n+1) = \frac{P(n)}{m} + (m+1)^m$$ 把$P(n)$表示为$AP(0)+B$。 因为$P$是$m$次多项式,做$m+1$次差分: $$\sum_{i=0}^{m+1} \binom{m+1}{i} (-1)^{m - i} P(i) = 0 $$ 就可以解出$A,B$啦。 这篇文章中有一种给出$1,m+1$处点值插值的方法,推到用到了很多组合技巧,很有意思。 不过其实直接拉格朗日插值就行。。。结合一些预处理可以做到$O(m)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include <bits/stdc++.h>

const int N = 500000 + 233, P = 1e9 + 7;

typedef long long ll;

inline ll fpow(ll x, int y) {
ll ret = 1;
for ( ; y; y >>= 1, x = x * x % P)
if (y & 1) ret = ret * x % P;
return ret;
}

int n, m;
ll fac[N], inv[N], inv_num[N], pre[N], pow_m[N],
A[N], B[N], p[N], pre_all[N], suf_all[N];

int prime[N], vis[N], tot;

inline void init() {
pow_m[0] = pow_m[1] = 1;
for (int i = 1; i <= m + 1; ++i) {
if (!vis[i]) {
prime[++tot] = i;
pow_m[i] = fpow(i, m);
}
for (int j = 1; j <= tot && i * prime[j] <= m + 1; ++j) {
vis[i * prime[j]] = 1;
pow_m[i * prime[j]] = pow_m[i] * pow_m[prime[j]] % P;
if (!(i % prime[j]))
break;
}
}

for (int i = fac[0] = 1; i <= m + 1; ++i)
fac[i] = fac[i - 1] * i % P;
inv[m + 1] = fpow(fac[m + 1], P - 2);
for (int i = m + 1; i; --i)
inv[i - 1] = inv[i] * i % P;
inv_num[1] = 1;
for (int i = 2; i <= m + 1; ++i)
inv_num[i] = P - 1ll * P / i * inv_num[P % i] % P;

for (int i = pre_all[0] = suf_all[m + 2] = 1; i <= m + 1; ++i)
pre_all[i] = pre_all[i - 1] * (n - i) % P;
for (int i = m + 1; i; --i)
suf_all[i] = suf_all[i + 1] * (n - i) % P;
}

inline ll solve() {
A[0] = 1, A[1] = inv_num[m], B[1] = pow_m[1];
for (int i = 2; i <= m + 1; ++i) {
A[i] = A[i - 1] * inv_num[m] % P;
B[i] = (B[i - 1] * inv_num[m] % P + pow_m[i]) % P;
}

long long sum_A = 0, sum_B = 0;
for (int i = 0; i <= m + 1; ++i) {
long long tmp = fac[m + 1] * inv[i] % P * inv[m + 1 - i] % P;
if ((m - i) & 1)
tmp = P - tmp;
sum_A = (sum_A + tmp * A[i]) % P;
sum_B = (sum_B + tmp * B[i]) % P;
}

p[0] = P - sum_B * fpow(sum_A, P - 2) % P;
for (int i = 1; i <= m + 1; ++i)
p[i] = (A[i] * p[0] + B[i]) % P;

long long ret = 0;

for (int i = 1; i <= m + 1; ++i) {
long long tmp = p[i];
tmp = tmp * pre_all[i - 1] % P * suf_all[i + 1] % P;
tmp = tmp * inv[i - 1] % P * inv[m + 1 - i] % P;
if ((m + 1 - i) & 1)
tmp = P - tmp;
ret = (ret + tmp) % P;
}

ret = fpow(m, n) * ret % P - p[0];
ret = (ret % P + P) % P;
return ret;
}

int main() {
std::cin >> n >> m;
if (m == 1)
return std::cout << 1ll * n * (n + 1) / 2 << std::endl, 0;
init();
if (m >= n) {
long long ans = 0, tmp = 1;
for (int i = 1; i <= n; ++i) {
tmp = tmp * m % P;
ans = (ans + pow_m[i] * tmp) % P;
}
std::cout << ans << std::endl;
} else
std::cout << solve() << std::endl;
return 0;
}