| 记录编号 |
616581 |
评测结果 |
AAAAAAAAAA |
| 题目名称 |
HS的自然数拆分 |
最终得分 |
100 |
| 用户昵称 |
hsl_beat |
是否通过 |
通过 |
| 代码语言 |
C++ |
运行时间 |
8.406 s |
| 提交时间 |
2026-06-28 16:42:04 |
内存使用 |
355.82 MiB |
显示代码纯文本
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 998244353;
const int G = 3;
int fastpow(int a, int b, int mod)
{
int res = 1;
while (b) {
if (b & 1) {
res = (res * a) % mod;
}
a = (a * a) % mod;
b >>= 1;
}
return res;
}
void change(vector<int> &a, int len)
{
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) {
swap(a[i], a[j]);
}
int k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) {
j += k;
}
}
}
void ntt(vector<int> &a, int len, int x)
{
change(a, len);
for (int h = 2; h <= len; h <<= 1) {
int omega = fastpow(G, (mod - 1) / h, mod);
if (x == -1) {
omega = fastpow(omega, mod - 2, mod);
}
for (int i = 0; i < len; i += h) {
int w = 1;
for (int j = i; j < i + h / 2; j++) {
int u = a[j];
int v = a[j + h / 2] * w % mod;
a[j] = (u + v) % mod;
a[j + h / 2] = (u - v + mod) % mod;
w = (w * omega) % mod;
}
}
}
if (x == -1) {
int invl = fastpow(len, mod - 2, mod);
for (int i = 0; i < a.size(); i++) {
a[i] = (a[i] * invl) % mod;
}
}
}
vector<int> convo(vector<int> a, vector<int> b) {
if (a.empty() || b.empty()) return {0};
int m = 1;
while (m < a.size() + b.size() - 1) {
m <<= 1;
}
a.resize(m, 0);
b.resize(m, 0);
ntt(a, m, 1);
ntt(b, m, 1);
for (int i = 0; i < m; i++) {
a[i] = (a[i] * b[i]) % mod;
}
ntt(a, m, -1);
a.resize(a.size() + b.size() - 1);
return a;
}
vector<int> ans;
void calc()
{
int n = 100000;
int b = sqrt(2 * n + 1);
vector<int> dp(n + 1, 0);
dp[0] = 1;
for (int i = 1; i <= b; i++) {
for (int j = i; j <= n; j++) {
dp[j] = (dp[j] + dp[j - i]) % mod;
}
}
vector<vector<int>> dp2(b + 1, vector<int>(n + 1, 0));
dp2[0][0] = 1;
for (int i = 1; i <= b; i++) {
for (int j = b + 1; j <= n; j++) {
dp2[i][j] = (dp2[i][j - i] + dp2[i - 1][j - b - 1]) % mod;
}
}
// cout << "b: " << b << '\n';
// for (int i = 1; i <= n; i++) {
// cout << dp[i] << ' ';
// }
// cout << '\n';
// for (int j = b + 1; j <= n; j++) {
// cout << j << ": ";
// for (int i = 1; i <= b; i++) {
// cout << dp2[i][j] << ' ';
// }
// cout << '\n';
// }
vector<int> cnt(n + 1, 0);
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= b; j++) {
cnt[i] = (cnt[i] + dp2[j][i]) % mod;
}
}
ans = convo(dp, cnt);
}
signed main()
{
freopen("zrscf.in", "r", stdin);
freopen("zrscf.out", "w", stdout);
calc();
int T;
cin >> T;
while (T--) {
int x;
cin >> x;
cout << ans[x] - 1 << '\n';
}
return 0;
}