import collections
import json
import os.path
import pprint
import zipfile
import h5py
import numpy as np
import rich.console
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.saving import saving_lib
from keras.src.saving.saving_lib import H5IOStore
from keras.src.utils import naming
from keras.src.utils import summary_utils
try:
import IPython as ipython
except ImportError:
ipython = None
def is_ipython_notebook():
"""Checks if the code is being executed in a notebook."""
try:
from IPython import get_ipython
# Check if an active IPython shell exists.
if get_ipython() is not None:
return True
return False
except ImportError:
return False
@keras_export("keras.saving.KerasFileEditor")
class KerasFileEditor:
"""Utility to inspect, edit, and resave Keras weights files.
You will find this class useful when adapting
an old saved weights file after having made
architecture changes to a model.
Args:
filepath: The path to a local file to inspect and edit.
Examples:
```python
editor = KerasFileEditor("my_model.weights.h5")
# Displays current contents
editor.summary()
# Remove the weights of an existing layer
editor.delete_object("layers/dense_2")
# Add the weights of a new layer
editor.add_object("layers/einsum_dense", weights={"0": ..., "1": ...})
# Save the weights of the edited model
editor.resave_weights("edited_model.weights.h5")
```
"""
def __init__(
self,
filepath,
):
self.filepath = filepath
self.metadata = None
self.config = None
self.model = None
self.console = rich.console.Console(highlight=False)
if filepath.endswith(".keras"):
zf = zipfile.ZipFile(filepath, "r")
weights_store = H5IOStore(
f"{saving_lib._VARS_FNAME}.h5",
archive=zf,
mode="r",
)
with zf.open(saving_lib._CONFIG_FILENAME, "r") as f:
config_json = f.read()
with zf.open(saving_lib._METADATA_FILENAME, "r") as f:
metadata_json = f.read()
self.config = json.loads(config_json)
self.metadata = json.loads(metadata_json)
elif filepath.endswith(".weights.h5"):
weights_store = H5IOStore(filepath, mode="r")
else:
raise ValueError(
"Invalid filename: "
"expected a `.keras` `.weights.h5` extension. "
f"Received: filepath={filepath}"
)
weights_dict, object_metadata = self._extract_weights_from_store(
weights_store.h5_file
)
weights_store.close()
self.weights_dict = weights_dict
self.object_metadata = object_metadata # {path: object_name}
self.console.print(self._generate_filepath_info(rich_style=True))
if self.metadata is not None:
self.console.print(self._generate_metadata_info(rich_style=True))
def summary(self):
"""Prints the weight structure of the opened file."""
self._weights_summary_cli()
def compare(self, reference_model):
"""Compares the opened file to a reference model.
This method will list all mismatches between the
currently opened file and the provided reference model.
Args:
reference_model: Model instance to compare to.
Returns:
Dict with the following keys:
`'status'`, `'error_count'`, `'match_count'`.
Status can be `'success'` or `'error'`.
`'error_count'` is the number of mismatches found.
`'match_count'` is the number of matching weights found.
"""
self.console.print("Running comparison")
ref_spec = {}
get_weight_spec_of_saveable(reference_model, ref_spec)
def _compare(
target,
ref_spec,
inner_path,
target_name,
ref_name,
error_count,
match_count,
checked_paths,
):
base_inner_path = inner_path
for ref_key, ref_val in ref_spec.items():
inner_path = f"{base_inner_path}/{ref_key}"
if inner_path in checked_paths:
continue
if ref_key not in target:
error_count += 1
checked_paths.add(inner_path)
if isinstance(ref_val, dict):
self.console.print(
f"[color(160)]...Object [bold]{inner_path}[/] "
f"present in {ref_name}, "
f"missing from {target_name}[/]"
)
self.console.print(
f" In {ref_name}, {inner_path} contains "
f"the following keys: {list(ref_val.keys())}"
)
else:
self.console.print(
f"[color(160)]...Weight [bold]{inner_path}[/] "
f"present in {ref_name}, "
f"missing from {target_name}[/]"
)
elif isinstance(ref_val, dict):
_error_count, _match_count = _compare(
target[ref_key],
ref_spec[ref_key],
inner_path,
target_name,
ref_name,
error_count=error_count,
match_count=match_count,
checked_paths=checked_paths,
)
error_count += _error_count
match_count += _match_count
else:
if target[ref_key].shape != ref_val.shape:
error_count += 1
checked_paths.add(inner_path)
self.console.print(
f"[color(160)]...Weight shape mismatch "
f"for [bold]{inner_path}[/][/]\n"
f" In {ref_name}: "
f"shape={ref_val.shape}\n"
f" In {target_name}: "
f"shape={target[ref_key].shape}"
)
else:
match_count += 1
return error_count, match_count
checked_paths = set()
error_count, match_count = _compare(
self.weights_dict,
ref_spec,
inner_path="",
target_name="saved file",
ref_name="reference model",
error_count=0,
match_count=0,
checked_paths=checked_paths,
)
_error_count, _ = _compare(
ref_spec,
self.weights_dict,
inner_path="",
target_name="reference model",
ref_name="saved file",
error_count=0,
match_count=0,
checked_paths=checked_paths,
)
error_count += _error_count
self.console.print("─────────────────────")
if error_count == 0:
status = "success"
self.console.print(
"[color(28)][bold]Comparison successful:[/] "
"saved file is compatible with the reference model[/]"
)
if match_count == 1:
plural = ""
else:
plural = "s"
self.console.print(
f" Found {match_count} matching weight{plural}"
)
else:
status = "error"
if error_count == 1:
plural = ""
else:
plural = "s"
self.console.print(
f"[color(160)][bold]Found {error_count} error{plural}:[/] "
"saved file is not compatible with the reference model[/]"
)
return {
"status": status,
"error_count": error_count,
"match_count": match_count,
}
def _edit_object(self, edit_fn, source_name, target_name=None):
if target_name is not None and "/" in target_name:
raise ValueError(
"Argument `target_name` should be a leaf name, "
"not a full path name. "
f"Received: target_name='{target_name}'"
)
if "/" in source_name:
# It's a path
elements = source_name.split("/")
weights_dict = self.weights_dict
for e in elements[:-1]:
if e not in weights_dict:
raise ValueError(
f"Path '{source_name}' not found in model."
)
weights_dict = weights_dict[e]
if elements[-1] not in weights_dict:
raise ValueError(f"Path '{source_name}' not found in model.")
edit_fn(
weights_dict, source_name=elements[-1], target_name=target_name
)
else:
# Ensure unicity
def count_occurences(d, name, count=0):
for k in d:
if isinstance(d[k], dict):
count += count_occurences(d[k], name, count)
if name in d:
count += 1
return count
occurrences = count_occurences(self.weights_dict, source_name)
if occurrences > 1:
raise ValueError(
f"Name '{source_name}' occurs more than once in the model; "
"try passing a complete path"
)
if occurrences == 0:
raise ValueError(
f"Source name '{source_name}' does not appear in the "
"model. Use `editor.weights_summary()` "
"to list all objects."
)
def _edit(d):
for k in d:
if isinstance(d[k], dict):
_edit(d[k])
if source_name in d:
edit_fn(d, source_name=source_name, target_name=target_name)
_edit(self.weights_dict)
def rename_object(self, object_name, new_name):
"""Rename an object in the file (e.g. a layer).
Args:
object_name: String, name or path of the
object to rename (e.g. `"dense_2"` or
`"layers/dense_2"`).
new_name: String, new name of the object.
"""
def rename_fn(weights_dict, source_name, target_name):
weights_dict[target_name] = weights_dict[source_name]
weights_dict.pop(source_name)
self._edit_object(rename_fn, object_name, new_name)
def delete_object(self, object_name):
"""Removes an object from the file (e.g. a layer).
Args:
object_name: String, name or path of the
object to delete (e.g. `"dense_2"` or
`"layers/dense_2"`).
"""
def delete_fn(weights_dict, source_name, target_name=None):
weights_dict.pop(source_name)
self._edit_object(delete_fn, object_name)
def add_object(self, object_path, weights):
"""Add a new object to the file (e.g. a layer).
Args:
object_path: String, full path of the
object to add (e.g. `"layers/dense_2"`).
weights: Dict mapping weight names to weight
values (arrays),
e.g. `{"0": kernel_value, "1": bias_value}`.
"""
if not isinstance(weights, dict):
raise ValueError(
"Argument `weights` should be a dict "
"where keys are weight names (usually '0', '1', etc.) "
"and values are NumPy arrays. "
f"Received: type(weights)={type(weights)}"
)
if "/" in object_path:
# It's a path
elements = object_path.split("/")
partial_path = "/".join(elements[:-1])
weights_dict = self.weights_dict
for e in elements[:-1]:
if e not in weights_dict:
raise ValueError(
f"Path '{partial_path}' not found in model."
)
weights_dict = weights_dict[e]
weights_dict[elements[-1]] = weights
else:
self.weights_dict[object_path] = weights
def delete_weight(self, object_name, weight_name):
"""Removes a weight from an existing object.
Args:
object_name: String, name or path of the
object from which to remove the weight
(e.g. `"dense_2"` or `"layers/dense_2"`).
weight_name: String, name of the weight to
delete (e.g. `"0"`).
"""
def delete_weight_fn(weights_dict, source_name, target_name=None):
if weight_name not in weights_dict[source_name]:
raise ValueError(
f"Weight {weight_name} not found "
f"in object {object_name}. "
"Weights found: "
f"{list(weights_dict[source_name].keys())}"
)
weights_dict[source_name].pop(weight_name)
self._edit_object(delete_weight_fn, object_name)
def add_weights(self, object_name, weights):
"""Add one or more new weights to an existing object.
Args:
object_name: String, name or path of the
object to add the weights to
(e.g. `"dense_2"` or `"layers/dense_2"`).
weights: Dict mapping weight names to weight
values (arrays),
e.g. `{"0": kernel_value, "1": bias_value}`.
"""
if not isinstance(weights, dict):
raise ValueError(
"Argument `weights` should be a dict "
"where keys are weight names (usually '0', '1', etc.) "
"and values are NumPy arrays. "
f"Received: type(weights)={type(weights)}"
)
def add_weight_fn(weights_dict, source_name, target_name=None):
weights_dict[source_name].update(weights)
self._edit_object(add_weight_fn, object_name)
def save(self, filepath):
"""Save the edited weights file.
Args:
filepath: Path to save the file to.
Must be a `.weights.h5` file.
"""
filepath = str(filepath)
if not filepath.endswith(".weights.h5"):
raise ValueError(
"Invalid `filepath` argument: "
"expected a `.weights.h5` extension. "
f"Received: filepath={filepath}"
)
weights_store = H5IOStore(filepath, mode="w")
def _save(weights_dict, weights_store, inner_path):
vars_to_create = {}
for name, value in weights_dict.items():
if isinstance(value, dict):
if value:
_save(
weights_dict[name],
weights_store,
inner_path=os.path.join(inner_path, name),
)
else:
# e.g. name="0", value=HDF5Dataset
vars_to_create[name] = value
if vars_to_create:
var_store = weights_store.make(inner_path)
for name, value in vars_to_create.items():
var_store[name] = value
_save(self.weights_dict, weights_store, inner_path="")
weights_store.close()
def resave_weights(self, filepath):
self.save(filepath)
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
metadata = metadata or {}
object_metadata = {}
for k, v in data.attrs.items():
object_metadata[k] = v
if object_metadata:
metadata[inner_path] = object_metadata
result = collections.OrderedDict()
for key in data.keys():
inner_path = f"{inner_path}/{key}"
value = data[key]
if isinstance(value, h5py.Group):
if len(value) == 0:
continue
if "vars" in value.keys() and len(value["vars"]) == 0:
continue
if hasattr(value, "keys"):
if "vars" in value.keys():
result[key], metadata = self._extract_weights_from_store(
value["vars"], metadata=metadata, inner_path=inner_path
)
else:
result[key], metadata = self._extract_weights_from_store(
value, metadata=metadata, inner_path=inner_path
)
else:
result[key] = value[()]
return result, metadata
def _generate_filepath_info(self, rich_style=False):
if rich_style:
filepath = f"'{self.filepath}'"
filepath = f"{summary_utils.highlight_symbol(filepath)}"
else:
filepath = f"'{self.filepath}'"
return f"Keras model file {filepath}"
def _generate_config_info(self, rich_style=False):
return pprint.pformat(self.config)
def _generate_metadata_info(self, rich_style=False):
version = self.metadata["keras_version"]
date = self.metadata["date_saved"]
if rich_style:
version = f"{summary_utils.highlight_symbol(version)}"
date = f"{summary_utils.highlight_symbol(date)}"
return f"Saved with Keras {version} - date: {date}"
def _print_weights_structure(
self, weights_dict, indent=0, is_first=True, prefix="", inner_path=""
):
for idx, (key, value) in enumerate(weights_dict.items()):
inner_path = os.path.join(inner_path, key)
is_last = idx == len(weights_dict) - 1
if is_first:
is_first = False
connector = "> "
elif is_last:
connector = "└─ "
else:
connector = "├─ "
if isinstance(value, dict):
bold_key = summary_utils.bold_text(key)
object_label = f"{prefix}{connector}{bold_key}"
if inner_path in self.object_metadata:
metadata = self.object_metadata[inner_path]
if "name" in metadata:
name = metadata["name"]
object_label += f" ('{name}')"
self.console.print(object_label)
if is_last:
appended = " "
else:
appended = "│ "
new_prefix = prefix + appended
self._print_weights_structure(
value,
indent + 1,
is_first=is_first,
prefix=new_prefix,
inner_path=inner_path,
)
else:
if hasattr(value, "shape"):
bold_key = summary_utils.bold_text(key)
self.console.print(
f"{prefix}{connector}{bold_key}:"
+ f" shape={value.shape}, dtype={value.dtype}"
)
else:
self.console.print(f"{prefix}{connector}{key}: {value}")
def _weights_summary_cli(self):
self.console.print("Weights structure")
self._print_weights_structure(self.weights_dict, prefix=" " * 2)
def _weights_summary_interactive(self):
def _generate_html_weights(dictionary, margin_left=0, font_size=1):
html = ""
for key, value in dictionary.items():
if isinstance(value, dict) and value:
weights_html = _generate_html_weights(
value, margin_left + 20, font_size - 1
)
html += (
f'{key}
'
f"{weights_html}"
"'
f"{key} : shape={value.shape}"
f", dtype={value.dtype}
"
f"