【ネタ】動的木とカラバのアルゴリズムで最小全域木の辺の重みの総和を求める

カラバのアルゴリズム

カラバのアルゴリズム (Kalaba's Algorithm)とは、以下のように最小木問題を解くアルゴリズムです。

  • 適当な全域木をとる。
  • 最初にとった全域木に含まれない辺について、適当な順序で順に以下の操作を行う。
    • 注目する辺を全域木に追加する。
    • できた閉路(基本閉路)をなす辺で最も重みが大きい辺をとる。
    • とった辺を全域木から削除する。

組み合わせ最適化問題の文脈で、最小木問題をM凸関数の最小化と見ることで自然に導かれる貪欲アルゴリズムのようです。(私は講義で最小木問題のアルゴリズムの1つとして知っただけなので詳しくはありません。)

単純であるため一瞬これから使おうと思いましたが、冷静に考えてみるとオンラインで木の形状の変更を行いパスクエリを解かなければならないことからクラスカル法・プリム法・ブルーフカ法などと異なり容易に時間O(|E|\log|V|)を達成することが難しいです。
実用性よりも理論として重要ということでしょうか。

Link Cut Treeなどの動的木を用い、辺を表す頂点を追加した(|V|+|E|)頂点の森に対して辺の追加や削除とパスクエリを行えばならしO(|E|\log|V|)が達成できます。
動的木を自分で書かなければそこまで実装量が多いわけでもないので書きました。

ソースコード (C++)

Luzhiled's LibraryのLink Cut Treeを使用しました。

#include <algorithm>
#include <iostream>
#include <stack>
#include <utility>
#include <vector>
using namespace std;

template <typename TreeDPInfo>
struct LinkCutTree {
  using Path = typename TreeDPInfo::Path;
  using Info = typename TreeDPInfo::Info;

 private:
  struct Node {
    Node *l, *r, *p;

    Info info;

    Path sum, mus;

    bool rev;

    bool is_root() const { return not p or (p->l != this and p->r != this); }

    Node(const Info &info)
        : info(info), l(nullptr), r(nullptr), p(nullptr), rev(false) {}
  };

 public:
  using NP = Node *;

 private:
  void toggle(NP t) {
    swap(t->l, t->r);
    swap(t->sum, t->mus);
    t->rev ^= true;
  }

  void rotr(NP t) {
    NP x = t->p, y = x->p;
    push(x), push(t);
    if ((x->l = t->r)) t->r->p = x;
    t->r = x, x->p = t;
    update(x), update(t);
    if ((t->p = y)) {
      if (y->l == x) y->l = t;
      if (y->r == x) y->r = t;
    }
  }

  void rotl(NP t) {
    NP x = t->p, y = x->p;
    push(x), push(t);
    if ((x->r = t->l)) t->l->p = x;
    t->l = x, x->p = t;
    update(x), update(t);
    if ((t->p = y)) {
      if (y->l == x) y->l = t;
      if (y->r == x) y->r = t;
    }
  }

 public:
  LinkCutTree() = default;

  void push(NP t) {
    if (t->rev) {
      if (t->l) toggle(t->l);
      if (t->r) toggle(t->r);
      t->rev = false;
    }
  }

  void push_rev(NP t) {
    if (t->rev) {
      if (t->l) toggle(t->l);
      if (t->r) toggle(t->r);
      t->rev = false;
    }
  }

  void update(NP t) {
    Path key = TreeDPInfo::vertex(t->info);
    t->sum = key;
    t->mus = key;
    if (t->l) {
      t->sum = TreeDPInfo::compress(t->l->sum, t->sum);
      t->mus = TreeDPInfo::compress(t->mus, t->l->mus);
    }
    if (t->r) {
      t->sum = TreeDPInfo::compress(t->sum, t->r->sum);
      t->mus = TreeDPInfo::compress(t->r->mus, t->mus);
    }
  }

  void splay(NP t) {
    push(t);
    while (not t->is_root()) {
      NP q = t->p;
      if (q->is_root()) {
        push_rev(q), push_rev(t);
        if (q->l == t)
          rotr(t);
        else
          rotl(t);
      } else {
        NP r = q->p;
        push_rev(r), push_rev(q), push_rev(t);
        if (r->l == q) {
          if (q->l == t)
            rotr(q), rotr(t);
          else
            rotl(t), rotr(t);
        } else {
          if (q->r == t)
            rotl(q), rotl(t);
          else
            rotr(t), rotl(t);
        }
      }
    }
  }

  NP expose(NP t) {
    NP rp = nullptr;
    for (NP cur = t; cur; cur = cur->p) {
      splay(cur);
      cur->r = rp;
      update(cur);
      rp = cur;
    }
    splay(t);
    return rp;
  }

  void link(NP child, NP parent) {
    if (is_connected(child, parent)) {
      throw runtime_error(
          "child and parent must be different connected components");
    }
    if (child->l) {
      throw runtime_error("child must be root");
    }
    child->p = parent;
    parent->r = child;
    update(parent);
  }

  void cut(NP child) {
    expose(child);
    NP parent = child->l;
    if (not parent) {
      throw runtime_error("child must not be root");
    }
    child->l = nullptr;
    parent->p = nullptr;
    update(child);
  }

  void evert(NP t) {
    expose(t);
    toggle(t);
    push(t);
  }

  NP alloc(const Info &v) {
    NP t = new Node(v);
    update(t);
    return t;
  }

  bool is_connected(NP u, NP v) {
    expose(u), expose(v);
    return u == v or u->p;
  }

  vector<NP> build(vector<Info> &vs) {
    vector<NP> nodes(vs.size());
    for (int i = 0; i < (int)vs.size(); i++) {
      nodes[i] = alloc(vs[i]);
    }
    return nodes;
  }

  NP lca(NP u, NP v) {
    if (not is_connected(u, v)) return nullptr;
    expose(u);
    return expose(v);
  }

  void set_key(NP t, const Info &v) {
    expose(t);
    t->info = std::move(v);
    update(t);
  }

  const Path &query_path(NP u) {
    expose(u);
    return u->sum;
  }

  const Path &query_path(NP u, NP v) {
    evert(u);
    return query_path(v);
  }

  template <typename C>
  pair<NP, Path> find_first(NP u, const C &check) {
    expose(u);
    Path sum = TreeDPInfo::vertex(u->info);
    if (check(sum)) return {u, sum};
    u = u->l;
    while (u) {
      push(u);
      if (u->r) {
        Path nxt = TreeDPInfo::compress(u->r->sum, sum);
        if (check(nxt)) {
          u = u->r;
          continue;
        }
        sum = nxt;
      }
      Path nxt = TreeDPInfo::compress(TreeDPInfo::vertex(u->info), sum);
      if (check(nxt)) {
        splay(u);
        return {u, nxt};
      }
      sum = nxt;
      u = u->l;
    }
    return {nullptr, sum};
  }
};

struct TreeDPInfo {
    struct Path { long long max_weight; int idx; };
    struct Info { long long weight; int idx; };
    static Path vertex(const Info & u) { return {u.weight, u.idx}; };
    static Path compress(const Path& p, const Path& c) {
        if(p.max_weight>c.max_weight){
            return p;
        } else {
            return c;
        }
    };
};

struct Edge {
    int to, idx;
};

int main(void) {
    int n,m;
    cin >> n >> m;
    vector<int> a(m);
    vector<int> b(m);
    vector<long long> c(m);
    LinkCutTree<TreeDPInfo> lct;
    vector g(n, vector<Edge>());
    vector<TreeDPInfo::Info> vs(n+m);
    for(int i=0;i<n;++i){
        vs[i]={0, m};
    }
    for(int i=0;i<m;++i){
        cin >> a[i] >> b[i] >> c[i];
        --a[i];
        --b[i];
        g[a[i]].emplace_back(Edge{b[i], i});
        g[b[i]].emplace_back(Edge{a[i], i});
        vs[n+i]={c[i], i};
    }
    auto vertices=lct.build(vs);
    long long ans=0;
    stack<int> st;
    vector<bool> seen(n);
    vector<bool> used(m);
    st.emplace(0);
    seen[0]=true;
    while(!st.empty()){
        int v=st.top();
        st.pop();
        for(auto [u, i]: g[v]){
            if(!seen[u]){
                lct.evert(vertices[n+i]);
                lct.link(vertices[n+i], vertices[v]);
                lct.evert(vertices[n+i]);
                lct.link(vertices[n+i], vertices[u]);
                st.emplace(u);
                seen[u]=true;
                used[i]=true;
                ans+=c[i];
            }
        }
    }
    for(int i=0;i<m;++i){
        if(used[i]){
            continue;
        }
        int e=lct.query_path(vertices[a[i]], vertices[b[i]]).idx;
        if(c[i]<c[e]){
            lct.evert(vertices[n+e]);
            lct.cut(vertices[a[e]]);
            lct.evert(vertices[n+e]);
            lct.cut(vertices[b[e]]);
            ans-=c[e];
            lct.evert(vertices[n+i]);
            lct.link(vertices[n+i], vertices[a[i]]);
            lct.evert(vertices[n+i]);
            lct.link(vertices[n+i], vertices[b[i]]);
            ans+=c[i];
        }
    }
    cout << ans << endl;
    return 0;
}

提出 (鉄則A67)

普通に定数倍が重くて遅いです笑
atcoder.jp

感想

普通にクラスカル法かプリム法かブルーフカ法が書ければよくね?