-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy paththreads_attack_solution.c
91 lines (75 loc) · 1.88 KB
/
threads_attack_solution.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <pthread.h>
#include <stdbool.h>
#include <stdio.h>
const int NUM_THREADS = 2;
const int NUM_ITERS = 1000000;
typedef struct {
_Atomic bool wantsToEnter1;
_Atomic bool wantsToEnter2;
_Atomic int turn;
} MyLock;
MyLock buildLock() {
return ((MyLock) {
.wantsToEnter1 = false,
.wantsToEnter2 = false,
.turn = 1
});
}
typedef struct {
MyLock* lock;
int* counter;
int threadIdx;
} ThreadArgument;
void performLock(MyLock* lock, int threadIdx) {
if (threadIdx == 1) {
lock->turn = 2;
lock->wantsToEnter1 = true;
while (lock->wantsToEnter2 && (lock->turn == 2)) {
// Just busy wait.
}
} else {
lock->turn = 1;
lock->wantsToEnter2 = true;
while (lock->wantsToEnter1 && (lock->turn == 1)) {
// Just busy wait.
}
}
}
void performUnlock(MyLock* lock, int threadIdx) {
if (threadIdx == 1) {
lock->wantsToEnter1 = false;
} else {
lock->wantsToEnter2 = false;
}
}
void* thread_main(void* argument) {
ThreadArgument* threadArgument = (ThreadArgument*) argument;
for (int i = 0; i < NUM_ITERS; i ++) {
performLock(threadArgument->lock, threadArgument->threadIdx);
*(threadArgument->counter) += 1;
performUnlock(threadArgument->lock, threadArgument->threadIdx);
}
return NULL;
}
int main() {
int counter = 0;
MyLock lock = buildLock();
ThreadArgument threadArgument1 = ((ThreadArgument) {
.lock = &lock,
.counter = &counter,
.threadIdx = 1
});
ThreadArgument threadArgument2 = ((ThreadArgument) {
.lock = &lock,
.counter = &counter,
.threadIdx = 2
});
pthread_t threads[NUM_THREADS];
pthread_create(&threads[0], NULL, thread_main, &threadArgument1);
pthread_create(&threads[1], NULL, thread_main, &threadArgument2);
for (int i = 0; i < NUM_THREADS; i ++) {
pthread_join(threads[i], NULL);
}
printf("Final Result: %d", counter);
return 0;
}