包除原理(ABC172E)
包除原理を応用する問題が出題されて本番中は手も足も出なかった。
いろんな解説を見てACしたので理解をまとめる。
atcoder.jp
包除原理とは
ものすごくざっくり書くと、複数の集合から構成される全体集合を重複なく求めるための考え方。基本的には全体を足して、重複する部分に関しては別途足し引きして辻褄を合わせる。
詳細は下記を参照。
mathtrain.jp
どう応用したか
今回の例では、ユニークな数字で構成される2つの数列の内、任意のでとなるパターンを数え上げる。
このような数列を直接求めるのは難しいので、全体の数列の組み合わせを全通り求めてそこから条件を満たさないものを除いていくことを考える。
全体集合
全体の数列を表す集合は、m個の中からn個を2数列分並べればよいので次の式で求まる。
条件を満たさない集合
今回の条件を満たさない集合とは、任意のでとなるが1つ以上ある集合のことなのでこれを数え上げていく。
ここで、直接「ちょうど箇所ので重複する集合」を求められれば良いが、これも難しい。「少なくとも箇所ので重複する集合」なら数えやすそうだが、これを全て足してしまうと重複する部分が発生する。例えば「少なくとも箇所ので重複する集合」には「少なくとも箇所ので重複する集合」も含まれている。
ここで、包除原理を使って重複を削除し「箇所ので値が重複する集合,」を数え上げればよい。
上記を計算するために、「少なくとも箇所ので重複する集合」について考える。
これを数え上げるのは簡単で、次のように考える。
- まず、長さの数列の中で重複させる位置を箇所選ぶ ->
- 次に、その箇所に重複させる数字を個の中から選んで並べる ->
- 最後に、残ったの位置に当てはめる数字を個の中から選んで並べる(数列2つ分を選ぶため二乗となることに注意) ->
よって重複がのときの集合は下記のようになる
ここまでくれば、を求めるにはを包除原理に従って足し合わせればよい。
よって今求めたい集合は全体からこれを引けばよいので、
となる。ここで、のときは、
となるので、上式を整理すると結局、
を求めればよいことになる。
実装例
#include <bits/stdc++.h> #define rep(i, n) for (int i = 0; i < (n); i++) #define reps(i, n, s) for (int i = (s); i < (n); i++) using ll = long long; using namespace std; constexpr long long MAX = 5100000; constexpr long long INF = 1LL << 60; constexpr int MOD = 1000000007; ll fac[MAX], finv[MAX], inv[MAX]; void COMinit() { fac[0] = fac[1] = 1; finv[0] = finv[1] = 1; inv[1] = 1; for (int i = 2; i < MAX; i++) { fac[i] = fac[i - 1] * i % MOD; inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD; finv[i] = finv[i - 1] * inv[i] % MOD; } } ll nCk(int n, int k) { if (n < k) return 0; if (n < 0 || k < 0) return 0; return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD; } ll nPk(ll n, ll k) { if (n == 0 || k == 0) return 1; return fac[n] * finv[n - k] % MOD; } class mint { long long x; public: mint(long long x = 0) : x((x % MOD + MOD) % MOD) {} mint operator-() const { return mint(-x); } mint& operator+=(const mint& a) { if ((x += a.x) >= MOD) x -= MOD; return *this; } mint& operator-=(const mint& a) { if ((x += MOD - a.x) >= MOD) x -= MOD; return *this; } mint& operator*=(const mint& a) { (x *= a.x) %= MOD; return *this; } mint operator+(const mint& a) const { mint res(*this); return res += a; } mint operator-(const mint& a) const { mint res(*this); return res -= a; } mint operator*(const mint& a) const { mint res(*this); return res *= a; } mint pow(ll t) const { if (!t) return 1; mint a = pow(t >> 1); a *= a; if (t & 1) a *= *this; return a; } // for prime MOD mint inv() const { return pow(MOD - 2); } mint& operator/=(const mint& a) { return (*this) *= a.inv(); } mint operator/(const mint& a) const { mint res(*this); return res /= a; } friend ostream& operator<<(ostream& os, const mint& m) { os << m.x; return os; } }; int main() { cin.tie(0); ios::sync_with_stdio(false); ll n, m; cin >> n >> m; COMinit(); // nCk*mPk*(m-kPn-k)^2 mint ans = 0; rep(k, n + 1) { mint now = nPk(m - k, n - k); now *= now; now *= nPk(m, k); now *= nCk(n, k); if (k % 2 != 0) now = -now; ans += now; } cout << ans << endl; return 0; }
参考
以下、参考にさせて頂いたサイトです。
www.youtube.com
qiita.com