CF 1487G

連結

題意

cnt0cnt0個字母a、cnt1cnt1個字母b、…、cnt25cnt25個字母z
問有幾個長度為nn的字串,且沒有長度為奇數的回文子字串

(3n400,n3<cntin)(3 \leq n \leq 400, \frac{n}{3} < cnt_i \leq n)

解法

雖然這場pA-pE都還蠻糞的,但這題我覺得還蠻有趣的

奇數長度的回文子子串?看起來很噁心
但事實上只要檢查有沒有長度為3的回文子子串就好
也就是1in2,s[i]s[i+2]\forall 1 \leq i \leq n - 2, s[i] \neq s[i + 2]需成立

接下來,到底n3\frac{n}{3}能幹嘛?可以發現這代表最多只會有兩種字元超過限制
到底有誰想的到這個R,\羊神教我/

也就是說,採用排容的方式,答案會是「所有可能」-「a超過」-「b超過」-…-「z超過」
+「a和b超過」+「a和c超過」+…+「y和z超過」

想到這裡就差不多接近答案了
考慮dp[i][a][b][l][r]代表長度為ii的字串,其中包含aa個字元a、bb個字元b,且字串後兩個字元為llrr
字元不是有26個嗎?這樣會炸吧
仔細想一下就會發現字元實際只有3種:a、b和其他
所以就可以列出一個簡單的O(3)\mathcal{O}(3)轉移式:

1
2
3
4
5
6
7
// dp[i][a][b][l][r]
for (int place = 0; place < 3; ++place) {
// 記得要把回文的case去掉
if (place < 2 && l == place) continue;
int na = a + (place == 0), nb = b + (place == 1), letter = (place < 2 ? 1 : 24 - (l == place));
add(dp[i & 1 ^ 1][na][nb][r][place], 1ll * dp[i & 1][a][b][l][r] * letter % mod);
}

這樣dp式算完,因為ab可以換成任兩個字元
就可以直接用上面的排容式算出答案了~
複雜度會是O(27n3)\mathcal{O}(27n^3)

code

實作的心得:五維陣列的dp看起來好噁爛= =

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <bits/stdc++.h>
using namespace std;
#define lli long long int
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define test(x) cout << "Line(" << __LINE__ << ") " #x << ' ' << x << endl
#define printv(x) {\
for (auto i : x) cout << i << ' ';\
cout << endl;\
}
#define pii pair <int, int>
#define pll pair <lli, lli>
#define X first
#define Y second
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
template<typename A, typename B>
ostream& operator << (ostream& o, pair<A, B> a){
return o << a.X << ' ' << a.Y;
}
template<typename A, typename B>
istream& operator >> (istream& o, pair<A, B> &a){
return o >> a.X >> a.Y;
}
const int mod = 998244353, abc = 864197532, N = 402, logN = 17, K = 333, INF = 5e8;

int dp[2][N][N][3][3];
int sum[N];

void add(int &a, int b) {
a += b;
if (a >= mod) a -= mod;
}

void del(int &a, int b) {
a -= b;
if (a < 0) a += mod;
}

int main () {
ios::sync_with_stdio(false);
cin.tie(0);
int n;
cin >> n;
int cnt[26];
for (int i = 0; i < 26; ++i) cin >> cnt[i];
dp[0][0][0][2][2] = 24 * 24;
dp[0][0][1][2][1] = dp[0][0][1][1][2] = 24;
dp[0][0][2][1][1] = 1;
dp[0][1][0][2][0] = dp[0][1][0][0][2] = 24;
dp[0][1][1][0][1] = dp[0][1][1][1][0] = 1;
dp[0][2][0][0][0] = 1;
for (int i = 2; i < n; ++i) {
for (int a = 0; a <= i; ++a) for (int b = 0; a + b <= i; ++b) {
for (int prel = 0; prel < 3; ++prel) for (int prer = 0; prer < 3; ++prer) {
dp[i & 1 ^ 1][a][b][prel][prer] = 0;
}
}
for (int a = 0; a <= i; ++a) for (int b = 0; a + b <= i; ++b) {
for (int place = 0; place < 3; ++place) {
for (int prel = 0; prel < 3; ++prel) for (int prer = 0; prer < 3; ++prer) {
if (place < 2 && prel == place) continue;
int na = a + (place == 0), nb = b + (place == 1), letter = (place < 2 ? 1 : 24 - (prel == place));
add(dp[i & 1 ^ 1][na][nb][prer][place], 1ll * dp[i & 1][a][b][prel][prer] * letter % mod);
}
}
}
}
int ans = 0;
for (int a = 0; a <= n; ++a) for (int b = 0; a + b <= n; ++b) {
for (int prel = 0; prel < 3; ++prel) for (int prer = 0; prer < 3; ++prer) {
add(sum[a], dp[n & 1][a][b][prel][prer]);
}
}
for (int i = 0; i <= n; ++i) add(ans, sum[i]);
for (int i = 0; i < 26; ++i) for (int c = cnt[i] + 1; c <= n; ++c) {
del(ans, sum[c]);
}
for (int i = 0; i < 26; ++i) for (int j = i + 1; j < 26; ++j) {
for (int c1 = cnt[i] + 1; c1 <= n; ++c1) for (int c2 = cnt[j] + 1; c2 <= n; ++c2) {
for (int prel = 0; prel < 3; ++prel) for (int prer = 0; prer < 3; ++prer) {
add(ans, dp[n & 1][c1][c2][prel][prer]);
}
}
}
cout << ans << endl;
}