Skip to code
Skip to analysis

This is a explanation of this problem from USACO's training website. I have converted it to markdown. Please do not just copy code; you will not learn anything; at least type it out and understand so you can do it yourself in the future!

Sorting is one of the most frequently performed computational tasks. Consider the special sorting problem in which the records to be sorted have at most three different key values. This happens for instance when we sort medalists of a competition according to medal value, that is, gold medalists come first, followed by silver, and bronze medalists come last.

In this task the possible key values are the integers 1, 2 and 3. The required sorting order is non-decreasing. However, sorting has to be accomplished by a sequence of exchange operations. An exchange operation, defined by two position numbers p and q, exchanges the elements in positions p and q.

You are given a sequence of key values. Write a program that computes the minimal number of exchange operations that are necessary to make the sequence sorted.

PROGRAM NAME: sort3

INPUT FORMAT

Line 1: N (1 <= N <= 1000), the number of records to be sorted
Lines 2-N+1: A single integer from the set {1, 2, 3}

SAMPLE INPUT (file sort3.in)

9
2
2
1
3
3
3
2
3
1

OUTPUT FORMAT

A single line containing the number of exchanges required

SAMPLE OUTPUT (file sort3.out)

4

CODE

Java


C++


Pascal



ANALYSIS

Russ Cox

We read the input into an array, and sort a copy of it in another array, so we know what we have and what we want.

A swap touches two elements, so it can correct at most two misplaced elements. We run through the array looking for pairs of misplaced elements that a swap would correct, and do those swaps.

The remaining misplaced elements form little cycles: a 1 where a 2 should be, a 2 where a 3 should be, and that 3 where the 1 should be. It takes two swaps to correct such a cycle. So we count the number of such cycles (by counting misplaced elements and dividing by three) and then add in two times that many swaps.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>

#define MAXN 	1000

int n;
int have[MAXN];
int want[MAXN];

int
intcmp(const void *a, const void *b)
{
    return *(int*)a - *(int*)b;
}

void
main(void)
{
    int i, j, k, n, nn[4], nswap, nbad;
    FILE *fin, *fout;

    fin = fopen("sort3.in", "r");
    fout = fopen("sort3.out", "w");
    assert(fin != NULL && fout != NULL);

    fscanf(fin, "%d", &n);

    for(i=0; i<n; i++) {
	fscanf(fin, "%d", &have[i]);
	want[i] = have[i];
    }
    qsort(want, n, sizeof(want[0]), intcmp);

    /* swaps that correct two elements */
    nswap = 0;
    for(i=0; i<n; i++)
    for(j=0; j<n; j++) {
	if(have[i] != want[i] && have[j] != want[j]
	        && have[i] == want[j] && have[j] == want[i]) {
	    have[i] = want[i];
	    have[j] = want[j];
	    nswap++;
	}
    }

    /* remainder are pairs of swaps that correct three elements */
    nbad = 0;
    for(i=0; i<n; i++)
	if(have[i] != want[i])
	    nbad++;

    assert(nbad%3 == 0);
    nswap += nbad/3 * 2;

    fprintf(fout, "%d\n", nswap);
    exit(0);
}

Dan Jasper

Dan Jasper from Germany writes:

The previous solution needs a copy of the original list and O(N2) time. I think it was necessary with the original task, where you had to print out the exchange operations, but to just count them, there is a more efficient solution. You can count the “1”, “2” and “3”, so you can calculate in what parts (buckets) of the list they have to be. The rest of the solution is somehow equal to yours, but all in all it uses O(N) time and does not need a copy of the list.

#include <stdio.h>

int list[1000], N, res, count[3], start[3];
in[3][3]; // this counts the number of "1"s in bucket "2", ...

void readFile() {
    FILE *f; int i;
    f = fopen("sort3.in", "r");
    fscanf(f, "%d", &N);
    for(i = 0; i < N; i++) {
        fscanf(f, "%d", &list[i]);
    }
    fclose(f);
}

void writeFile() {
    FILE *f;
    f = fopen("sort3.out", "w");
    fprintf(f, "%d\n", res);
    fclose(f);
}

void findBuckets() {
    int i;
    for(i = 0; i < N; i++) count[list[i]-1]++;
    start[0] = 0;
    start[1] = count[0];
    start[2] = count[0] + count[1];
}

void findWhere() {
    int i, j;
    for(i = 0; i < 3; i++) {
        for(j = 0; j < count[i]; j++) in[list[start[i] + j]-1][i]++;
    }
   }

void sort() {
    int h;
    // 1 <-> 2
    h = in[0][1];
    if(in[1][0] < h) h = in[1][0];
    res += h; in[0][1] -= h; in[1][0] -= h;
    // 3 <-> 2
    h = in[2][1];
    if(in[1][2] < h) h = in[1][2];
    res += h; in[2][1] -= h; in[1][2] -= h;
    // 1 <-> 3
    h = in[0][2];
    if(in[2][0] < h) h = in[2][0];
    res += h; in[0][2] -= h; in[2][0] -= h;

    // Cycles
    res += (in[0][1] + in[0][2]) * 2;
}

void process() {
    findBuckets();
    findWhere();
    sort();
}

int main () {
    readFile();
    process();
    writeFile();
    return 0;
}

Evgeni Dzhelyov

Bulgaria’s Evgeni Dzhelyov writes:

I read the elements one by one and count them, so we know exactly how 1s, 2s and 3s we have and we know how the sorted array looks like. Then I count the 2s in 1 and 1s in 2, so it is obvious that we need min(2sin1, 1sin2) swaps, I do the same for 1-3 and 2-3. The sum of all these mins give us the number of direct swaps we need. After that number of swaps we would have N 1s, 2s and 3s and we would need 2*N swaps, where N is max(2sin1, 1sin2) - min(2sin1, 1sin2).

#include <fstream>

using namespace std;

int min (int a, int b) { return a < b ? a : b; }
int max (int a, int b) { return a > b ? a : b; }

int main () {
    int s[1024];
    int n;
    int sc[4] = {0};
    
    ifstream fin("sort3.in");
    ofstream fout("sort3.out");
    fin>>n;
    for (int i = 0; i < n; i++) {
        fin>>s[i];
        sc[s[i]]++;
    }
    int s12 = 0, s13 = 0, s21 = 0, s31 = 0, s23 = 0, s32 = 0;
    for (int i = 0; i < sc[1]; i++){
        if (s[i] == 2) s12++;
        if (s[i] == 3) s13++;
    }
    
    for (int i = sc[1]; i < sc[1]+sc[2]; i++){
        if (s[i] == 1) s21++;
        if (s[i] == 3) s23++;
    }
    
    for (int i = sc[1]+sc[2]; i < sc[1]+sc[2]+sc[3]; i++){
        if (s[i] == 1) s31++;
        if (s[i] == 2) s32++;
    }
    
    fout<<min(s12, s21)+min(s13, s31)+min(s23, s32) +
					2*(max(s12, s21) - min(s12, s21))<<endl;
    return 0;
}

Back to top