| 记录编号 |
615990 |
评测结果 |
AAAAAAAAAA |
| 题目名称 |
数列求和 |
最终得分 |
100 |
| 用户昵称 |
RpUtl |
是否通过 |
通过 |
| 代码语言 |
C++ |
运行时间 |
1.282 s |
| 提交时间 |
2026-05-29 20:19:29 |
内存使用 |
15.12 MiB |
显示代码纯文本
#include <bits/stdc++.h>
using namespace std;
const int N = 2005;
typedef long long ll;
ll C[N][N], s, a, k, val[N], bas[N], res[N];
ll powm[N], powm1[N];
int mod;
ll ksm(ll a, ll b) {
a %= mod;
ll ans = 1;
while (b) {
if (b & 1) {
ans = (ans * a) % mod;
}
a = (a * a) % mod;
b >>= 1;
}
return ans;
}
void solve(ll n) {
if (n <= 2000) {
ll a_pow[N];
a_pow[0] = 1;
for (int i = 1; i <= n; ++i)
{ a_pow[i] = a_pow[i - 1] * a % mod; }
for (int i = 1; i <= n; ++i) {
bas[i] = 1;
}
for (int i = 0; i <= k; ++i) {
__int128 sum = 0;
for (int j = 1; j <= n; ++j) {
sum += (__int128)bas[j] * a_pow[j];
bas[j] = bas[j] * j % mod;
}
val[i] = (ll)(sum % mod);
}
return;
}
solve(n / 2);
ll m = (n / 2) % mod;
int K = (int)k;
if (!(n & 1)) {
ll P = ksm(a, n / 2);
powm[0] = 1;
for (int i = 1; i <= K; ++i)
{ powm[i] = powm[i - 1] * m % mod; }
for (int i = 0; i <= K; ++i) {
__int128 sum = 0;
ll *Ci = C[i];
for (int j = i; j >= 0; --j)
{ sum += (__int128)Ci[j] * powm[j] * val[i - j]; }
res[i] = (ll)(sum % mod);
ll t = (ll)res[i] * P % mod;
t += val[i];
if (t >= mod) {
t -= mod;
}
res[i] = t;
}
} else {
ll P = ksm(a, n / 2 + 1);
powm1[0] = 1;
for (int i = 1; i <= K; ++i)
{ powm1[i] = powm1[i - 1] * (m + 1) % mod; }
for (int i = 0; i <= K; ++i) {
__int128 sum = 0;
ll *Ci = C[i];
for (int j = i; j >= 0; --j)
{ sum += (__int128)Ci[j] * powm1[j] * val[i - j]; }
res[i] = (ll)(sum % mod);
ll add2 = (ll)P * powm1[i] % mod;
ll t = (ll)res[i] * P % mod;
t += val[i];
if (t >= mod) {
t -= mod;
}
t += add2;
if (t >= mod) {
t -= mod;
}
res[i] = t;
}
}
for (int i = 0; i <= K; ++i) {
val[i] = res[i];
}
}
int main() {
freopen("oeis.in", "r", stdin);
freopen("oeis.out", "w", stdout);
scanf("%lld %lld %lld %d", &s, &a, &k, &mod);
for (int i = 0; i <= k; ++i) {
C[i][0] = C[i][i] = 1;
for (int j = 1; j < i; ++j) {
ll x = C[i - 1][j] + C[i - 1][j - 1];
if (x >= mod) {
x -= mod;
}
C[i][j] = x;
}
}
solve(s);
printf("%lld\n", val[k]);
return 0;
}