QOJ #3091. Japanese Knowledge 题解

Description

给定一个非递减的正整数序列 A=(A1,A2,,AN)A = (A_1, A_2, \ldots, A_N)

对于每个 k=0,1,2,,Nk = 0, 1, 2, \ldots, N,要求计算满足以下条件的、长度为 NN非递减非负整数序列 x=(x1,x2,,xN)x = (x_1, x_2, \ldots, x_N) 的数量,并对结果取模 998244353998244353

  1. 对所有 1iN1 \leq i \leq N,有 xiAix_i \leq A_i
  2. 恰好有 kk 个下标 ii 满足 xi=Aix_i = A_i

1N,Ai2.5×1051\leq N,A_i\leq 2.5\times 10^5

Solution

fk(a1,a2,,an)f_k(a_1,a_2,\ldots,a_n) 为有恰好 kkxi=aix_i=a_i 的方案数,g(a1,a2,,an)g(a_1,a_2,\ldots,a_n) 为不考虑 xi=aix_i=a_i 这条限制的总方案数。

首先这题只有第一个限制是能做的,但是加上第二个限制就不好做了。所以考虑怎么把第二个限制去掉。

有个想法是对第二个进行容斥,但是会发现是不行的。于是需要找一些关于 xi=aix_i=a_i 的性质。

这里先给出结论:所有 (ak+11,ak+21,,an1)(a_{k+1}-1,a_{k+2}-1,\ldots,a_n-1) 的方案都一一对应了一个原序列有 kkxi=aix_i=a_i 的方案。

证明就考虑如果有恰好 kkxi=aix_i=a_i,把这 kk 个数去掉后的 xix'_i 一定要小于 ai+ka_{i+k}。因为如果 xiai+kx'_i\geq a_{i+k},则 ai+kxixia_{i+k}\leq x'_i\leq x_i,这时 xi,xi+1,,xkx_i,x_{i+1},\ldots,x_k 都必须顶到上界,就矛盾了。

对于一个满足 xi<ai+kx'_i<a_{i+k} 的方案,考虑从小到大放到原数组中。每次一定是找到 xi<ajx'_i<a_j 的最小的 jj 放进去,如果不是最小的 jj,则最小的 kk 一定会顶到上界,而 xi<akx'_i<a_k,就矛盾了。由于已经满足了 xi<ai+kx'_i<a_{i+k},所以这么操作一定能找到唯一的一组方案。

现在问题变为对于每个后缀,求 g(ak1,ak+11,,an1)g(a_k-1,a_{k+1}-1,\ldots,a_n-1)


考虑分治。

问题可以转化为有 nn 个柱子,第 ii 个柱子高 hi=ai1h_i=a_i-1,问从最右侧每次往左或者往下走,有多少种走到 (i,0)(i,0) 的方案。

我们定义函数 solve(l,r,F)solve(l,r,F) 为当前分治区间为 [l,r][l,r],且这个区间内的每个柱子的高度都减去 hl1h_{l-1},从 nn 走到当前的 (r,i)(r,i) 的方案数为 FiF_i;返回值为一个多项式 GG,其中 GiG_i 为走到 l+il+i 的方案数。

(图是贺的)

每次找到中点 midmid,然后调用 solve(mid+1,r,F)solve(mid+1,r,F),求出从最右边走到横着的黄线的每个位置的方案数,再通过 NTT 求出从横着的黄线走到竖黄线的每个位置的方案数 HH,最后调用 solve(l,mid,H)solve(l,mid,H)

在分治的过程中,还要更新 [mid+1,r][mid+1,r] 终点在黄线下面的方案数,和上面的东西是类似的。

时间复杂度:O((N+V)log2N)O((N+V)\log^2 N)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#include <bits/stdc++.h>

// #define int int64_t

const int kMaxN = 5e5 + 5, kMod = 998244353;

int n;
int a[kMaxN], res[kMaxN], fac[kMaxN], ifac[kMaxN];

constexpr int qpow(int bs, int64_t idx = kMod - 2) {
int ret = 1;
for (; idx; idx >>= 1, bs = (int64_t)bs * bs % kMod)
if (idx & 1)
ret = (int64_t)ret * bs % kMod;
return ret;
}

