BZOJ5109:[CodePlus 2017]大吉大利,晚上吃鸡! (最短路+Hash表+二进制压位)

  • Post author:
  • Post category:其他

题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=5109


题目分析:过了挺久终于把这个坑填了。一开始以为是一道很难的题,后来发现也不难想。

由于懒得打题解,直接引用出题人的题解好了(主要是来贴代码)QAQ:

虽然题目中给定的是无向图,但是实际上我们可以先从
S
出发求一遍最短路,然后问题变成了:“在有向无环图上,求有多少个满足条件的点对 A,B ,满足从
S
T 的所有路径一定经过
A,B
其中一点,并且不存在路径同时经过
A,B
”。

求解这到题目的一个关键点在于: 满足条件的点对
A,B
具有特点:从
S
A 的方案数
×

A
T 的方案数 + 从
S
B 的方案数
×

B
T 的方案数
=
S
T
的方案数。

所以在有向无环图上用动态规划求解路径条数,再去掉 A 可以到达
B
B 可以到达
A
的情况即可求解这到题目。

PS:方案数可能会爆掉怎么办?可以对方案数求余一个大整数,如果觉得不够的话可以求余两个大整数。

定义 F(X)=
S
X 的方案数
×

X
T 的方案数 = 从
S
经过 X 到达
T
的方案数,所以满足条件的点对 A,B 为:


  1. F(A)+F(B)=F(T)

  2. A
    B 不能相互到达

对于条件
1
,我们可以使用数据结构进行优化(使用std::map即可),而对于条件 2 ,我们可以使用 bitset 位压
32
或者
64
位进行加速,使得最终时间和空间都能够承受。

时间复杂度:
O(nlogn+nmw)
,其中
w
<script type=”math/tex” id=”MathJax-Element-2256″>w</script> 是位压的字长。

我写代码的时候模了3个大质数,并且手写了个Hash表处理条件1,结果代码比标程长到不知哪里去了……


CODE:

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;

const int maxn=50010;
const int maxm=800;
const unsigned long long Max=1ULL<<22;
const long long M[3]={998244353,1000000007,1998585857};
const long long Mod=55837;

const long long M1=1333331;
const long long M2=23252729;
const long long M3=19260817;

typedef long long LL;
typedef unsigned long long ULL;

struct edge
{
    int obj,len;
    edge *Next;
} e[maxn<<2];
edge *head[maxn];
edge *nhead[maxn];
int cur=-1;

int Heap[maxn];
int id[maxn];
LL dis[maxn];
int tail;

int pin[maxn];
int que[maxn];
int he,ta;

struct data
{
    LL cnt1[3],cnt2[3];
    int num;
} a[maxn];
ULL get[maxn][maxm];

struct Hash_data
{
    LL val[3];
    ULL Node[maxm];
    int Num;
} Hash[Mod];

int Cnt[Max];
int n,m,s,t,sn;
LL ans=0;

void Add(edge **Head,int x,int y,int z)
{
    cur++;
    e[cur].obj=y;
    e[cur].len=z;
    e[cur].Next=Head[x];
    Head[x]=e+cur;
}

int Delete()
{
    int temp=Heap[1];
    Heap[1]=Heap[tail];
    tail--;
    id[ Heap[1] ]=1;
    int x=1;
    while (1)
    {
        int y=x,Left=x<<1,Right=Left|1;
        if ( Left<=tail && dis[ Heap[Left] ]<dis[ Heap[y] ] ) y=Left;
        if ( Right<=tail && dis[ Heap[Right] ]<dis[ Heap[y] ] ) y=Right;
        if (y==x) break;
        swap(Heap[x],Heap[y]);
        swap(id[ Heap[x] ],id[ Heap[y] ]);
        x=y;
    }
    return temp;
}

void Update(int x)
{
    while (x>1)
    {
        int y=x>>1;
        if (dis[ Heap[y] ]<=dis[ Heap[x] ]) break;
        swap(Heap[x],Heap[y]);
        swap(id[ Heap[x] ],id[ Heap[y] ]);
        x=y;
    }
}

void Release(int x,int y,LL v)
{
    if (dis[y]<=dis[x]+v) return;
    dis[y]=dis[x]+v;
    Update(id[y]);
}

void Dijkstra()
{
    for (int i=1; i<=n; i++) dis[i]=1e15;
    dis[s]=0;

    tail=1;
    Heap[1]=s;
    id[s]=1;
    for (int i=1; i<=n; i++) if (i!=s) Heap[++tail]=i,id[i]=tail;

    for (int i=1; i<n; i++)
    {
        int node=Delete();
        for (edge *p=head[node]; p; p=p->Next)
            Release(node,p->obj,p->len);
    }
}

void Work(int x,int y)
{
    y--;
    int u=y/64,v=y%64;
    ULL temp=1;
    temp<<=v;
    get[x][u]|=temp;
}

