题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=5109
题目分析:过了挺久终于把这个坑填了。一开始以为是一道很难的题,后来发现也不难想。
由于懒得打题解,直接引用出题人的题解好了(主要是来贴代码)QAQ:
虽然题目中给定的是无向图,但是实际上我们可以先从
S
出发求一遍最短路,然后问题变成了:“在有向无环图上,求有多少个满足条件的点对A,B ,满足从
S
到T 的所有路径一定经过
A,B
其中一点,并且不存在路径同时经过
A,B
”。求解这到题目的一个关键点在于: 满足条件的点对
A,B
具有特点:从
S
到A 的方案数
×
从
A
到T 的方案数 + 从
S
到B 的方案数
×
从
B
到T 的方案数
=
从S 到
T
的方案数。所以在有向无环图上用动态规划求解路径条数,再去掉
A 可以到达
B
或B 可以到达
A
的情况即可求解这到题目。PS:方案数可能会爆掉怎么办?可以对方案数求余一个大整数,如果觉得不够的话可以求余两个大整数。
定义
F(X)= 从
S
到X 的方案数
×
从
X
到T 的方案数 = 从
S
经过X 到达
T
的方案数,所以满足条件的点对A,B 为:
F(A)+F(B)=F(T)
A
和B 不能相互到达对于条件
1
,我们可以使用数据结构进行优化(使用std::map
即可),而对于条件2 ,我们可以使用bitset
位压
32
或者
64
位进行加速,使得最终时间和空间都能够承受。时间复杂度:
O(nlogn+nmw)
,其中
w
<script type=”math/tex” id=”MathJax-Element-2256″>w</script> 是位压的字长。
我写代码的时候模了3个大质数,并且手写了个Hash表处理条件1,结果代码比标程长到不知哪里去了……
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=50010;
const int maxm=800;
const unsigned long long Max=1ULL<<22;
const long long M[3]={998244353,1000000007,1998585857};
const long long Mod=55837;
const long long M1=1333331;
const long long M2=23252729;
const long long M3=19260817;
typedef long long LL;
typedef unsigned long long ULL;
struct edge
{
int obj,len;
edge *Next;
} e[maxn<<2];
edge *head[maxn];
edge *nhead[maxn];
int cur=-1;
int Heap[maxn];
int id[maxn];
LL dis[maxn];
int tail;
int pin[maxn];
int que[maxn];
int he,ta;
struct data
{
LL cnt1[3],cnt2[3];
int num;
} a[maxn];
ULL get[maxn][maxm];
struct Hash_data
{
LL val[3];
ULL Node[maxm];
int Num;
} Hash[Mod];
int Cnt[Max];
int n,m,s,t,sn;
LL ans=0;
void Add(edge **Head,int x,int y,int z)
{
cur++;
e[cur].obj=y;
e[cur].len=z;
e[cur].Next=Head[x];
Head[x]=e+cur;
}
int Delete()
{
int temp=Heap[1];
Heap[1]=Heap[tail];
tail--;
id[ Heap[1] ]=1;
int x=1;
while (1)
{
int y=x,Left=x<<1,Right=Left|1;
if ( Left<=tail && dis[ Heap[Left] ]<dis[ Heap[y] ] ) y=Left;
if ( Right<=tail && dis[ Heap[Right] ]<dis[ Heap[y] ] ) y=Right;
if (y==x) break;
swap(Heap[x],Heap[y]);
swap(id[ Heap[x] ],id[ Heap[y] ]);
x=y;
}
return temp;
}
void Update(int x)
{
while (x>1)
{
int y=x>>1;
if (dis[ Heap[y] ]<=dis[ Heap[x] ]) break;
swap(Heap[x],Heap[y]);
swap(id[ Heap[x] ],id[ Heap[y] ]);
x=y;
}
}
void Release(int x,int y,LL v)
{
if (dis[y]<=dis[x]+v) return;
dis[y]=dis[x]+v;
Update(id[y]);
}
void Dijkstra()
{
for (int i=1; i<=n; i++) dis[i]=1e15;
dis[s]=0;
tail=1;
Heap[1]=s;
id[s]=1;
for (int i=1; i<=n; i++) if (i!=s) Heap[++tail]=i,id[i]=tail;
for (int i=1; i<n; i++)
{
int node=Delete();
for (edge *p=head[node]; p; p=p->Next)
Release(node,p->obj,p->len);
}
}
void Work(int x,int y)
{
y--;
int u=y/64,v=y%64;
ULL temp=1;
temp<<=v;
get[x][u]|=temp;
}
void Calc()
{
for (int i=1; i<=n; i++) a[i].num=i,Work(i,i);
sn=(n-1)/64;
he=0,ta=1;
que[1]=s;
for (int k=0; k<3; k++) a[s].cnt1[k]=1;
while (he<ta)
{
int node=que[++he];
for (edge *p=nhead[node]; p; p=p->Next)
{
int to=p->obj;
for (int k=0; k<3; k++) a[to].cnt1[k]=(a[to].cnt1[k]+a[node].cnt1[k])%M[k];
pin[to]--;
if (!pin[to]) que[++ta]=to;
}
}
for (int k=0; k<3; k++) a[t].cnt2[k]=1;
for (int i=ta; i>=1; i--)
{
int node=que[i];
for (edge *p=nhead[node]; p; p=p->Next)
{
int to=p->obj;
for (int k=0; k<3; k++) a[node].cnt2[k]=(a[node].cnt2[k]+a[to].cnt2[k])%M[k];
}
}
for (int i=1; i<=n; i++)
for (int k=0; k<3; k++) a[i].cnt1[k]=(a[i].cnt1[k]*a[i].cnt2[k])%M[k];
for (int i=1; i<=ta; i++)
{
int node=que[i];
if ( a[node].cnt1[0] && a[node].cnt1[1] && a[node].cnt1[2] )
for (edge *p=nhead[node]; p; p=p->Next)
{
int to=p->obj;
if ( a[to].cnt1[0] && a[to].cnt1[1] && a[to].cnt1[2] )
for (int k=0; k<=sn; k++) get[to][k]|=get[node][k];
}
}
}
void Push(int x,LL v1,LL v2,LL v3)
{
LL y=(v1*M1+v2*M2+v3*M3)%Mod;
while ( Hash[y].val[0]!=-1 &&
( Hash[y].val[0]!=v1 || Hash[y].val[1]!=v2 || Hash[y].val[2]!=v3 ) )
y=(y+1)%Mod;
Hash[y].val[0]=v1;
Hash[y].val[1]=v2;
Hash[y].val[2]=v3;
Hash[y].Node[(x-1)/64]|=( 1ULL<<((x-1)%64) );
Hash[y].Num++;
}
int Get(ULL v)
{
int temp=0;
for (int i=0; i<3; i++) temp+=Cnt[v&(Max-1ULL)],v>>=22;
return temp;
}
void Check(int x,LL v1,LL v2,LL v3)
{
LL y=(v1*M1+v2*M2+v3*M3)%Mod;
while ( Hash[y].val[0]!=-1 &&
( Hash[y].val[0]!=v1 || Hash[y].val[1]!=v2 || Hash[y].val[2]!=v3 ) )
y=(y+1)%Mod;
if ( Hash[y].val[0]==-1 ) return;
ans+=Hash[y].Num;
for (int k=0; k<=sn; k++) ans-=( Get(get[x][k]&Hash[y].Node[k])<<1 );
}
void Solve()
{
for (int i=0; i<Mod; i++) Hash[i].val[0]=-1;
for (int i=1; i<Max; i++) Cnt[i]=Cnt[i^(i&(-i))]+1;
for (int i=1; i<=n; i++) Push(i,a[i].cnt1[0],a[i].cnt1[1],a[i].cnt1[2]);
for (int i=1; i<=n; i++)
{
LL v1=(a[t].cnt1[0]-a[i].cnt1[0]+M[0])%M[0];
LL v2=(a[t].cnt1[1]-a[i].cnt1[1]+M[1])%M[1];
LL v3=(a[t].cnt1[2]-a[i].cnt1[2]+M[2])%M[2];
Check(i,v1,v2,v3);
if ( v1==a[i].cnt1[0] && v2==a[i].cnt1[1] && v3==a[i].cnt1[2] ) ans++;
}
ans>>=1;
}
int main()
{
freopen("chicken.in","r",stdin);
freopen("chicken.out","w",stdout);
scanf("%d%d%d%d",&n,&m,&s,&t);
for (int i=1; i<=n; i++) head[i]=nhead[i]=NULL;
for (int i=1; i<=m; i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
Add(head,u,v,w);
Add(head,v,u,w);
}
Dijkstra();
for (int i=1; i<=n; i++)
for (edge *p=head[i]; p; p=p->Next)
{
int to=p->obj;
if ( dis[i]+(long long)p->len==dis[to] )
Add(nhead,i,to,0),pin[to]++;
}
Calc();
Solve();
printf("%lld\n",ans);
return 0;
}