| 记录编号 |
610011 |
评测结果 |
AAAAAAAAAA |
| 题目名称 |
分组游戏 |
最终得分 |
100 |
| 用户昵称 |
梦那边的美好TE |
是否通过 |
通过 |
| 代码语言 |
C++ |
运行时间 |
2.420 s |
| 提交时间 |
2025-12-06 14:56:57 |
内存使用 |
18.02 MiB |
显示代码纯文本
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int mod = 998244353;
const int G = 3;
const int Gi = 332748118;
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = 1ll * res * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return res;
}
void ntt(vector<int>& a, int len, int type) {
for (int i = 1, j = len >> 1; i < len - 1; i++) {
if (i < j) swap(a[i], a[j]);
int k = len >> 1;
while (j >= k) {
j -= k;
k >>= 1;
}
if (j < k) j += k;
}
for (int h = 2; h <= len; h <<= 1) {
int wn = qpow(type == 1 ? G : Gi, (mod - 1) / h);
for (int j = 0; j < len; j += h) {
int w = 1;
for (int k = j; k < j + h / 2; k++) {
int u = a[k];
int t = 1ll * w * a[k + h / 2] % mod;
a[k] = (u + t) % mod;
a[k + h / 2] = (u - t + mod) % mod;
w = 1ll * w * wn % mod;
}
}
}
if (type == -1) {
int inv = qpow(len, mod - 2);
for (int i = 0; i < len; i++) {
a[i] = 1ll * a[i] * inv % mod;
}
}
}
vector<int> multiply(const vector<int>& a, const vector<int>& b) {
int len = 1;
int n = a.size(), m = b.size();
while (len < n + m) len <<= 1;
vector<int> A(len, 0), B(len, 0);
for (int i = 0; i < n; i++) A[i] = a[i];
for (int i = 0; i < m; i++) B[i] = b[i];
ntt(A, len, 1);
ntt(B, len, 1);
for (int i = 0; i < len; i++) {
A[i] = 1ll * A[i] * B[i] % mod;
}
ntt(A, len, -1);
A.resize(n + m - 1);
return A;
}
const int N = 2005;
int n, a[N];
int fac[N], invfac[N];
int dp[N][N], f[N][N];
void init() {
fac[0] = 1;
for (int i = 1; i <= n; i++) {
fac[i] = 1ll * fac[i-1] * i % mod;
}
invfac[n] = qpow(fac[n], mod - 2);
for (int i = n-1; i >= 0; i--) {
invfac[i] = 1ll * invfac[i+1] * (i+1) % mod;
}
}
int main() {
freopen("gamem.in", "r", stdin);
freopen("gamem.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", a+i);
}
sort(a+1, a+1+n, greater<int>());
init();
dp[1][0] = 1;
f[1][0] = dp[1][0];
for (int i = 1; i <= n; i++) {
for (int j = 0; j < i; j++) {
f[i+1][j+1] = (f[i+1][j+1] + 1ll * f[i][j] * (j+1) % mod) % mod;
}
int limit = min(a[i] - 1, i - 1);
if (limit >= 0) {
vector<int> A(i, 0);
for (int j = 0; j < i; j++) {
A[j] = f[i][j];
}
vector<int> B(limit + 1, 0);
for (int k = 0; k <= limit; k++) {
B[limit - k] = invfac[k];
}
vector<int> C = multiply(A, B);
for (int j = 0; j < i; j++) {
if (j + limit < (int)C.size()) {
f[i+1][j] = (f[i+1][j] + C[j + limit]) % mod;
}
}
}
for (int j = 0; j <= i; j++) {
dp[i+1][j] = 1ll * f[i+1][j] * invfac[j] % mod;
}
}
printf("%d\n", dp[n+1][0]);
return 0;
}