记录编号 |
532934 |
评测结果 |
AAAAAAAAAA |
题目名称 |
[HZOI 2015]疯狂的求和问题 |
最终得分 |
100 |
用户昵称 |
神利·代目 |
是否通过 |
通过 |
代码语言 |
C++ |
运行时间 |
0.126 s |
提交时间 |
2019-06-08 20:15:52 |
内存使用 |
18.23 MiB |
显示代码纯文本
#define _CRT_SECURE_NO_WARNINGS
#include<stdio.h>
#define maxn 500010
#define mod 998244353
int mypow(int x, int y) {
if (x == -1)return y & 1 ? -1 : 1;
int res = 1;
while (y) {
if (y & 1)res = 1ll * res * x % mod;
y >>= 1;
x = 1ll * x * x % mod;
}
return res;
}
void get_all(int n, int k, int* prime, int* g, bool* flag) {
g[1] = 1;
for (int i = 2; i <= n; ++i) {
if (!flag[i]) {
prime[++prime[0]] = i;
g[i] = mypow(i, k);
}
for (int j = 1; j <= prime[0] && i * prime[j] <= n; ++j) {
flag[i * prime[j]] = 1;
g[i * prime[j]] = 1ll * g[i] * g[prime[j]] % mod;
if (i % prime[j] == 0)
break;
}
}
}
int n, k, P[maxn], invfac[maxn], fac[maxn], inv[maxn];
int premul[maxn], sufmul[maxn];
//在读入的时候对n取模
void read(register int& x) {
x = 0; register int c = getchar(); for (; c<'0' || c>'9'; c = getchar());
for (; c >= '0' && c <= '9'; c = getchar()) x = ((x * 10ll) + (c ^ 48ll)) % mod; return;
}
int prime[maxn];
bool flag[maxn];
int main() {
freopen("Crazy_Sum.in", "r", stdin);
freopen("Crazy_Sum.out", "w", stdout);
read(n), read(k);
get_all(k+1, k, prime, P, flag);
for (int i = 1; i <= k + 1; ++i)
P[i] = (P[i] + P[i - 1]) % mod;
fac[0] = 1;
for (int i = 1; i <= k + 1; ++i)
fac[i] = 1ll * fac[i - 1] * i % mod;
inv[1] = 1;
for (int i = 2; i <= k + 1; ++i)
inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
invfac[0] = 1;
for (int i = 1; i <= k + 1; ++i)
invfac[i] = 1ll * invfac[i - 1] * inv[i] % mod;
if (n <= k + 1) {
printf("%d\n", P[n]);
return 0;
}
//预处理:premul[0 ... k+2]
premul[0] = 1;
for (int i = 1; i <= k + 2; ++i)
premul[i] = 1ll * premul[i - 1] * (n - i + 1) % mod;
//预处理:sufmul[1 ... k+3]
sufmul[k + 3] = 1;
for (int i = k + 2; i; --i)
sufmul[i] = 1ll * sufmul[i + 1] * (n - i + 1) % mod;
int ans = 0;
for (int i = 0; i <= k + 1; ++i) {
int temp = 1ll * premul[i] * sufmul[i + 2] % mod * invfac[i] % mod * invfac[k + 1 - i] % mod * mypow(-1, k + 1 - i);
ans = (ans + 1ll * P[i] * temp) % mod;
}
printf("%d\n", (ans % mod + mod) % mod);
}