比赛 noi2017模板练习+ 评测结果 AAAAAAAAAA
题目名称 求和 最终得分 100
用户昵称 lemonoil 运行时间 7.185 s
代码语言 C++ 内存使用 13.57 MiB
提交时间 2017-07-18 19:26:42
显示代码纯文本
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#define maxn 500010
using namespace std;
typedef long long ll;
const ll md = 998244353, G = 3;
int n;
 
ll g[maxn], f[maxn], h[maxn];
 
ll power_mod(ll a, ll b = md - 2){
    ll ret = 1;
    while(b > 0){
        if(b & 1)ret = ret * a % md;
        b >>= 1;
        a = a * a % md;
    }return ret;
}
 
 
void NTT(ll A[], int n, int type){
    for(int i = 0, j = 0; i < n; i ++){
        if(i > j)swap(A[i], A[j]);
        for(int t = n >> 1; (j ^= t) < t; t >>= 1);
    }
 
    for(int k = 2; k <= n; k <<= 1){
        ll wn = power_mod(G, type > 0 ? (md-1)/k : md-1-(md-1)/k);
        for(int i = 0; i < n; i += k){
            ll w = 1;
            for(int j = 0; j < k >> 1; j ++){
                ll T = w * A[i+j+(k>>1)] % md;
                A[i+j+(k>>1)] = (A[i+j] - T + md) % md;
                A[i+j] = (A[i+j] + T) % md;
                w = w * wn % md;
            }
        }
    }
 
    if(type < 0){
        ll inv = power_mod(n);
        for(int i = 0; i < n; i ++)
            (A[i] *= inv) %= md;
    }
}
 
ll inv[maxn], fac[maxn];
 
void solve(int l, int r){
    if(l == r)return;
    int mid = l + r >> 1;
    solve(l, mid);
    int len = r - l + 1, n;
    for(n = 1; n <= len; n <<= 1);
    for(int i = 0; i < n; i ++)f[i] = h[i] = 0;
    for(int i = 0; i < n; i ++)f[i] = inv[i];
    for(int i = l; i <= mid; i ++)h[i - l] = g[i];
    NTT(f, n, 1), NTT(h, n, 1);
    for(int i = 0; i < n; i ++)
        h[i] = f[i] * h[i] % md;
    NTT(h, n, -1);
    for(int i = mid + 1; i <= r; i ++)
        (g[i] += 2 * h[i - l] % md) %= md;
    solve(mid + 1, r);
}
 
 
int main(){
    freopen("heoi2016_sum.in", "r", stdin);
    freopen("heoi2016_sum.out", "w", stdout);
    scanf("%d", &n);
    g[0] = inv[0] = fac[0] = 1;
    for(int i = 1; i <= n; i ++)
        fac[i] = fac[i-1] * i % md;
    inv[n] = power_mod(fac[n]);
    for(int i = n - 1; i; i --)
        inv[i] = inv[i+1] * (i+1) % md;
 
    solve(0, n);
    ll ans = 0;
    for(int i = 0; i <= n; i ++)
        (ans += g[i] * fac[i] % md) %= md;
    (ans += md) %= md;
    printf("%lld\n", ans);
    return 0;
}