1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| #include <bits/stdc++.h>
const int N = 500000 + 233, P = 1e9 + 7;
typedef long long ll;
inline ll fpow(ll x, int y) { ll ret = 1; for ( ; y; y >>= 1, x = x * x % P) if (y & 1) ret = ret * x % P; return ret; }
int n, m; ll fac[N], inv[N], inv_num[N], pre[N], pow_m[N], A[N], B[N], p[N], pre_all[N], suf_all[N];
int prime[N], vis[N], tot;
inline void init() { pow_m[0] = pow_m[1] = 1; for (int i = 1; i <= m + 1; ++i) { if (!vis[i]) { prime[++tot] = i; pow_m[i] = fpow(i, m); } for (int j = 1; j <= tot && i * prime[j] <= m + 1; ++j) { vis[i * prime[j]] = 1; pow_m[i * prime[j]] = pow_m[i] * pow_m[prime[j]] % P; if (!(i % prime[j])) break; } }
for (int i = fac[0] = 1; i <= m + 1; ++i) fac[i] = fac[i - 1] * i % P; inv[m + 1] = fpow(fac[m + 1], P - 2); for (int i = m + 1; i; --i) inv[i - 1] = inv[i] * i % P; inv_num[1] = 1; for (int i = 2; i <= m + 1; ++i) inv_num[i] = P - 1ll * P / i * inv_num[P % i] % P;
for (int i = pre_all[0] = suf_all[m + 2] = 1; i <= m + 1; ++i) pre_all[i] = pre_all[i - 1] * (n - i) % P; for (int i = m + 1; i; --i) suf_all[i] = suf_all[i + 1] * (n - i) % P; }
inline ll solve() { A[0] = 1, A[1] = inv_num[m], B[1] = pow_m[1]; for (int i = 2; i <= m + 1; ++i) { A[i] = A[i - 1] * inv_num[m] % P; B[i] = (B[i - 1] * inv_num[m] % P + pow_m[i]) % P; }
long long sum_A = 0, sum_B = 0; for (int i = 0; i <= m + 1; ++i) { long long tmp = fac[m + 1] * inv[i] % P * inv[m + 1 - i] % P; if ((m - i) & 1) tmp = P - tmp; sum_A = (sum_A + tmp * A[i]) % P; sum_B = (sum_B + tmp * B[i]) % P; }
p[0] = P - sum_B * fpow(sum_A, P - 2) % P; for (int i = 1; i <= m + 1; ++i) p[i] = (A[i] * p[0] + B[i]) % P;
long long ret = 0;
for (int i = 1; i <= m + 1; ++i) { long long tmp = p[i]; tmp = tmp * pre_all[i - 1] % P * suf_all[i + 1] % P; tmp = tmp * inv[i - 1] % P * inv[m + 1 - i] % P; if ((m + 1 - i) & 1) tmp = P - tmp; ret = (ret + tmp) % P; }
ret = fpow(m, n) * ret % P - p[0]; ret = (ret % P + P) % P; return ret; }
int main() { std::cin >> n >> m; if (m == 1) return std::cout << 1ll * n * (n + 1) / 2 << std::endl, 0; init(); if (m >= n) { long long ans = 0, tmp = 1; for (int i = 1; i <= n; ++i) { tmp = tmp * m % P; ans = (ans + pow_m[i] * tmp) % P; } std::cout << ans << std::endl; } else std::cout << solve() << std::endl; return 0; }
|