diff --git a/tinydb/table.py b/tinydb/table.py index 96b3f4a8..f7a15ceb 100644 --- a/tinydb/table.py +++ b/tinydb/table.py @@ -325,6 +325,9 @@ def perform_update(table, doc_id): # Update documents by setting all fields from the provided data table[doc_id].update(fields) + if cond is None and doc_ids is None and isinstance(fields, Document): + doc_ids = [fields.doc_id] + if doc_ids is not None: # Perform the update operation for documents specified by a list # of document IDs @@ -413,19 +416,23 @@ def upsert(self, document: Mapping, cond: Query) -> List[int]: def remove( self, - cond: Optional[Query] = None, + cond: Optional[Union[Document, Query]] = None, doc_ids: Optional[Iterable[int]] = None, ) -> List[int]: """ Remove all matching documents. - :param cond: the condition to check against + :param cond: the condition to check against, or the document to remove :param doc_ids: a list of document IDs :returns: a list containing the removed documents' ID """ if cond is None and doc_ids is None: raise RuntimeError('Use truncate() to remove all documents') + if doc_ids is None and isinstance(cond, Document): + doc_ids = [cond.doc_id] + cond = None + if cond is not None: removed_ids = []