4229 – Kdist

De la Universitas MediaWiki

Bujorel s-a apucat de pomicultură şi a însămânţat un arbore (graf conex aciclic) cu N noduri, fiecare nod având o culoare dată dintr-un interval [1, K]. Acum, după ce arborele a crescut, el doreşte să ştie, pentru fiecare culoare, suma distanţelor dintre toate perechile de noduri ale arborelui ce au culoarea respectivă. Distanţa dintre două noduri se defineşte ca fiind numărul de muchii de pe drumul dintre cele două noduri.

Cerința

Deoarece Bujorel a folosit foarte mult îngrăşământ la plantarea arborelui, acesta a crescut foarte mult şi voi trebuie să scrieţi un program care calculează suma distanţelor dintre nodurile cu aceeaşi culoare.

Date de intrare

Pe prima linie se află două numere naturale N şi K, numărul de noduri ale arborelui, respectiv numărul de culori în care sunt vopsite nodurile. Pe următoarele N-1 linii este descris arborele, fiecare linie conţinând două numere naturale x şi y, reprezentând o muchie dintre nodul x şi nodul y. În continuare sunt prezente N linii, a i-a dintre aceste linii având un număr întreg c aparţinând intervalului [1, K] reprezentând culoarea nodului i.

Date de ieșire

Se vor afișa K linii, cea de-a i-a linie conţinând suma distanţelor dintre toate perechile de noduri ce au culoarea i.

Restricții și precizări

  • 1 ≤ K ≤ N ≤ 200.000
  • În calcularea sumei distanţelor dintre nodurile cu aceeaşi culoare, fiecare pereche de noduri (x, y) va fi considerată o singură dată.

Exemplu:

Intrare

6 3
1 2
1 3
3 4
3 5
5 6
1
2
2
1
2
3

Ieșire

2
6
0

Explicație

Pentru culoarea 1 avem o pereche: (1, 4), cu distanţa 2. Pentru culoarea 2 avem trei perechi: (2, 3) cu distanţa 2, (2, 5) cu distanţa 3 și (3, 5) cu distanţa 1. Pentru culoarea 3 avem un singur nod cu această culoare, deci răspunsul este 0.\

Rezolvare

from collections import defaultdict, deque


def calculate_color_distances(n, k, edges, colors):
 
    tree = defaultdict(list)
    for u, v in edges:
        tree[u].append(v)
        tree[v].append(u)


    color_nodes = defaultdict(list)
    for i in range(n):
        color_nodes[colors[i]].append(i + 1)


    def bfs_distance_sum(start, valid_nodes):
        visited = set()
        queue = deque([(start, 0)])
        visited.add(start)
        distance_sum = 0
        node_count = 0

        while queue:
            node, dist = queue.popleft()
            distance_sum += dist
            node_count += 1

            for neighbor in tree[node]:
                if neighbor not in visited and neighbor in valid_nodes:
                    visited.add(neighbor)
                    queue.append((neighbor, dist + 1))

        return distance_sum, node_count


    color_distance_sums = [0] * k
    for color in range(1, k + 1):
        nodes = color_nodes[color]
        if len(nodes) < 2:
            continue

        total_distance_sum = 0
        for node in nodes:
            dist_sum, node_count = bfs_distance_sum(node, set(nodes))
            total_distance_sum += dist_sum


        color_distance_sums[color - 1] = total_distance_sum // 2

    return color_distance_sums



n, k = map(int, input().split())
edges = [tuple(map(int, input().split())) for _ in range(n - 1)]
colors = [int(input()) for _ in range(n)]


result = calculate_color_distances(n, k, edges, colors)
for dist_sum in result:
    print(dist_sum)