zcmimi's blog
avatar
zcmimi
2020-09-03 21:35:00

LG 4719 【模板】"动态 DP"&动态树分治

给定一棵n个点的树,点带点权。

m次操作,每次操作给定x,y,表示修改点x的权值为y

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

f_{i,0}表示以i为根的子树,不选i的答案,f_{i,1}表示以i为根的子树,选i的答案

很容易列出dp方程:

f_{i,0}=\sum_{j} \max(f_{j,0},f_{j,1})\\ f_{i,1}=a_i + \sum_{j} f_{j,0}

这里ji的所有儿子

最后答案为\max(f_{1,0},f_{1,1})

考虑带修改:

修改了一个点的点权,只会更改从这个点到根这条路径上节点的答案

我们希望能够只更改这条链上的答案

由于树可能会退化成一条链,这样每次更新就是\mathcal{O(n)}的,不可接受。

我们希望这条链只更新\log n

使用重链剖分(树链剖分),原因:

  • 每个点到根的路径上,最多经过\log n条轻边
  • 每条重链的链尾都是叶子节点,且只有叶子节点没有重儿子

    (这为动态规划的初始状态和转移方式做了保障。)

  • 一条重链所在的区间在剖出的DFS序上,是连续的一段区间

    (这为可以使用数据结构维护区间信息)

考虑简化dp方程以迎合重连剖分:

g_{i,0}表示i点所有轻儿子,可取可不取形成的最大权独立集,g_{i,1}表示 i点只考虑轻儿子的、取自己的最大权独立集

那么

f_{i,0}=g_{i,0}+\max(f_{j,0},f_{j,1})\\ f_{i,1}=g_{i,1}+f_{j,0}

这里的j表示i点的重儿子。特殊地,对于叶子节点,g_{i, 0} = g_{i, 1} = 0

区间维护: 使用矩阵

定义矩阵

\begin{vmatrix} f_{i,0} & f_{i,1} \end{vmatrix}

构造转移矩阵使\begin{vmatrix}f_{j,0} & f_{j,1}\end{vmatrix} \to \begin{vmatrix}f_{i,0} & f_{i,1}\end{vmatrix}

可以发现我们的dp方程不满足矩阵乘法的形式,

那么我们重新定义矩阵乘法C=A*B为:

C_{i,j}=\max_{k} \{A_{i,k}+B_{k,j}\}

可以发现这个新定义的矩阵乘法仍具有结合律

构造转移矩阵:

\begin{vmatrix}f_{j,0} & f_{j,1}\end{vmatrix}* \begin{vmatrix} g_{i,0} & g_{i,1}\\ g_{i,0} & \infty \end{vmatrix} =\begin{vmatrix} f_{i,0} & f_{i,1} \end{vmatrix}

对于一条重链,我们可以用线段树维护区间乘积

到了一条重链链头,因为这个点是它父亲的轻儿子,我们需要更新它父亲节点所在的点的转移矩阵。

这样子一直跳到根节点就可以了

代码如下:

#include<bits/stdc++.h>
#define fl(i,x) for(int i=head[x],to;to=e[i].to,i;i=e[i].nxt)
const int N=100011;
void cmax(int&x,int y){if(x<y)x=y;}
int max(int x,int y){return x>y?x:y;}
int n,m,cnt,a[N],head[N];
struct edge{int to,nxt;}e[N<<1];
int f[N],top[N],siz[N],id[N],dfn[N],sz,F[N][2],end[N];
void add(int x,int y){e[++cnt].to=y,e[cnt].nxt=head[x],head[x]=cnt;}
struct mat{
    int a[2][2];
    mat(){memset(a,-0x3F,sizeof a);}
    int&operator()(int x,int y){return a[x][y];}
    mat operator*(const mat&x)const{
        mat c;
        for(int i=0;i<2;++i)
            for(int j=0;j<2;++j)
                for(int k=0;k<2;++k)
                    cmax(c.a[i][j],a[i][k]+x.a[k][j]);
        return c;
    }
    int mx()const{return max(a[0][0],a[1][0]);}
}val[N];
struct seg{
    mat s[N<<2];
    #define ls rt<<1
    #define rs rt<<1|1
    void pu(int rt){s[rt]=s[ls]*s[rs];}
    void build(int l=1,int r=n,int rt=1){
        if(l==r)return s[rt]=val[dfn[l]],void();
        int m=l+r>>1;
        build(l,m,ls),build(m+1,r,rs);
        pu(rt);
    }
    void upd(int x,int l=1,int r=n,int rt=1){
        if(l==r)return s[rt]=val[dfn[l]],void();
        int m=l+r>>1;
        if(x<=m)upd(x,l,m,ls);
        else upd(x,m+1,r,rs);
        pu(rt);
    }
    mat ask(int L,int R,int l=1,int r=n,int rt=1){
        if(L==l&&r==R)return s[rt];
        int m=l+r>>1;
        if(L>m)return ask(L,R,m+1,r,rs);
        else if(R<=m)return ask(L,R,l,m,ls);
        else return ask(L,m,l,m,ls)*ask(m+1,R,m+1,r,rs);
    }
    #undef ls
    #undef rs
}T;
void dfs(int x){
    siz[x]=1;
    fl(i,x)if(to!=f[x])
        f[to]=x,dfs(to),siz[x]+=siz[to];
}
void bt(int x,int tp){
    top[x]=tp;dfn[id[x]=++sz]=x;
    cmax(end[tp],sz);
    F[x][0]=0,F[x][1]=a[x];
    val[x](0,0)=val[x](0,1)=0;
    val[x](1,0)=a[x];
    int k=0;
    fl(i,x)if(to!=f[x]&&siz[to]>siz[k])k=to;
    if(!k)return;
    bt(k,tp);
    F[x][0]+=max(F[k][0],F[k][1]);
    F[x][1]+=F[k][0];
    fl(i,x)if(!top[to]){
        bt(to,to);
        F[x][0]+=max(F[to][0],F[to][1]);
        F[x][1]+=F[to][0];
        val[x](0,1)=val[x](0,0)+=max(F[to][0],F[to][1]);
        val[x](1,0)+=F[to][0];
    }
}
void upd(int x,int v){
    val[x](1,0)+=v-a[x];
    a[x]=v;
    mat bef,aft;
    while(x){
        bef=T.ask(id[top[x]],end[top[x]]);
        T.upd(id[x]);
        aft=T.ask(id[top[x]],end[top[x]]);
        x=f[top[x]];
        val[x](0,1)=val[x](0,0)+=aft.mx()-bef.mx();
        val[x](1,0)+=aft(0,0)-bef(0,0);
    }
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)scanf("%d",a+i);
    int x,y;
    for(int i=1;i<n;++i)
        scanf("%d%d",&x,&y),
        add(x,y),add(y,x);
    dfs(1),bt(1,1),T.build();
    while(m--){
        scanf("%d%d",&x,&y);
        upd(x,y);
        printf("%d\n",T.ask(id[1],end[1]).mx());
    }
}
动态dp
comment评论
Search
search