diff options
author | Lars Wirzenius <liw@liw.fi> | 2021-05-29 18:50:38 +0300 |
---|---|---|
committer | Lars Wirzenius <liw@liw.fi> | 2021-05-29 18:50:38 +0300 |
commit | 88165aace3c1586adfee8285c9e9689a00bb2f31 (patch) | |
tree | 03828a67f1d1ed98c276b86405994ab7070d2ca4 | |
parent | 35e374cb434166721d26d1b94590b6c88a170607 (diff) | |
download | oso-work-sample-88165aace3c1586adfee8285c9e9689a00bb2f31.tar.gz |
feat: add max-client.py from OSO
Sponsored-by: author
-rw-r--r-- | max-client.py | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/max-client.py b/max-client.py new file mode 100644 index 0000000..d691909 --- /dev/null +++ b/max-client.py @@ -0,0 +1,88 @@ +import argparse +import requests +from dataclasses import asdict, dataclass + + +@dataclass +class Compare: + request_id: int + left: int + right: int + type: str = "compare" + +@dataclass +class ComparisonResult: + request_id: int + answer: bool + type: str = "comp_result" + +@dataclass +class ComputeMax: + length: int + type: str = "compute_max" + +@dataclass +class ComputeMin: + length: int + type: str = "compute_min" + +@dataclass +class Done: + result: int + type: str = "done" + +def message_to_struct(message): + if message["type"] == "compare": + return Compare(**message) + elif message["type"] == "comp_result": + return ComparisonResult(**message) + elif message["type"] == "compute_min": + return ComputeMin(**message) + elif message["type"] == "compute_max": + return ComputeMax(**message) + elif message["type"] == "done": + return Done(**message) + +class Client: + def __init__(self, address, log=False): + self.address = address if address else "http://localhost:5000" + self.log = log + + def send(self, data): + response = requests.post(self.address, json=asdict(data)) + json = response.json() + if self.log: + print(json) + return message_to_struct(json) + + def compute(self, values, op): + req = None + if op == "min": + req = ComputeMin(len(values)) + elif op == "max": + req = ComputeMax(len(values)) + else: + assert False, "not supported operation: " + op + next_message = self.send(req) + + while True: + if next_message.type == "done": + return values[next_message.result] + elif next_message.type == "compare": + request_id = next_message.request_id + left = next_message.left + right = next_message.right + next_message = self.send(ComparisonResult(request_id, values[left] < values[right])) + else: + raise Exception("Unexpected message: ", next_message) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run the max computer') + parser.add_argument('--address', type=str, required=False, + help='address of the max computer (defaults to http://localhost:5000)') + + args = parser.parse_args() + client = Client(args.address, log=True) + assert 3 == client.compute([1, 2, 3, 1], op="max") + assert 1 == client.compute([1, 2, 3, 1], op="min") |