void Calc()
{
    for (int i=1; i<=n; i++) a[i].num=i,Work(i,i);
    sn=(n-1)/64;

    he=0,ta=1;
    que[1]=s;
    for (int k=0; k<3; k++) a[s].cnt1[k]=1;
    while (he<ta)
    {
        int node=que[++he];
        for (edge *p=nhead[node]; p; p=p->Next)
        {
            int to=p->obj;
            for (int k=0; k<3; k++) a[to].cnt1[k]=(a[to].cnt1[k]+a[node].cnt1[k])%M[k];
            pin[to]--;
            if (!pin[to]) que[++ta]=to;
        }
    }

    for (int k=0; k<3; k++) a[t].cnt2[k]=1;
    for (int i=ta; i>=1; i--)
    {
        int node=que[i];
        for (edge *p=nhead[node]; p; p=p->Next)
        {
            int to=p->obj;
            for (int k=0; k<3; k++) a[node].cnt2[k]=(a[node].cnt2[k]+a[to].cnt2[k])%M[k];
        }
    }

    for (int i=1; i<=n; i++)
        for (int k=0; k<3; k++) a[i].cnt1[k]=(a[i].cnt1[k]*a[i].cnt2[k])%M[k];
    for (int i=1; i<=ta; i++)
    {
        int node=que[i];
        if ( a[node].cnt1[0] && a[node].cnt1[1] && a[node].cnt1[2] )
            for (edge *p=nhead[node]; p; p=p->Next)
            {
                int to=p->obj;
                if ( a[to].cnt1[0] && a[to].cnt1[1] && a[to].cnt1[2] )
                    for (int k=0; k<=sn; k++) get[to][k]|=get[node][k];
            }
    }
}

void Push(int x,LL v1,LL v2,LL v3)
{
    LL y=(v1*M1+v2*M2+v3*M3)%Mod;
    while ( Hash[y].val[0]!=-1 &&
            ( Hash[y].val[0]!=v1 || Hash[y].val[1]!=v2 || Hash[y].val[2]!=v3 ) )
        y=(y+1)%Mod;
    Hash[y].val[0]=v1;
    Hash[y].val[1]=v2;
    Hash[y].val[2]=v3;
    Hash[y].Node[(x-1)/64]|=( 1ULL<<((x-1)%64) );
    Hash[y].Num++;
}

int Get(ULL v)
{
    int temp=0;
    for (int i=0; i<3; i++) temp+=Cnt[v&(Max-1ULL)],v>>=22;
    return temp;
}

void Check(int x,LL v1,LL v2,LL v3)
{
    LL y=(v1*M1+v2*M2+v3*M3)%Mod;
    while ( Hash[y].val[0]!=-1 &&
            ( Hash[y].val[0]!=v1 || Hash[y].val[1]!=v2 || Hash[y].val[2]!=v3 ) )
        y=(y+1)%Mod;
    if ( Hash[y].val[0]==-1 ) return;
    ans+=Hash[y].Num;
    for (int k=0; k<=sn; k++) ans-=( Get(get[x][k]&Hash[y].Node[k])<<1 );
}

void Solve()
{
    for (int i=0; i<Mod; i++) Hash[i].val[0]=-1;
    for (int i=1; i<Max; i++) Cnt[i]=Cnt[i^(i&(-i))]+1;
    for (int i=1; i<=n; i++) Push(i,a[i].cnt1[0],a[i].cnt1[1],a[i].cnt1[2]);
    for (int i=1; i<=n; i++)
    {
        LL v1=(a[t].cnt1[0]-a[i].cnt1[0]+M[0])%M[0];
        LL v2=(a[t].cnt1[1]-a[i].cnt1[1]+M[1])%M[1];
        LL v3=(a[t].cnt1[2]-a[i].cnt1[2]+M[2])%M[2];
        Check(i,v1,v2,v3);
        if ( v1==a[i].cnt1[0] && v2==a[i].cnt1[1] && v3==a[i].cnt1[2] ) ans++;
    }
    ans>>=1;
}

int main()
{
    freopen("chicken.in","r",stdin);
    freopen("chicken.out","w",stdout);

    scanf("%d%d%d%d",&n,&m,&s,&t);
    for (int i=1; i<=n; i++) head[i]=nhead[i]=NULL;
    for (int i=1; i<=m; i++)
    {
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        Add(head,u,v,w);
        Add(head,v,u,w);
    }

    Dijkstra();

    for (int i=1; i<=n; i++)
        for (edge *p=head[i]; p; p=p->Next)
        {
            int to=p->obj;
            if ( dis[i]+(long long)p->len==dis[to] )
                Add(nhead,i,to,0),pin[to]++;
        }

    Calc();

    Solve();
    printf("%lld\n",ans);

    return 0;
}


版权声明:本文为KsCla原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。