记录编号 616189 评测结果 AAAAAAAAAA
题目名称 数列求和 最终得分 100
用户昵称 GravatarRpUtl 是否通过 通过
代码语言 C++ 运行时间 6.765 s
提交时间 2026-05-31 09:20:13 内存使用 45.71 MiB
显示代码纯文本
#include <bits/stdc++.h>
#define pb push_back
#define int long long
using namespace std;

int MOD;
int a;
int C[2010][2010];

// 前缀和表:pre[n][k] = sum_{i=1..n} (i^k * a^i) % MOD
int pre[2005][2010];
bool pre_ok = false; // 是否已建立前缀和

// 快读
inline int read() {
    int x = 0;
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + (c ^ 48), c = getchar();
    return x;
}

// 初始化组合数,轻微优化:利用对称性和局部变量
void init() {
    const int up = 2005;
    C[0][0] = 1;
    for (int i = 1; i <= up; i++) {
        C[i][0] = C[i][i] = 1;
        int* row = C[i];
        int* prev = C[i - 1];
        for (int j = 1; j < i; j++) {
            int v = prev[j - 1] + prev[j];
            row[j] = (v >= MOD) ? (v - MOD) : v;
        }
    }
}

inline int qpow(int a, int b) {
    a %= MOD;
    int res = 1;
    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

int ans[2010], tmp[2010];
vector<int> vec;

signed main() {
     freopen("oeis.in","r",stdin);
     freopen("oeis.out","w",stdout);
    int N, K;
    N = read(); a = read(); K = read(); MOD = read();
    init();

    int num = N;
    while (num) {
        vec.pb(num);
        num /= 2;
    }
    vec.pb(-10);
    sort(vec.begin(), vec.end());

    const int MOD_local = MOD; 

    for (int idx = 1; idx < (int)vec.size(); idx++) {
        int n = vec[idx];
        if (n <= 2000) {
            if (!pre_ok) {
                int cur_pow[2010]; 
                for (int i = 1; i <= 2000; i++) {
                    cur_pow[0] = 1;
                    for (int k = 1; k <= K; k++) {
                        cur_pow[k] = cur_pow[k - 1] * i % MOD_local;
                    }
                    int ai = qpow(a, i); 
                    int* pre_row = pre[i];
                    int* pre_prev = pre[i - 1];
                    for (int k = 0; k <= K; k++) {
                        int add_val = cur_pow[k] * ai % MOD_local;
                        int v = pre_prev[k] + add_val;
                        pre_row[k] = (v >= MOD_local) ? (v - MOD_local) : v;
                    }
                }
                pre_ok = true;
            }
            memcpy(ans, pre[n], sizeof(ans[0]) * (K + 1));
        } else {
            int original_n = n;
            int odd_contrib = 0;
            if (n & 1) {
                int a_n = qpow(a, n);
                int n_mod = n % MOD_local;
                int n_pow = 1; 
                for (int k = 0; k <= K; k++) {
                    if (k > 0) n_pow = n_pow * n_mod % MOD_local;
                    int term = n_pow * a_n % MOD_local;
                    tmp[k] = ans[k];
                }
                n--; 
            }
            int mid = n / 2;
            int a_mid = qpow(a, mid);
            int mid_mod = mid % MOD_local;
            int odd_term[2010];
            if (original_n & 1) {
                int a_n = qpow(a, original_n);
                int n_mod = original_n % MOD_local;
                int n_pow = 1;
                for (int k = 0; k <= K; k++) {
                    odd_term[k] = n_pow * a_n % MOD_local;
                    n_pow = n_pow * n_mod % MOD_local;
                }
            }
            for (int k = 0; k <= K; k++) {
                long long sum_res = (original_n & 1) ? odd_term[k] : 0;
                sum_res += ans[k];
                if (sum_res >= MOD_local) sum_res -= MOD_local;
                __int128 inner_sum = sum_res;
                long long now = a_mid;
                int* C_k = C[k]; 
                int* ans_ptr = ans; 
                for (int j = k; j >= 0; j--) {
                    inner_sum += (__int128)now * C_k[j] * ans_ptr[j];
                    now = now * mid_mod % MOD_local;
                }
                tmp[k] = (int)(inner_sum % MOD_local);
            }

            memcpy(ans, tmp, sizeof(ans[0]) * (K + 1));
        }
    }

    printf("%lld\n", ans[K]);
    return 0;
}