zcmimi's blog

点分治

点分治用于大规模处理树上路径

树的重心

树的最大子树最小的点

性质:每一个子树的大小都不超过\frac n2

int rt,mxs=inf,siz[N];// 当前 根,最大子树大小
void frt(int x,int f){
    siz[x]=1;
    int res=0;
    fl(i,x)if(!v[to]&&to!=f){
        frt(to,x);
        siz[x]+=siz[to];
        res=max(res,siz[to]);
    }
    res=max(res,SZ-siz[x]);
    /*
     SZ是当前树大小

     因为是无根树,所以除了自己的子树外的部分也要当做一棵子树
     */
    if(res<mxs)rt=x,mxs=res;
}

实现

void sol(int x){
    v[x]=1;//标记为已统计
    ans+=calc(x,0);//找出当前节点的答案
    fl(i,x)if(!v[to]){
        ans-=calc(to,e[i].w);//减去重复的
        rt=0,mxs=inf;SZ=siz[to];
        frt(to,0);//找出子树的重心
        sol(rt);//进入子树,继续分治
    }
}
1234567

从点1开始的路径有:

1\rightarrow 2

1\rightarrow 2\rightarrow 4

1\rightarrow 2\rightarrow 5

1\rightarrow 3

1\rightarrow 3\rightarrow 6

1\rightarrow 3\rightarrow 7

有些路径合并起来后,并不会经过点1,要减去这部分的答案

比如1\rightarrow 2\rightarrow 41\rightarrow 2\rightarrow 5合并起来其实是:

4\rightarrow 2\rightarrow 5,并没有经过1

合并统计路径

举个例子: 求长度为k的路径数

我们将从x开始的所有路径的长度求出来存到数组中

排序可以用类似双指针的方法扫描

如果k的值较小的话可以用桶的方式来统计

// get dis,获取从x开始的所有路径的长度
void gd(int x,int d,int f){
    if(d>k)return;
    b[++tt]=d;
    fl(i,x)if(!v[to]&&to!=f)gd(to,d+1,x);
}
ll calc(int x,int d){// d是x的深度(用来减去子树中多余部分时用到)
    tt=0;gd(x,d,0);
    sort(b+1,b+tt+1);
    int l=0,r=tt;
    ll res=0;
    while(l<r){
        ++l;
        while(b[l]+b[r]>k&&l<r)--r;
        if(b[l]+b[r]==k){
            int t1=1,t2=1;// 统计两边的个数
            while(l<r&&b[l]==b[l+1])++l,++t1;
            while(l<r&&b[r]==b[r-1])--r,++t2;
            if(b[l]!=b[r])res+=1ll*t1*t2;
            else res+=1ll*(t1+t2)*(t1+t2-1)/2ll;
            //将所有长度为k/2的路径组合
        }
    }
    return res;
}

例题

CF161D Distance in Tree

CodeForces 161D Distance in Tree

输入点数为N一棵树

求树上长度恰好为K的路径个数

模板题

#include<bits/stdc++.h>
using namespace std;
#define N 500011
#define inf 1000000000
#define ll long long
int n,k,cnt=0,SZ,rt,mxs,head[N],siz[N];
bool v[N];
struct edge{
    int to,nxt;
}e[N<<1];
void add(int x,int y){
    e[++cnt].to=y;e[cnt].nxt=head[x];head[x]=cnt;
}
#define fl(i,x) for(int i=head[x],to;to=e[i].to,i;i=e[i].nxt)
void frt(int x,int f){
    siz[x]=1;
    int res=0;
    fl(i,x)if(!v[to]&&to!=f){
        frt(to,x);
        siz[x]+=siz[to];
        res=max(res,siz[to]);
    }
    res=max(res,SZ-siz[x]);
    if(res<mxs)rt=x,mxs=res;
}
int b[N],tt;
ll ans=0;
void gd(int x,int d,int f){
    if(d>k)return;
    b[++tt]=d;
    fl(i,x)if(!v[to]&&to!=f)gd(to,d+1,x);
}
ll calc(int x,int d){
    tt=0;gd(x,d,0);
    sort(b+1,b+tt+1);
    int l=0,r=tt;
    ll res=0;
    while(l<r){
        ++l;
        while(b[l]+b[r]>k&&l<r)--r;
        if(b[l]+b[r]==k){
            int t1=1,t2=1;
            while(l<r&&b[l]==b[l+1])++l,++t1;
            while(l<r&&b[r]==b[r-1])--r,++t2;
            if(b[l]!=b[r])res+=1ll*t1*t2;
            else res+=1ll*(t1+t2)*(t1+t2-1)/2ll;
        }
    }
    return res;
}
void sol(int x){
    v[x]=1;ans+=calc(x,0);
    fl(i,x)if(!v[to]){
        ans-=calc(to,1);
        rt=0,mxs=inf;SZ=siz[to];
        frt(to,0);
        sol(rt);
    }
}
int main(){
    scanf("%d%d",&n,&k);
    int x,y;
    for(int i=1;i<=n-1;++i)
        scanf("%d%d",&x,&y),
        add(x,y),add(y,x);

    rt=0,mxs=inf,SZ=n;
    frt(1,0);
    sol(rt);
    printf("%lld\n",ans);
}

[国家集训队]聪聪可可

LG P2634 [国家集训队]聪聪可可

BZ 2152 聪聪可可

也可以算是模板吧

这个就是用桶来统计了

#include<bits/stdc++.h>
using namespace std;
#define N 20011
int n,cnt=0,head[N];
struct edge{
    int to,nxt,w;
}e[N<<1];
void add(int x,int y,int w){
    e[++cnt].to=y;e[cnt].nxt=head[x];head[x]=cnt;e[cnt].w=w;
}
#define fl(i,x) for(int i=head[x],to;to=e[i].to,i;i=e[i].nxt)
int siz[N],SZ,rt,mxs,ans=0,b[3];
bool v[N];
void frt(int x,int f){
    siz[x]=1;
    int res=0;
    fl(i,x)if(!v[to]&&to!=f){
        frt(to,x);
        siz[x]+=siz[to];
        res=max(res,siz[to]);
    }
    res=max(res,SZ-siz[x]);
    if(res<mxs)rt=x,mxs=res;
}
void gd(int x,int d,int f){
    ++b[d%3];
    fl(i,x)if(!v[to]&&to!=f)gd(to,d+e[i].w,x);
}
int calc(int x,int d){
    b[0]=b[1]=b[2]=0;
    gd(x,d,0);
    return b[0]*b[0]+b[1]*b[2]*2;
}
void sol(int x){
    v[x]=1;
    ans+=calc(x,0);
    fl(i,x)if(!v[to]){
        ans-=calc(to,e[i].w);
        rt=0,mxs=(1<<30),SZ=siz[to];
        frt(to,x);
        sol(rt);
    }
}
int main(){
    scanf("%d",&n);
    int x,y,w;
    for(int i=1;i<=n-1;++i)
        scanf("%d%d%d",&x,&y,&w),
        add(x,y,w),
        add(y,x,w);

    rt=0,mxs=(1<<30);SZ=n;
    frt(1,0);
    sol(rt);
    int fm=n*n,gcd=__gcd(fm,ans);
    fm/=gcd;ans/=gcd;
    printf("%d/%d\n",ans,fm);
}
点分治
comment评论
Search
search