わりとよくある備忘録

競プロ他雑多な私的メモ

包除原理(ABC172E)

包除原理を応用する問題が出題されて本番中は手も足も出なかった。
いろんな解説を見てACしたので理解をまとめる。
atcoder.jp

包除原理とは

ものすごくざっくり書くと、複数の集合から構成される全体集合を重複なく求めるための考え方。基本的には全体を足して、重複する部分に関しては別途足し引きして辻褄を合わせる。
詳細は下記を参照。
mathtrain.jp

どう応用したか

今回の例では、ユニークな数字で構成される2つの数列A,Bの内、任意のiA_i≠B_iとなるパターンを数え上げる。
このような数列を直接求めるのは難しいので、全体の数列の組み合わせを全通り求めてそこから条件を満たさないものを除いていくことを考える。

全体集合

全体の数列を表す集合Uは、m個の中からn個を2数列分並べればよいので次の式で求まる。

U=( _m P_n)^2

条件を満たさない集合

今回の条件を満たさない集合とは、任意のiA_i=B_iとなるiが1つ以上ある集合のことなのでこれを数え上げていく。
ここで、直接「ちょうどk箇所のiで重複する集合」を求められれば良いが、これも難しい。「少なくともk箇所のiで重複する集合」なら数えやすそうだが、これを全て足してしまうと重複する部分が発生する。例えば「少なくとも1箇所のiで重複する集合」には「少なくとも2箇所のiで重複する集合」も含まれている。

ここで、包除原理を使って重複を削除し「k箇所のiで値が重複する集合S,k=\{1,2,...,n\}」を数え上げればよい。
上記を計算するために、「少なくともk箇所のiで重複する集合S_k」について考える。
これを数え上げるのは簡単で、次のように考える。

  1. まず、長さnの数列の中で重複させる位置をk箇所選ぶ ->  _n C_k
  2. 次に、そのk箇所に重複させる数字をm個の中から選んで並べる ->  _m P_k
  3. 最後に、残ったn-kの位置に当てはめる数字をm-k個の中から選んで並べる(数列A,B2つ分を選ぶため二乗となることに注意) -> ( _{m-k} P_{n-k})^2

よって重複がkのときの集合S_kは下記のようになる

 S_k=_n C_k* _m P_k*( _{m-k} P_{n-k})^2

ここまでくれば、Sを求めるにはS_kを包除原理に従って足し合わせればよい。

S=\sum_{k=1}^{n}(-1)^{(k-1)}* _n C_k* _m P_k*( _{m-k} P_{n-k})^2

よって今求めたい集合は全体からこれを引けばよいので、

U-S=( _m P_n)^2 - \sum_{k=1}^{n}(-1)^{(k-1)}* _n C_k* _m P_k*( _{m-k} P_{n-k})^2

となる。ここで、k=0のときは、
_n C_k* _m P_k*( _{m-k} P_{n-k})^2=( _m P_n)^2

となるので、上式を整理すると結局、
\sum_{k=0}^{n}(-1)^k* _n C_k* _m P_k*( _{m-k} P_{n-k})^2

を求めればよいことになる。

実装例

#include <bits/stdc++.h>

#define rep(i, n) for (int i = 0; i < (n); i++)
#define reps(i, n, s) for (int i = (s); i < (n); i++)

using ll = long long;
using namespace std;
constexpr long long MAX = 5100000;
constexpr long long INF = 1LL << 60;
constexpr int MOD = 1000000007;

ll fac[MAX], finv[MAX], inv[MAX];

void COMinit() {
  fac[0] = fac[1] = 1;
  finv[0] = finv[1] = 1;
  inv[1] = 1;
  for (int i = 2; i < MAX; i++) {
    fac[i] = fac[i - 1] * i % MOD;
    inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD;
    finv[i] = finv[i - 1] * inv[i] % MOD;
  }
}

ll nCk(int n, int k) {
  if (n < k) return 0;
  if (n < 0 || k < 0) return 0;
  return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
}

ll nPk(ll n, ll k) {
  if (n == 0 || k == 0) return 1;
  return fac[n] * finv[n - k] % MOD;
}

class mint {
  long long x;

 public:
  mint(long long x = 0) : x((x % MOD + MOD) % MOD) {}
  mint operator-() const { return mint(-x); }
  mint& operator+=(const mint& a) {
    if ((x += a.x) >= MOD) x -= MOD;
    return *this;
  }
  mint& operator-=(const mint& a) {
    if ((x += MOD - a.x) >= MOD) x -= MOD;
    return *this;
  }
  mint& operator*=(const mint& a) {
    (x *= a.x) %= MOD;
    return *this;
  }
  mint operator+(const mint& a) const {
    mint res(*this);
    return res += a;
  }
  mint operator-(const mint& a) const {
    mint res(*this);
    return res -= a;
  }
  mint operator*(const mint& a) const {
    mint res(*this);
    return res *= a;
  }
  mint pow(ll t) const {
    if (!t) return 1;
    mint a = pow(t >> 1);
    a *= a;
    if (t & 1) a *= *this;
    return a;
  }
  // for prime MOD
  mint inv() const { return pow(MOD - 2); }
  mint& operator/=(const mint& a) { return (*this) *= a.inv(); }
  mint operator/(const mint& a) const {
    mint res(*this);
    return res /= a;
  }

  friend ostream& operator<<(ostream& os, const mint& m) {
    os << m.x;
    return os;
  }
};
int main() {
  cin.tie(0);
  ios::sync_with_stdio(false);
  ll n, m;
  cin >> n >> m;

  COMinit();
  // nCk*mPk*(m-kPn-k)^2
  mint ans = 0;
  rep(k, n + 1) {
    mint now = nPk(m - k, n - k);
    now *= now;
    now *= nPk(m, k);
    now *= nCk(n, k);
    if (k % 2 != 0) now = -now;
    ans += now;
  }
  cout << ans << endl;
  return 0;
}

参考

以下、参考にさせて頂いたサイトです。
www.youtube.com
qiita.com