zcmimi's blog

前置知识: 莫队算法(Mo's Algorithm)

树上莫队:

也就是把莫队搬到树上

入手模板: SP10707 COT2 - Count on a tree II

给定一个n个节点的树,每个节点表示一个整数,问u到v的路径上有多少个不同的整数。

要把莫队搬到树上,那我们如何把树上问题转化到序列上呢? 欧拉序

欧拉序

欧拉序就是前序遍历时,在遍历到x时将x加入序列,离开x子树时再将x加入序列

void dfs(int x){
    T[st[x]=++dfn]=x;//加入序列
    siz[x]=1;
    fl(i,x)if(to!=f[x]){//遍历x的儿子
        f[to]=x;
        d[to]=d[x]+1;
        dfs(to);
        siz[x]+=siz[to];
    }
    T[ed[x]=++dfn]=x;//加入序列
}

实现

这里我们设st_i表示访问到i时加入欧拉序的时间,ed_i表示回溯经过i时加入欧拉序的时间

不妨设st_x<st_y

lca(x,y)==x,x,y在一条链上,那么统计序列中[st_x,st_y]的答案

否则x,y在不同的子树内,那么统计序列中[ed_x,st_y]的答案

不理解的可以手动画图模拟一遍,并配合代码理解qwq

于是我们就把树上问题转化为序列上的了

接下来就和普通莫队一样了

代码

#include<bits/stdc++.h>
#define Fur(i,x,y) for(int i=x;i<=y;++i)
#define fl(i,x) for(int i=head[x],to;to=e[i].to,i;to=e[i].nxt)
using namespace std;
const int N=40011,M=100011;
int n,cnt=0,head[N],val[N],old[N];
struct edge{
    int to,nxt;
}e[N<<1];
void add(int x,int y){
    e[++cnt].to=y;e[cnt].nxt=head[x];head[x]=cnt;
}
int T[N],dfn=0,st[N],ed[N],top[N],siz[N],f[N],d[N];
void dfs(int x){
    T[st[x]=++dfn]=x;
    siz[x]=1;
    fl(i,x)if(to!=f[x]){
        f[to]=x;
        d[to]=d[x]+1;
        dfs(to);
        siz[x]+=siz[to];
    }
    T[ed[x]=++dfn]=x;
}
void bt(int x,int tp){
    top[x]=tp;
    int k=0;
    fl(i,x)if(to!=f[x]&&siz[to]>siz[k])k=to;
    if(!k)return;bt(k,tp);
    fl(i,x)if(!top[to])bt(to,to);
}
int lca(int x,int y){
    while(top[x]!=top[y]){
        if(d[top[x]]<d[top[y]])swap(x,y);
        x=f[top[x]];
    }
    return d[x]<d[y]?x:y;
}
int blk,bl[N<<1];
struct que{
    int l,r,lca,bl,br,id;
    bool operator<(que t)const{
        (bl==t.bl)?((bl&1)?r<t.r:r>t.r):(l<t.l);
    }
}Q[N];
int ans=0,c[N],ANS[M];
bool v[N];
void inc(int x){ans+=(++c[val[x]]==1);}
void dec(int x){ans-=(--c[val[x]]==0);}
void op(int x){
    v[x]?dec(x):inc(x);
    v[x]^=1;
}
struct node{
    int v,p;
    bool operator<(node t)const{
        return v<t.v;
    }
}a[M];
int main(){
    int q,x,y,l=1,r=0;
    scanf("%d%d",&n,&q);
    blk=sqrt(n);
    Fur(i,1,n*2)bl[i]=(i-1)/blk+1;
    Fur(i,1,n)scanf("%d",&a[i].v),a[i].p=i;
    sort(a+1,a+n+1);
    int tt=0;
    Fur(i,1,n)old[val[a[i].p]=(tt+=a[i].v!=a[i-1].v)]=a[i].v;
    Fur(i,1,n-1)scanf("%d%d",&x,&y),add(x,y),add(y,x);
    dfs(1);bt(1,1);
    Fur(i,1,q){
        scanf("%d%d",&x,&y);
        if(st[x]>st[y])swap(x,y);
        int t=lca(x,y);
        if(t==x)Q[i]=que{st[x],st[y],0,bl[st[x]],bl[st[y]],i};
        else Q[i]=que{ed[x],st[y],t,bl[ed[x]],bl[st[y]],i};
    }
    sort(Q+1,Q+q+1);
    Fur(i,1,q){
        while(l>Q[i].l)op(T[--l]);
        while(r<Q[i].r)op(T[++r]);
        while(l<Q[i].l)op(T[l++]);
        while(r>Q[i].r)op(T[r--]);
        if(Q[i].lca)op(Q[i].lca);
        ANS[Q[i].id]=ans;
        if(Q[i].lca)op(Q[i].lca);
    }
    Fur(i,1,q)printf("%d\n",ANS[i]);
}
树上莫队
comment评论
Search
search