AGC040-C Neither AB nor BA (800) 解説

AGCの高配点系って解説少ないよね

概要

  • 長さ$N$ の文字列が与えられて、その文字は全て'A'か'B'か'C'
  • 文字列から連続した2文字を選んで消す
  • "AB"と"BA"を選んで消してはいけない
  • 全消しできるような文字列の通り数はいくつ?

制約

  • \(N\)は偶数
  • \( N \)は107以下

考察

こういう数え上げは初手条件整理が命な気がする。このまま存在条件を考えてそれを数える形式に持って行きたい。

まず、消去前と消去後で文字の位置の偶奇は変わらないことに気づくので偶奇に分解する。すると、奇数番目の'A'と偶数番目の'B', 奇数番目の'B'と偶数番目の'A'をマッチさせてはいけないという条件に言い換えられる。

ここで奇数番目にある'A'の個数を\(A_1\), 奇数番目にある'B'の個数を\(B_1\) , 偶数番目にある'A'の個数を\(A_2\), 偶数番目にある'B'の個数を\(B_2\)とおく。また、\(N/2=n\)とする。

明らかに、

\(A_1+B_1 \leq n \tag{1}\) \(A_2+B_2 \leq n \tag{2}\)

が成り立つ。

そして、文字列が全消しできることは

\(A_1+B_2 \leq n \tag{3} \) \(A_2+B_1 \leq n \tag{4} \)

の両方が成り立つことの必要十分な気がする。

(3)式を変形することで\(A_1 \leq n-B_2\)となり、\(n-B_2\)とは偶数番目にある'A'と’C'の個数のことなので、これに反すると奇数番目の'A'と組まなければならない偶数番目の’B'が出てきて矛盾。 (4)に関しても同様なので、十分性は示せた。たぶんこれが必要条件にもなってるんだろうけど証明できてない

あとはこれを満たす4つの数\(A_1,B_1,A_2,B_2\)を数え上げればいいが、\(O(N4)\)かかってしまう。(3)と(4)の関係がANDで結ばれていて条件が厳しくなっているため、余事象をとる。

そうすると

\(A_1+B_2 > n \tag{5} \) \(A_2+B_1 > n \tag{6} \)

となり、(1)(2)の元で(5)か(6)が満たされている通り数を 3N から引けばよい。ここで、(5)と(6)は排反事象である。(なぜなら、(1)と(2), (5)と(6)をそれぞれ辺々足すと矛盾する)

対称性より(5)のみ考えてあとで2倍する。 \(A_1\)は1から\(n\)までの範囲を自由に動くことができる。まず\(A_1\)を固定すると、それに対応して\(B_2\)としてとれる値が決まる。 \(A_1=a, B_2=b\)のとき、そのような文字列の個数は

\[ \binom{n}{a} \times \binom{n}{b} \times 2^{2n-a-b} \]

となる。(長さnの文字列から'A'を置く場所を決め、'B'を置く場所を決め、残りを余り物2種類で埋める通り数)

この式で\(a\)が現れる場所をくくりだすと\[ \binom{n}{a} \times 2^{2n-a} \times (\binom{n}{b} \times 2^{-b}) \]となる。\(a\)を固定したとき、条件を満たす\(b\)は区間になっているので\( (\binom{n}{b} \times 2^{-b})\)についてあらかじめ累積和を計算しておくことで\(O(1)\)で各\(a\)について計算できるようになる。 最後に全ての\(a\)について和を計算することによって、正しい答えを得ることができる。

実装例

#include "bits/stdc++.h"
using namespace std;
#define rep(i,n) for(int i=0;i<n;i++)
#define int long long
const int inf = 1e17;
const int mod = 998244353;
const int maxN = 10000003;
int kj[maxN], kji[maxN];
int rwa[maxN];
int modpow(int a, int x, int mod) {
    int res = 1;
    while (x) {
        if (x & 1)res = res*a%mod;
        x >>= 1;
        a = a*a%mod;
    }
    return res;
}
void setkj(int n) {
    kj[0] = 1;
    rep(i, n)kj[i + 1] = kj[i] * (i + 1) % mod;
    rep(i, n + 1)kji[i] = modpow(kj[i], mod - 2, mod);
}
int comb(int r,int c) {
    if (c<0 || r<c)return 0;
    return kj[r] * kji[c] % mod*kji[r - c] % mod;
}

signed main() {
    int n; cin >> n;
    setkj(n);
    n /= 2;
    int sum = 0;
    for (int j = 1; j <= n; j++)rwa[j] = comb(n, j)*modpow(2, mod - 1 - j, mod) % mod;
    for (int j = 0; j <= n; j++)rwa[j + 1] += rwa[j];

    for (int i = 1; i <= n; i++) {
        int lj = n + 1 - i, rj = n;
        int res1 = rwa[rj] - rwa[lj - 1] + mod;
        res1 %= mod;
        sum += comb(n, i)*modpow(2, n + n - i, mod) % mod*res1%mod;
        sum %= mod;
    }

    sum = modpow(3, 2 * n, mod) - sum * 2 % mod;
    sum = (sum + mod) % mod;
    cout << sum << endl; 
}

所感

解説が甜菜すぎる そしてMathjaxの機嫌が悪い...