inline int add(int x, int y) { return (x + y >= kMod ? x + y - kMod : x + y); }
inline int sub(int x, int y) { return (x >= y ? x - y : x - y + kMod); }
inline void inc(int &x, int y) { (x += y) >= kMod ? x -= kMod : x; }
inline void dec(int &x, int y) { (x -= y) < 0 ? x += kMod : x; }

namespace POLY {
constexpr int kMaxN = 4e6 + 5, kR = 3, kB = __builtin_ctz(kMod - 1), kG = qpow(kR, (kMod - 1) >> kB);

int polyg[kMaxN];
bool inited;

void prework(int n = (kMaxN - 5) / 2) {
inited = 1;
int c = 0;
for (; (1 << c) <= n; ++c) {}
c = std::min(c - 1, kB - 2);
polyg[0] = 1, polyg[1 << c] = qpow(kG, 1 << (kB - 2 - c));
for (int i = c; i; --i)
polyg[1 << i - 1] = (int64_t)polyg[1 << i] * polyg[1 << i] % kMod;
for (int i = 1; i < (1 << c); ++i)
polyg[i] = (int64_t)polyg[i & (i - 1)] * polyg[i & -i] % kMod;
}

int getlen(int n) {
int len = 1;
for (; len <= n; len <<= 1) {}
return len;
}

struct Poly : std::vector<int> {
using vector::vector;
using vector::operator [];

friend Poly operator -(Poly a) {
static Poly c;
c.resize(a.size());
for (int i = 0; i < c.size(); ++i)
c[i] = sub(0, c[i]);
return c;
}
friend Poly operator +(Poly a, Poly b) {
static Poly c;
c.resize(std::max(a.size(), b.size()));
for (int i = 0; i < c.size(); ++i)
c[i] = add((i < a.size() ? a[i] : 0), (i < b.size() ? b[i] : 0));
return c;
}
friend Poly operator -(Poly a, Poly b) {
static Poly c;
c.resize(std::max(a.size(), b.size()));
for (int i = 0; i < c.size(); ++i)
c[i] = sub((i < a.size() ? a[i] : 0), (i < b.size() ? b[i] : 0));
return c;
}
friend void dif(Poly &a, int len) {
if (a.size() < len) a.resize(len);
for (int l = len; l != 1; l >>= 1) {
int m = l / 2;
for (int i = 0, k = 0; i < len; i += l, ++k) {
for (int j = 0; j < m; ++j) {
int tmp = (int64_t)a[i + j + m] * polyg[k] % kMod;
a[i + j + m] = sub(a[i + j], tmp);
inc(a[i + j], tmp);
}
}
}
}
friend void dit(Poly &a, int len) {
if (a.size() < len) a.resize(len);
for (int l = 2; l <= len; l <<= 1) {
int m = l / 2;
for (int i = 0, k = 0; i < len; i += l, ++k) {
for (int j = 0; j < m; ++j) {
int tmp = a[i + j + m];
a[i + j + m] = (int64_t)sub(a[i + j], tmp) * polyg[k] % kMod;
inc(a[i + j], tmp);
}
}
}
int invl = qpow(len);
for (int i = 0; i < len; ++i)
a[i] = (int64_t)a[i] * invl % kMod;
std::reverse(a.begin() + 1, a.begin() + len);
}
friend Poly operator *(Poly a, Poly b) {
if (!inited) prework();
int n = a.size() + b.size() - 1, len = getlen(n);
a.resize(len), b.resize(len);
dif(a, len), dif(b, len);
for (int i = 0; i < len; ++i)
a[i] = (int64_t)a[i] * b[i] % kMod;
dit(a, len);
a.resize(n);
return a;
}
friend Poly operator *(Poly a, int b) {
static Poly c;
c = a;
for (auto &x : c) x = (int64_t)x * b % kMod;
return c;
}
friend Poly operator *(int a, Poly b) {
static Poly c;
c = b;
for (auto &x : c) x = (int64_t)x * a % kMod;
return c;
}
friend void operator *=(Poly &a, Poly b) {
if (!inited) prework();
int n = a.size() + b.size() - 1, len = getlen(n);
a.resize(len), b.resize(len);
dif(a, len), dif(b, len);
for (int i = 0; i < len; ++i)
a[i] = (int64_t)a[i] * b[i] % kMod;
dit(a, len);
a.resize(n);
}
friend Poly shift(Poly f, int d) {
if (d == 0) return f;
if ((int)f.size() + d < 0) return {};
Poly g((int)f.size() + d, 0);
for (int i = 0; i < g.size(); ++i)
if (i - d >= 0 && i - d < f.size())
g[i] = f[i - d];
return g;
}
};
} // namespace POLY

