以前,桁 DP の簡潔な実装 という記事を書きました.
その後,maspy(@maspy_stars)さんに,これとは別の桁 DP の考察・実装法を教えていただきました.ポイントは以下です:
- 最下位桁で場合分けする
- メモ化再帰で書く
この方針だと,ABC050 / ARC066: D - Xor Sum が自然に解けたので,紹介します.
maspyのHP では,多くの問題がこの方針で解説されています.
解法
以下の手順で解きます:
- 足し算 $+$ は bitwise な演算でなく厄介なので,bitwise な演算で書き換える.
- 書き換えた条件に分数が現れるので,分数を新たな文字でおいて整理する.
- $(a, b)$ の存在条件をさらに整理する.
- 冒頭の方針に従って桁 DP.
これ以降,bitwise XOR を $\oplus$ と,bitwise AND を $\land$ と書きます.
bitwise な演算での書き換え
足し算 $+$ は bitwise な演算ではなく少し扱いにくいです.そこで,競プロで頻出の式 \[a + b = (a \oplus b) + 2 (a \land b)\]を使います.この式は,二進法での足し算の筆算を考えると分かります.
これを使うと,問題文の条件は\[ a \oplus b = u, \qquad a \land b = \frac{v-u}{2} \] と変形できます.
分数の処理
足し算 $+$ を bitwise な演算である $\land$ に書き換えられたのは良かったのですが,分数が現れてしまいました.これも厄介です.そこで,$w := \dfrac{v-u}{2}$ とおきます.このとき $v = u + 2w$ です.
組 $(u, v)$ と $(u, w)$ が一対一に対応するため,組 $(u, v)$ を数える代わりに,組 $(u, w)$ を数えても良いです.
$u, v \leq N$ も考慮すると,問題の答えは,非負整数の組 $(u, w)$ であって,条件
- $u + 2w \leq N$
- $a \oplus b = u,\ $ $a \land b = w$ となる $(a, b)$ が存在する
をともに満たすものの数です.$v = u + 2w \leq N$ ならば,$w$ の非負性から $u \leq N$ も成り立つので,$u \leq N$ という条件は忘れてしまってよいことに注意してください.
$(a, b)$ の存在条件の整理
初手で足し算 $+$ を消したおかげで,条件 2 は bitwise な条件になっています.
そこで,$(a, b) = (0, 0), (0, 1), (1, 0), (1, 1)$ の各場合について,$a \oplus b,$ $a \land b$ を計算してみます.上で求めた $(u, w)$ についての条件は,以下と等価であることが分かります:
- $u + 2w \leq N$
- $u, w$ の下から $i$ 桁目の bit を $u_i, w_i$ とする.任意の $i$ に対して $(u_i, w_i) \neq (1, 1)$
桁 DP へ
ここまで来ると,あとは桁 DP で解けそうです.ここでは,冒頭に述べたとおり,
- $u, w$ の最下位 bit $u_0, w_0$ で場合分けする
- メモ化再帰で書く
という方針にします.
問題の答えを $\mathtt{dp}(N)$ とおきます.$u, w$ の最下位 bit で場合分けするために,$u = 2u' + u_0, $ $w = 2w' + w_0$ とおきます.すると,
- $(u_0, w_0) = (0, 0)$ のとき: $u + 2w \leq N$ $\iff$ $u' + 2w' \leq \dfrac{N}{2}$
- $(u_0, w_0) = (1, 0)$ のとき: $u + 2w \leq N$ $\iff$ $u' + 2w' \leq \dfrac{N-1}{2}$
- $(u_0, w_0) = (0, 1)$ のとき: $u + 2w \leq N$ $\iff$ $u' + 2w' \leq \dfrac{N-2}{2}$
です.これより,以下の漸化式が得られます!: \[ \mathtt{dp}(N) = \mathtt{dp}\Big(\Big\lfloor \frac{N}{2} \Big\rfloor \Big) + \mathtt{dp}\Big(\Big\lfloor \frac{N-1}{2} \Big\rfloor \Big) + \mathtt{dp}\Big(\Big\lfloor \frac{N-2}{2} \Big\rfloor \Big). \]
あとはこれをメモ化再帰で書けば OK です.
計算量の評価
$\mathtt{dp}(n)$ が呼び出されるような $n$ は,$N$ の上位 $k$ 桁,もしくはそこから $1$ or $2$ を引いた数に等しいです.よって,高々 $3 \log_2 N$ 個程度の $n$ に対してしか,$\mathtt{dp}(n)$ を計算する必要はありません.
メモにハッシュテーブル(std::unordered_map
など)を使うと,解法全体の計算量も $O(\log N)$ となります.
実装例(C++ with ACL)
std::map
を使っているので,計算量は $O(\log N \log \log N)$ です.
#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using namespace atcoder;
using ll = long long;
using mint = modint1000000007;
map<ll, mint> memo;
mint dp(ll n) {
if (memo.count(n)) return memo[n];
if (n == 0) return 1;
if (n == 1) return 2;
return memo[n] = dp(n/2) + dp((n-1)/2) + dp((n-2)/2);
}
int main() {
ll n; cin >> n;
cout << dp(n).val() << '\n';
}