显示代码纯文本
#include <cstring>
#include <cstdio>
#include <algorithm>
#define ll long long
const int MOD = 998244353;
inline int mul(int x, int y) { return (ll)x * (ll)y % MOD; }
inline int add(int x, int y) { int r = x + y; if(r >= MOD) r -= MOD; return r; }
inline int qpow(int a, int p) {
int x = a, ans = 1;
while(p) {
if(p & 1) ans = mul(ans, x);
x = mul(x, x);
p >>= 1;
}
return ans;
}
inline int inv(int x) { return qpow(x, MOD - 2); }
using namespace std;
inline void read(int &x) {
char ch; while((ch = getchar()), (ch < '0' || ch > '9'));
x = ch - '0'; while((ch = getchar()), (ch >= '0' && ch <= '9')) x = x * 10 + (ch - '0');
}
inline void readmod(int &x) {
char ch; while((ch = getchar()), (ch < '0' || ch > '9'));
x = ch - '0'; while((ch = getchar()), (ch >= '0' && ch <= '9')) x = mul(x, 10) + (ch - '0');
}
const int MAXK = 500010;
int K, N, V[MAXK];
int VUp = 1, fact[MAXK], negfact[MAXK];
int main() {
freopen("Crazy_Sum.in", "rt", stdin);
freopen("Crazy_Sum.out", "wt", stdout);
int i, Ans = 0;
readmod(N); read(K);
fact[0] = negfact[0] = 1;
for(i = 1; i <= K + 2; i++) {
V[i] = add(V[i - 1], qpow(i, K));
VUp = mul(VUp, N - i);
fact[i] = mul(fact[i - 1], i);
negfact[i] = mul(negfact[i - 1], MOD - i);
}
for(i = 1; i <= K + 2; i++) Ans = add(Ans,
mul(
V[i],
mul( inv(mul( fact[i - 1], negfact[K + 2 - i] )), mul( VUp, inv(N - i) ) )
)
);
printf("%d\n", Ans);
}