using POLY::Poly;

int C(int m, int n) {
if (m < n || m < 0 || n < 0) return 0;
return 1ll * fac[m] * ifac[n] % kMod * ifac[m - n] % kMod;
}

void prework(int n = 5e5) {
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = 1ll * i * fac[i - 1] % kMod;
ifac[n] = qpow(fac[n]);
for (int i = n; i; --i) ifac[i - 1] = 1ll * i * ifac[i] % kMod;
}

Poly solve(int l, int r, Poly F) {
assert(F.size() == a[r] - a[l - 1] + 1);
if (l == r) {
int ret = 0;
for (auto x : F) inc(ret, x);
return {ret};
}
int mid = (l + r) >> 1, llen = mid - l + 1, rlen = r - mid, d = a[mid] - a[l - 1];
Poly f = solve(mid + 1, r, shift(F, -d));
Poly G(d + 1, 0);
G[d] = f[0];
Poly pf(rlen, 0), pg(d + rlen - 1, 0);
for (int i = 0; i <= rlen - 1; ++i) pf[i] = 1ll * f[rlen - 1 - i] * ifac[rlen - 1 - i] % kMod;
for (int i = 0; i <= d + rlen - 2; ++i) pg[i] = fac[i];
pf *= pg;
for (int i = 0; i <= d - 1; ++i)
if (d + rlen - 2 - i >= 0 && d + rlen - 2 - i < pf.size())
inc(G[i], 1ll * pf[d + rlen - 2 - i] * ifac[d - 1 - i] % kMod);
//
if (d) {
pf.clear(), pg.clear();
pf.resize(d, 0), pg.resize(d, 0);
for (int i = 0; i < d; ++i) pf[i] = F[d - 1 - i];
for (int i = 0; i < d; ++i) pg[i] = 1ll * fac[i + rlen - 1] * ifac[i] % kMod;
pf *= pg;
for (int i = 0; i <= d - 1; ++i) inc(G[i], 1ll * pf[d - 1 - i] * ifac[rlen - 1] % kMod);
}
Poly ff = f;
if (d) {
pf.clear(), pg.clear();
pf.resize(rlen, 0), pg.resize(rlen, 0);
for (int i = 0; i < rlen; ++i) pf[i] = f[rlen - 1 - i];
for (int i = 1; i < rlen; ++i) pg[i] = 1ll * fac[d - 1 + i] * ifac[i] % kMod;
pf *= pg;
for (int i = 0; i <= rlen - 1; ++i) inc(ff[i], 1ll * pf[rlen - 1 - i] * ifac[d - 1] % kMod);
}
if (d) {
pf.clear(), pg.clear();
pf.resize(d, 0), pg.resize(d + rlen - 1, 0);
for (int i = 0; i < d; ++i) pf[i] = 1ll * F[d - 1 - i] * ifac[d - 1 - i] % kMod;
for (int i = 0; i < d + rlen - 1; ++i) pg[i] = fac[i];
pf *= pg;
for (int i = 0; i <= rlen - 1; ++i)
if (d + rlen - 2 - i >= 0 && d + rlen - 2 - i < pf.size())
inc(ff[i], 1ll * pf[d + rlen - 2 - i] * ifac[rlen - 1 - i] % kMod);
}
assert(ff.size() == r - mid);
Poly g = solve(l, mid, G);
return g + shift(ff, g.size());
}

void dickdreamer() {
std::cin >> n; prework();
for (int i = 1; i <= n; ++i) std::cin >> a[i], --a[i];
auto res = solve(1, n, Poly(a[n] + 1, 1));
res.emplace_back(1);
for (auto x : res) std::cout << x << ' ';
}

int32_t main() {
#ifdef ORZXKR
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
int T = 1;
// std::cin >> T;
while (T--) dickdreamer();
// std::cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";
return 0;
}

QOJ #3091. Japanese Knowledge 题解
https://sobaliuziao.github.io/2025/08/27/post/5472a497.html
作者
Egg_laying_master
发布于
2025年8月27日
许可协议