1
0

Merge remote-tracking branch 'origin/release-v1.36' into hs/hacked-together-event-cache

This commit is contained in:
Will Hunt
2021-06-16 11:32:27 +01:00
233 changed files with 11035 additions and 4686 deletions

View File

@@ -30,7 +30,11 @@ def exit(status: int = 0, message: Optional[str] = None):
def format_plain(public_key: nacl.signing.VerifyKey):
print(
"%s:%s %s"
% (public_key.alg, public_key.version, encode_verify_key_base64(public_key),)
% (
public_key.alg,
public_key.version,
encode_verify_key_base64(public_key),
)
)
@@ -50,7 +54,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read",
"key_file",
nargs="+",
type=argparse.FileType("r"),
help="The key file to read",
)
parser.add_argument(
@@ -63,7 +70,7 @@ if __name__ == "__main__":
parser.add_argument(
"--expiry-ts",
type=int,
default=int(time.time() * 1000) + 6*3600000,
default=int(time.time() * 1000) + 6 * 3600000,
help=(
"The expiry time to use for -x, in milliseconds since 1970. The default "
"is (now+6h)."

View File

@@ -11,23 +11,22 @@ if __name__ == "__main__":
parser.add_argument(
"--config-dir",
default="CONFDIR",
help="The path where the config files are kept. Used to create filenames for "
"things like the log config and the signing key. Default: %(default)s",
"things like the log config and the signing key. Default: %(default)s",
)
parser.add_argument(
"--data-dir",
default="DATADIR",
help="The path where the data files are kept. Used to create filenames for "
"things like the database and media store. Default: %(default)s",
"things like the database and media store. Default: %(default)s",
)
parser.add_argument(
"--server-name",
default="SERVERNAME",
help="The server name. Used to initialise the server_name config param, but also "
"used in the names of some of the config files. Default: %(default)s",
"used in the names of some of the config files. Default: %(default)s",
)
parser.add_argument(
@@ -41,21 +40,22 @@ if __name__ == "__main__":
"--generate-secrets",
action="store_true",
help="Enable generation of new secrets for things like the macaroon_secret_key."
"By default, these parameters will be left unset."
"By default, these parameters will be left unset.",
)
parser.add_argument(
"-o", "--output-file",
type=argparse.FileType('w'),
"-o",
"--output-file",
type=argparse.FileType("w"),
default=sys.stdout,
help="File to write the configuration to. Default: stdout",
)
parser.add_argument(
"--header-file",
type=argparse.FileType('r'),
type=argparse.FileType("r"),
help="File from which to read a header, which will be printed before the "
"generated config.",
"generated config.",
)
args = parser.parse_args()

View File

@@ -41,7 +41,7 @@ if __name__ == "__main__":
parser.add_argument(
"-c",
"--config",
type=argparse.FileType('r'),
type=argparse.FileType("r"),
help=(
"Path to server config file. "
"Used to read in bcrypt_rounds and password_pepper."
@@ -72,8 +72,8 @@ if __name__ == "__main__":
pw = unicodedata.normalize("NFKC", password)
hashed = bcrypt.hashpw(
pw.encode('utf8') + password_pepper.encode("utf8"),
pw.encode("utf8") + password_pepper.encode("utf8"),
bcrypt.gensalt(bcrypt_rounds),
).decode('ascii')
).decode("ascii")
print(hashed)

View File

@@ -301,8 +301,7 @@ class Porter(object):
return table, already_ported, total_to_port, forward_chunk, backward_chunk
async def get_table_constraints(self) -> Dict[str, Set[str]]:
"""Returns a map of tables that have foreign key constraints to tables they depend on.
"""
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
def _get_constraints(txn):
# We can pull the information about foreign key constraints out from
@@ -511,7 +510,9 @@ class Porter(object):
return
def build_db_store(
self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
self,
db_config: DatabaseConnectionConfig,
allow_outdated_version: bool = False,
):
"""Builds and returns a database store using the provided configuration.
@@ -747,7 +748,7 @@ class Porter(object):
return col
outrows = []
for i, row in enumerate(rows):
for row in rows:
try:
outrows.append(
tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
@@ -897,8 +898,7 @@ class Porter(object):
await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
async def _setup_events_stream_seqs(self) -> None:
"""Set the event stream sequences to the correct values.
"""
"""Set the event stream sequences to the correct values."""
# We get called before we've ported the events table, so we need to
# fetch the current positions from the SQLite store.
@@ -927,12 +927,14 @@ class Porter(object):
)
await self.postgres_store.db_pool.runInteraction(
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
"_setup_events_stream_seqs",
_setup_events_stream_seqs_set_pos,
)
async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
"""Set a sequence to the correct value.
"""
async def _setup_sequence(
self, sequence_name: str, stream_id_tables: Iterable[str]
) -> None:
"""Set a sequence to the correct value."""
current_stream_ids = []
for stream_id_table in stream_id_tables:
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
@@ -946,20 +948,25 @@ class Porter(object):
next_id = max(current_stream_ids) + 1
def r(txn):
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, )
txn.execute(sql + " %s", (next_id, ))
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
txn.execute(sql + " %s", (next_id,))
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
await self.postgres_store.db_pool.runInteraction(
"_setup_%s" % (sequence_name,), r
)
async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
table="event_auth_chains",
keyvalues={},
retcol="MAX(chain_id)",
allow_none=True,
)
def r(txn):
txn.execute(
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
(curr_chain_id,),
(curr_chain_id + 1,),
)
if curr_chain_id is not None:
@@ -975,8 +982,7 @@ class Porter(object):
class Progress(object):
"""Used to report progress of the port
"""
"""Used to report progress of the port"""
def __init__(self):
self.tables = {}
@@ -1001,8 +1007,7 @@ class Progress(object):
class CursesProgress(Progress):
"""Reports progress to a curses window
"""
"""Reports progress to a curses window"""
def __init__(self, stdscr):
self.stdscr = stdscr
@@ -1027,7 +1032,7 @@ class CursesProgress(Progress):
self.total_processed = 0
self.total_remaining = 0
for table, data in self.tables.items():
for data in self.tables.values():
self.total_processed += data["num_done"] - data["start"]
self.total_remaining += data["total"] - data["num_done"]
@@ -1118,8 +1123,7 @@ class CursesProgress(Progress):
class TerminalProgress(Progress):
"""Just prints progress to the terminal
"""
"""Just prints progress to the terminal"""
def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done)