用户登录
用户注册

分享至

动态dp初探

  • 作者: 北京夜场模特领队
  • 来源: 51数据库
  • 2021-10-27

动态区间最大子段和问题

给出长度为\(n\)的序列和\(m\)次操作,每次修改一个元素的值或查询区间的最大字段和(sp1714 gss3)。

\(f[i]\)为以下标\(i\)结尾的最大子段和,\(g[i]\)表示从起始位置到\(i\)以内的最大子段和。
\[ f[i]=\max(f[i-1]+a[i],a[i])\\g[i]=\max(g[i-1],f[i]) \]
定义如下的矩阵乘法,显然这满足乘法结合律和分配律。
\[ c=ab\\c[i,j]=\max_{k}(a[i,k]+b[k,j]) \]
将转移写为矩阵(注意\(g[i]=\max(g[i-1],f[i-1]+a[i],a[i])\)
\[ \begin{bmatrix} f[i]\\ g[i]\\ 0\end{bmatrix} = \begin{bmatrix} a[i]&-\infty&a[i]\\ a[i]&0&a[i]\\ -\infty&-\infty&0\end{bmatrix} \begin{bmatrix} f[i-1]\\ g[i-1]\\ 0\end{bmatrix} \]
可知每个元素\(a[i]\)都对应了一个矩阵,可以认为区间\([l,r]?\)的答案所在矩阵正是
\[ (\prod_{i=l}^k \begin{bmatrix} a[i]&-\infty&a[i]\\ a[i]&0&a[i]\\ -\infty&-\infty&0 \end{bmatrix} )\begin{bmatrix} 0\\ -\infty\\ 0\end{bmatrix} \]
因此可以用线段树维护区间矩阵乘积。

#include <bits/stdc++.h>
using namespace std;
const int n=5e4+10;
const int inf=0x3f3f3f3f;

struct mtr {
    int a[3][3];
    int*operator[](int d) {return a[d];}
    inline mtr() {}
    inline mtr(int val) {
        a[0][0]=a[0][2]=a[1][0]=a[1][2]=val;
        a[0][1]=a[2][0]=a[2][1]=-inf;
        a[1][1]=a[2][2]=0;
    }
    mtr operator*(mtr b) {
        static mtr c;
        memset(&c,-inf,sizeof c);
        for(int i=0; i<3; ++i) 
        for(int k=0; k<3; ++k) 
        for(int j=0; j<3; ++j) 
            c[i][j]=max(c[i][j],a[i][k]+b[k][j]);
        return c; 
    }
} t,a[n<<2];

#define ls (x<<1)
#define rs (x<<1|1)
void build(int x,int l,int r) {
    if(l==r) {
        scanf("%d",&r);
        a[x]=mtr(r);
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    a[x]=a[ls]*a[rs];
}
void modify(int x,int l,int r,int p,int val) {
    if(l==r) {
        a[x]=mtr(val);
        return;
    }
    int mid=(l+r)>>1;
    if(p<=mid) modify(ls,l,mid,p,val);
    else modify(rs,mid+1,r,p,val);
    a[x]=a[ls]*a[rs];
}
mtr query(int x,int l,int r,int l,int r) {
    if(l<=l&&r<=r) return a[x];
    int mid=(l+r)>>1;
    if(r<=mid) return query(ls,l,mid,l,r);
    if(mid<l) return query(rs,mid+1,r,l,r);
    return query(ls,l,mid,l,r)*query(rs,mid+1,r,l,r); 
}

int main() {
    memset(&t,-inf,sizeof t); //notice
    t[0][0]=t[2][0]=0;
    int n,q;
    scanf("%d",&n);
    build(1,1,n);
    scanf("%d",&q);
    for(int op,l,r; q--; ) {
        scanf("%d%d%d",&op,&l,&r);
        if(op==0) modify(1,1,n,l,r);
        else {
            mtr ret=query(1,1,n,l,r)*t;
            printf("%d\n",max(ret[0][0],ret[1][0]));
        } 
    }
    return 0;
}

动态树上最大权独立集

注意断句 给出一棵\(n?\)个节点树和\(m?\)次操作,每次操作修改一个点权并计算当前树上的最大权独立集的权值。

重链剖分,设\(y\)\(x\)的某个儿子,\(s\)是重儿子,\(f[x,t]\)表示在以\(x\)为根的子树中不选/选\(x\)时的最大权独立集权值,\(g[x,t]\)表示在以\(x\)的为根的子树中除去以\(s\)为根的子树部分内不选/选\(x\)的最大权独立集权值,。
\[ g[x,0]=\sum_{y\not=s}\max(f[y,0],f[y,1])\\ g[x,1]=a[x]+\sum_{y\not=s} f[y,0]\\ f[x,0]=g[x,0]+\max(f[s,0],f[s,1])\\ f[x,1]=g[x,1]+f[s,0] \]
然后改写为矩阵乘法
\[ \begin{bmatrix} f[x,0]\\ f[x,1] \end{bmatrix}= \begin{bmatrix} g[x,0]&g[x,0]\\ g[x,1]&-\infty \end{bmatrix} \begin{bmatrix} f[s,0]\\ f[s,1] \end{bmatrix} \]
\(s\)不存在时,钦定\(f[s,0]=0\)\(f[s,1]=-\infty\)。进一步可发现在一条链内,链顶的\(f[,t]?\)值正是链上所有的“g矩阵”(应该明白指的是那个吧)乘起来的第一列值。

因此我们可以树剖维护重链上这些矩阵的乘积,更新时从修改点跳到每条重链的链顶,计算链顶部\(f[,t]\),更新他父亲的\(g[,t]\)(显然他不是父亲的重儿子),然后再跳往父亲所在重链……。

也可以lct来做,(试了一下树剖发现麻烦爆了)每次access修改点到根,然后对这部分重计算就好了。

#include <bits/stdc++.h>
using namespace std;
const int n=1e5+10;
const int inf=0x3f3f3f3f;

struct mtr {
    int a[2][2];
    int*operator[](int d) {return a[d];}
    mtr() {memset(a,-inf,sizeof a);}
    mtr operator*(mtr b) {
        mtr c;
        for(int i=0; i<2; ++i) 
        for(int k=0; k<2; ++k) 
        for(int j=0; j<2; ++j) 
            c[i][j]=max(c[i][j],a[i][k]+b[k][j]);
        return c; 
    }
} g[n],pg[n]; 

int n,m,a[n];
int head[n],to[n<<1],last[n<<1];
int fa[n],ch[n][2],dp[n][2];

void add_edge(int x,int y) {
    static int cnt=0;
    to[++cnt]=y,last[cnt]=head[x],head[x]=cnt;
}
void dfs(int x) {
    dp[x][1]=a[x];
    for(int i=head[x]; i; i=last[i]) {
        if(to[i]==fa[x]) continue;
        fa[to[i]]=x;
        dfs(to[i]);
        dp[x][0]+=max(dp[to[i]][0],dp[to[i]][1]);
        dp[x][1]+=dp[to[i]][0];
    }
    g[x][0][0]=g[x][0][1]=dp[x][0];
    g[x][1][0]=dp[x][1];
    pg[x]=g[x];
} 
void update(int x) {
    pg[x]=g[x];
    if(ch[x][0]) pg[x]=pg[ch[x][0]]*pg[x]; //无交换律 
    if(ch[x][1]) pg[x]=pg[x]*pg[ch[x][1]];
}
int get(int x) {
    return ch[fa[x]][0]==x?0:(ch[fa[x]][1]==x?1:-1);
}
void rotate(int x) {
    int y=fa[x],k=get(x);
    if(~get(y)) ch[fa[y]][get(y)]=x;
    fa[x]=fa[y];
    fa[ch[y][k]=ch[x][k^1]]=y;
    fa[ch[x][k^1]=y]=x;
    update(y);
    update(x); 
}
void splay(int x) {
    while(~get(x)) {
        int y=fa[x];
        if(~get(y)) rotate(get(x)^get(y)?x:y);
        rotate(x);
    }
} 
void access(int x) {
    for(int y=0; x; x=fa[y=x]) {
        splay(x);
        if(ch[x][1]) { //旧的重儿子 
            g[x][0][0]+=max(pg[ch[x][1]][0][0],pg[ch[x][1]][1][0]);
            g[x][1][0]+=pg[ch[x][1]][0][0];
        }
        if(y) { //新的重儿子 
            g[x][0][0]-=max(pg[y][0][0],pg[y][1][0]);
            g[x][1][0]-=pg[y][0][0];
        }
        g[x][0][1]=g[x][0][0]; //别忘了 
        ch[x][1]=y;
        update(x);
    }
}
void modify(int x,int y) {
    access(x);
    splay(x);
    g[x][1][0]+=y-a[x];
    update(x);
    a[x]=y;
}

int main() {
    scanf("%d%d",&n,&m);
    for(int i=1; i<=n; ++i) scanf("%d",a+i);
    for(int x,y,i=n; --i; ) {
        scanf("%d%d",&x,&y);
        add_edge(x,y);
        add_edge(y,x);
    }
    dfs(1); //所有连边是轻边 
    for(int x,y; m--; ) {
        scanf("%d%d",&x,&y);
        modify(x,y);
        splay(1);
        printf("%d\n",max(pg[1][0][0],pg[1][1][0]));
    }
    return 0;
}

全局平衡二叉树

然后讲一讲这道题的毒瘤加强版。传送门

数据加强并且经过特殊构造,树剖和lct都过不了了。树剖本身复杂度太大, o(\(m\log^2n\))过不了百万是很正常的;而lct虽然只有一个\(\log\) ,但由于常数过大也被卡了。

树剖的两个 \(\log\) 基本上可以放弃治疗了。但是我们不禁要问,lct究竟慢在哪里?

仔细想想,lct的access复杂度之所以是一个 \(\log?\) ,是由于splay的势能分析在整棵lct上依然成立,也就是说可以把lct看作一棵大splay,在这棵大splay上的一次access只相当于一次splay。

话虽然是这么说,但是实际上当我们不停地随机access的时候,要调整的轻重链数量还是很多的。感受一下,拿极端情形来说,如果树是一条链,一开始全是轻边,那么对链末端的结点access一次显然应该是 \(o(n)\)的。所以其实lct的常数大就大在它是靠势能法得到的 \(o(\log n)\),这么不靠谱的玩意是容易gg的。

但是如果我们不让lct放任自由地access,而是一开始就给它构造一个比较优雅的姿态并让它静止(本来这棵树也不需要动),那么它也许就有救了。我们可以按照树链剖分的套路先划分出轻重边,然后对于重链建立一棵形态比较好的splay,至于轻儿子就跟原来的lct一样直接用轻边挂上即可。什么叫“形态比较好”呢?我们给每个点 \(x?\) 定义其权重为 size[x]-size[son[x]],其中 son[x] 是它的重儿子,那么对于一条重链,我们可以先找到它的带权重心作为当前节点,然后对左右分别递归建树。

by gkxx

似乎较lct常数更小,也蛮好些的。

#include <bits/stdc++.h> /*卡着时限过*/
using namespace std;

namespace io {
    const unsigned buffsize=1<<24,output=1<<24;
    static char ch[buffsize],*st=ch,*t=ch;
    inline char getc() {
        return((st==t)&&(t=(st=ch)+fread(ch,1,buffsize,stdin),st==t)?0:*st++);
    }
    static char out[output],*nowps=out;
    inline void flush() {
        fwrite(out,1,nowps-out,stdout);
        nowps=out;
    }
    template<typename t>inline void read(t&x) {
        x=0;
        static char ch;
        t f=1;
        for(ch=getc(); !isdigit(ch); ch=getc())if(ch=='-')f=-1;
        for(; isdigit(ch); ch=getc())x=x*10+(ch^48);
        x*=f;
    }
    template<typename t>inline void write(t x,char ch='\n') {
        if(!x)*nowps++=48;
        if(x<0)*nowps++='-',x=-x;
        static unsigned sta[111],tp;
        for(tp=0; x; x/=10)sta[++tp]=x%10;
        for(; tp; *nowps++=sta[tp--]^48);
        *nowps++=ch;
        flush();
    }
}
using io::read;
using io::write;

const int n=1e6+10;
const int inf=0x3f3f3f3f;

struct mtr {
    int a[2][2];
    int*operator[](int x) {return a[x]; }
    inline mtr() {}
    inline mtr(int g0,int g1) {
        a[0][0]=a[0][1]=g0;
        a[1][0]=g1;
        a[1][1]=-inf;
    }
    inline mtr operator*(mtr b) {
        mtr c;
        c[0][0]=max(a[0][0]+b[0][0],a[0][1]+b[1][0]);
        c[0][1]=max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
        c[1][0]=max(a[1][0]+b[0][0],a[1][1]+b[1][0]);
        c[1][1]=max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
        return c;
    }
    void print() {
        printf("%d %d\n%d %d\n\n",a[0][0],a[0][1],a[1][0],a[1][1]); 
    }
};

int n,m,a[n];
int head[n],to[n<<1],last[n<<1];
int siz[n],son[n],g[n][2];
inline void add_edge(int x,int y) {
    static int cnt=0;
    to[++cnt]=y,last[cnt]=head[x],head[x]=cnt;
}
void dfs1(int x,int pa) {
    siz[x]=1;
    g[x][1]=a[x];
    for(int i=head[x]; i; i=last[i]) {
        if(to[i]==pa) continue;
        dfs1(to[i],x);
        siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
        g[x][0]+=max(g[to[i]][0],g[to[i]][1]);
        g[x][1]+=g[to[i]][0];
    }
}
void dfs2(int x,int pa) {
    if(!son[x]) return;
    g[x][0]-=max(g[son[x]][0],g[son[x]][1]);
    g[x][1]-=g[son[x]][0];
    for(int i=head[x]; i; i=last[i]) 
        if(to[i]!=pa) dfs2(to[i],x); 
}

mtr g[n],pg[n];
int root,fa[n],ch[n][2];
int stk[n],tp;
bool is_root[n];

inline void update(int x) {
    pg[x]=g[x];
    if(ch[x][0]) pg[x]=pg[ch[x][0]]*pg[x];
    if(ch[x][1]) pg[x]=pg[x]*pg[ch[x][1]];
}
int chain(int l,int r) {
    if(r<l) return 0;
    int sum=0,pre=0;
    for(int i=l; i<=r; ++i) sum+=siz[stk[i]]-siz[son[stk[i]]];
    for(int i=l; i<=r; ++i) {
        pre+=siz[stk[i]]-siz[son[stk[i]]];
        if((pre<<1)>=sum) {
            int x=stk[i];
            ch[x][0]=chain(l,i-1);
            ch[x][1]=chain(i+1,r);
            if(ch[x][0]) fa[ch[x][0]]=x;
            if(ch[x][1]) fa[ch[x][1]]=x;
            update(x);
            return x;
        }
    }
    return 2333;
}
int tree(int top,int pa) {
    for(int x=top; x; x=son[pa=x]) {
        for(int i=head[x]; i; i=last[i]) {
            if(to[i]!=son[x]&&to[i]!=pa) {
                fa[tree(to[i],x)]=x;
            }
        } 
        g[x]=mtr(g[x][0],g[x][1]);
    }
    tp=0;
    for(int x=top; x; x=son[x]) stk[++tp]=x;
    return chain(1,tp);
}
inline void build() {
    root=tree(1,0);
    for(int i=1; i<=n; ++i) {
        is_root[i]=ch[fa[i]][0]!=i&&ch[fa[i]][1]!=i;
    }
}
inline int solve(int x,int y) {
    g[x][1]+=y-a[x];
    a[x]=y;
    for(int f0,f1; x; x=fa[x]) {
        f0=pg[x][0][0];
        f1=pg[x][1][0];
        g[x]=mtr(g[x][0],g[x][1]);
        update(x);
        if(fa[x]&&is_root[x]) {
            g[fa[x]][0]+=max(pg[x][0][0],pg[x][1][0])-max(f0,f1);
            g[fa[x]][1]+=pg[x][0][0]-f0;
        }
    }
    return max(pg[root][0][0],pg[root][1][0]);
}

int main() {
    read(n);
    read(m);
    for(int i=1; i<=n; ++i) read(a[i]);
    for(int x,y,i=n; --i; ) {
        read(x);
        read(y);
        add_edge(x,y);
        add_edge(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    build();
    int lastans=0;
    for(int x,y; m--; ) {
        read(x);
        read(y);
        x^=lastans;
        lastans=solve(x,y);
        write(lastans);
    }
    return 0;
}

noip18 保卫王国

给出一棵\(n?\)个节点树和\(m?\)次询问,每次询问强制选/不选两个点然后计算当前树上的最小覆盖集,询问互相独立。

提示:强制选一个点就是把它的点权改成0,强制不选就是改成 \(\infty\);最小覆盖权值+最大独立集权值=总权值。

#include <bits/stdc++.h>
using namespace std;

const int n=1e6+10;
const long long inf=1e10;

struct mtr {
    long long a[2][2];
    long long*operator[](int x) {return a[x]; }
    inline mtr() {} 
    inline mtr(long long g0,long long g1) {
        a[0][0]=a[0][1]=g0;
        a[1][0]=g1;
        a[1][1]=-inf;
    }
    inline mtr operator*(mtr b) {
        mtr c;
        c[0][0]=max(a[0][0]+b[0][0],a[0][1]+b[1][0]);
        c[0][1]=max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
        c[1][0]=max(a[1][0]+b[0][0],a[1][1]+b[1][0]);
        c[1][1]=max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
        return c;
    }
};

int n,m;
long long a[n],g[n][2];
int head[n],to[n<<1],last[n<<1];
int prt[n],siz[n],son[n];
inline void add_edge(int x,int y) {
    static int cnt=0;
    to[++cnt]=y,last[cnt]=head[x],head[x]=cnt;
}
void dfs1(int x,int pa) {
    siz[x]=1;
    g[x][1]=a[x];
    for(int i=head[x]; i; i=last[i]) {
        if(to[i]==pa) continue;
        prt[to[i]]=x;
        dfs1(to[i],x);
        siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
        g[x][0]+=max(g[to[i]][0],g[to[i]][1]);
        g[x][1]+=g[to[i]][0];
    }
}
void dfs2(int x,int pa) {
    if(!son[x]) return;
    g[x][0]-=max(g[son[x]][0],g[son[x]][1]);
    g[x][1]-=g[son[x]][0];
    for(int i=head[x]; i; i=last[i]) 
        if(to[i]!=pa) dfs2(to[i],x); 
}

mtr g[n],pg[n];
int root,fa[n],ch[n][2];
int stk[n],tp;
bool is_root[n];

inline void update(int x) {
    pg[x]=g[x];
    if(ch[x][0]) pg[x]=pg[ch[x][0]]*pg[x];
    if(ch[x][1]) pg[x]=pg[x]*pg[ch[x][1]];
}
int chain(int l,int r) {
    if(r<l) return 0;
    int sum=0,pre=0;
    for(int i=l; i<=r; ++i) sum+=siz[stk[i]]-siz[son[stk[i]]];
    for(int i=l; i<=r; ++i) {
        pre+=siz[stk[i]]-siz[son[stk[i]]];
        if((pre<<1)>=sum) {
            int x=stk[i];
            ch[x][0]=chain(l,i-1);
            ch[x][1]=chain(i+1,r);
            if(ch[x][0]) fa[ch[x][0]]=x;
            if(ch[x][1]) fa[ch[x][1]]=x;
            update(x);
            return x;
        }
    }
    return 2333; 
}
int tree(int top,int pa) {
    for(int x=top; x; x=son[pa=x]) {
        for(int i=head[x]; i; i=last[i]) {
            if(to[i]!=son[x]&&to[i]!=pa) {
                fa[tree(to[i],x)]=x;
            }
        } 
        g[x]=mtr(g[x][0],g[x][1]);
    }
    tp=0;
    for(int x=top; x; x=son[x]) stk[++tp]=x;
    return chain(1,tp);
}
inline void build() {
    root=tree(1,0);
    for(int i=1; i<=n; ++i) {
        is_root[i]=ch[fa[i]][0]!=i&&ch[fa[i]][1]!=i;
    }
}
long long tot,res;
inline void solve(int x,long long y) {
    tot+=y;
    g[x][1]+=y;
    for(long long f0,f1; x; x=fa[x]) {
        f0=pg[x][0][0];
        f1=pg[x][1][0];
        g[x]=mtr(g[x][0],g[x][1]);
        update(x);
        if(fa[x]&&is_root[x]) {
            g[fa[x]][0]+=max(pg[x][0][0],pg[x][1][0])-max(f0,f1);
            g[fa[x]][1]+=pg[x][0][0]-f0;
        }
    }
}
inline void solve(int x,int p,int y,int q) {
    long long sx,sy;
    if(!p&&!q) sx=inf,sy=inf,res=0;
    else if(!p&&q) sx=inf,sy=0,res=a[y];
    else if(p&&!q) sx=0,sy=inf,res=a[x];
    else sx=0,sy=0,res=a[x]+a[y];
    solve(x,sx-a[x]);
    solve(y,sy-a[y]);
    res+=tot-max(pg[root][0][0],pg[root][1][0]);
    solve(x,a[x]-sx);
    solve(y,a[y]-sy);
}

char type[10]; 
int main() { //此代码 在-o2时极快
    freopen("defense.in","r",stdin);
    freopen("defense.ans","w",stdout); 
    scanf("%d%d%s",&n,&m,type);
    for(int i=1; i<=n; ++i) {
        scanf("%lld",a+i);
        tot+=a[i];
    }
    for(int x,y,i=n; --i; ) {
        scanf("%d%d",&x,&y);
        add_edge(x,y);
        add_edge(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    build();
    for(int x,p,y,q; m--; ) {
        scanf("%d%d%d%d",&x,&p,&y,&q);
        if(!p&&!q&&(prt[x]==y||prt[y]==x)) {
            puts("-1");
            continue;
        }
        solve(x,p,y,q);
        printf("%lld\n",res);
    }
    return 0;
}

更多习(tian)题(keng)

bzoj4911 [sdoi2017]切树游戏

bzoj4721 洪水

软件
前端设计
程序设计
Java相关