- {{ year }}
- {%- if tags != "" %}
- ·
- {% for tag in page.tags -%}
-
- {{ tag }}
- {% endfor -%}
- {% endif %}
-
- {%- if categories != "" %}
- ·
- {% for category in page.categories -%}
-
- {{ category }}
- {% endfor -%}
- {% endif %}
-
-
-
-
-
- {{ content }}
-
-
-
-
- For attribution in academic contexts, please cite this work as
-
-
-
-
- BibTeX citation
-
-
-
-
-
-
- {%- if site.disqus_shortname and page.comments -%}
-
-
-
- {%- endif %}
-
-
-
-
diff --git a/_news/announcement_1.md b/_news/announcement_1.md
deleted file mode 100644
index 98e5af5c..00000000
--- a/_news/announcement_1.md
+++ /dev/null
@@ -1,7 +0,0 @@
----
-layout: post
-date: 2015-10-22 15:59:00-0400
-inline: true
----
-
-A simple inline announcement.
diff --git a/_news/announcement_2.md b/_news/announcement_2.md
deleted file mode 100644
index dbd4b4d4..00000000
--- a/_news/announcement_2.md
+++ /dev/null
@@ -1,31 +0,0 @@
----
-layout: post
-title: A long announcement with details
-date: 2015-11-07 16:11:00-0400
-inline: false
----
-
-Announcements and news can be much longer than just quick inline posts. In fact, they can have all the features available for the standard blog posts. See below.
-
-***
-
-Jean shorts raw denim Vice normcore, art party High Life PBR skateboard stumptown vinyl kitsch. Four loko meh 8-bit, tousled banh mi tilde forage Schlitz dreamcatcher twee 3 wolf moon. Chambray asymmetrical paleo salvia, sartorial umami four loko master cleanse drinking vinegar brunch. Pinterest DIY authentic Schlitz, hoodie Intelligentsia butcher trust fund brunch shabby chic Kickstarter forage flexitarian. Direct trade cold-pressed meggings stumptown plaid, pop-up taxidermy. Hoodie XOXO fingerstache scenester Echo Park. Plaid ugh Wes Anderson, freegan pug selvage fanny pack leggings pickled food truck DIY irony Banksy.
-
-#### Hipster list
-
-
brunch
-
fixie
-
raybans
-
messenger bag
-
-
-Hoodie Thundercats retro, tote bag 8-bit Godard craft beer gastropub. Truffaut Tumblr taxidermy, raw denim Kickstarter sartorial dreamcatcher. Quinoa chambray slow-carb salvia readymade, bicycle rights 90's yr typewriter selfies letterpress cardigan vegan.
-
-***
-
-Pug heirloom High Life vinyl swag, single-origin coffee four dollar toast taxidermy reprehenderit fap distillery master cleanse locavore. Est anim sapiente leggings Brooklyn ea. Thundercats locavore excepteur veniam eiusmod. Raw denim Truffaut Schlitz, migas sapiente Portland VHS twee Bushwick Marfa typewriter retro id keytar.
-
-> We do not grow absolutely, chronologically. We grow sometimes in one dimension, and not in another, unevenly. We grow partially. We are relative. We are mature in one realm, childish in another.
-> —Anais Nin
-
-Fap aliqua qui, scenester pug Echo Park polaroid irony shabby chic ex cardigan church-key Odd Future accusamus. Blog stumptown sartorial squid, gastropub duis aesthetic Truffaut vero. Pinterest tilde twee, odio mumblecore jean shorts lumbersexual.
diff --git a/_news/announcement_3.md b/_news/announcement_3.md
deleted file mode 100644
index d9072191..00000000
--- a/_news/announcement_3.md
+++ /dev/null
@@ -1,7 +0,0 @@
----
-layout: post
-date: 2016-01-15 07:59:00-0400
-inline: true
----
-
-A simple inline announcement with Markdown emoji! :sparkles: :smile:
diff --git a/_pages/about.md b/_pages/about.md
deleted file mode 100644
index 20832a24..00000000
--- a/_pages/about.md
+++ /dev/null
@@ -1,115 +0,0 @@
----
-layout: about
-title: about
-permalink: /about/
-nav: true
-nav_order: 1
-subtitle:
-
-# profile:
-# align: right
-# image:
-# image_circular: false # crops the image to make it circular
-# address:
-
-# news: false # includes a list of news items
-# selected_papers: false # includes a list of papers marked as "selected={true}"
-# social: false # includes social icons at the bottom of the page
----
-
-**Announcements**:
-- The deadline for submitting an abstract and your final blog post have been updated to December 11th and 17th (00:00 GMT) respectively!
-
-# ICLR 2024 Blogposts Track
-
-The Machine Learning community is currently experiencing a [reproducibility crisis](https://neuripsconf.medium.com/designing-the-reproducibility-program-for-neurips-2020-7fcccaa5c6ad) and a reviewing crisis [[Littman, 2021]](#Litt). Because of the highly competitive and noisy reviewing process of ML conferences [[Tran et al., 2020]](#Tran), researchers have an incentive to oversell their results, slowing down the progress and diminishing the integrity of the scientific community. Moreover with the growing number of papers published and submitted at the main ML conferences [[Lin et al., 2020]](#Lin), it has become more challenging to keep track of the latest advances in the field.
-
-Blog posts are becoming an increasingly popular and useful way to talk about science [[Brown and Woolston, 2018]](#Brow). They offer substantial value to the scientific community by providing a flexible platform to foster open, human, and transparent discussions about new insights or limitations of a scientific publication. However, because they are not as recognized as standard scientific publications, only a minority of researchers manage to maintain an active blog and get visibility for their efforts. Many are well-established researchers ([Francis Bach](https://francisbach.com/), [Ben Recht](https://www.argmin.net/), [Ferenc Huszár](https://www.inference.vc/), [Lilian Weng](https://lilianweng.github.io/lil-log/)) or big corporations that leverage entire teams of graphic designers designer and writers to polish their blogs ([Facebook AI](https://ai.facebook.com/blog/?page=1), [Google AI](https://ai.googleblog.com/), [DeepMind](https://deepmind.com/blog), [OpenAI](https://openai.com/blog/)). As a result, the incentives for writing scientific blog posts are largely personal; it is unreasonable to expect a significant portion of the machine learning community to contribute to such an initiative when everyone is trying to establish themselves through publications.
-
-**Submit** your blogpost on [Openreview](https://openreview.net/group?id=ICLR.cc/2024/BlogPosts&referrer=%5BHomepage%5D(%2F))
-
-## A Blog Post Conference Track
-
-Last year, we ran the **second** iteration of the [Blogpost track](https://iclr-blogposts.github.io/2023/about) at ICLR 2023!
-
-It was very successful, with accepted posts presented in person at the main conference.
-
-Our goal is to create a formal call for blog posts at ICLR to incentivize and reward researchers to review past work and summarize the outcomes, develop new intuitions, or highlight some shortcomings. A very influential initiative of this kind happened after the Second World War in France. Because of the lack of up-to-date textbooks, a collective of mathematicians under the pseudonym Nicolas Bourbaki [[Halmos 1957]](#Halm), decided to start a series of textbooks about the foundations of mathematics [[Bourbaki, 1939]](#Bour). In the same vein, we aim to provide a new way to summarize scientific knowledge in the ML community.
-
-Due to the large diversity of topics that can be discussed in a blog post, we decided to restrict the range of topics for this call for blog posts. We identified that the blog posts that would bring to most value to the community and the conference would be posts that distill and discuss *previously published papers*.
-
-## Key Dates
-
-**Abstract deadline**: December 11th 00:00GMT, 2023 (submit to OpenReview - to be announced soon).
-
-**Submission deadline**: December 17th 00:00GMT, 2023 (any modifications to your blog post, via a pull request on GitHub).
-
-**Decision Notification**: ~~January 30th, 2024~~ UPDATED: February 15th, 2024
-
-**Camera-ready merge**: March 15th, 2024
-
-## A call for blog posts discussing work previously published at ICLR
-
-#### Content
-
-Write a post on a subject that has been published at a top-tier venue (ICLR, ICML, NeurIPS, AAAI, UAI, CVPR, SIGGRAPH, ECCV, ICCV, etc.) relatively recently.
-
-#### Conflict of interest
-
-The authors of the blog posts will have to declare their conflicts of interest (positive or negative) with the paper (and the paper's authors) they write about. Conflicts of interest include:
-- Recent collaborators (less than 3 years)
-- Current institution Reviewers will be asked to judge if the submission is sufficiently critical and objective of the papers addressed in the blog post.
-- **Blog Posts must not be used to highlight or advertise past publications of the **authors or their lab****.
-
-We will only ask the authors to report if they have a conflict of interest. If so, reviewers will be asked to judge if the submission is sufficiently critical and objective of the papers addressed in the blog post.
-
-
-## Publication
-
-#### Blog post
-
-The posts will be created and published under a unified template; see [the submission instructions]({{ '/submitting' | relative_url }}) and the [sample post]({% post_url 2024-05-07-distill-example %}) hosted on the blog of this website.
-
-#### Poster
-Additionally, accepted posts will have the option to present their work as a poster during the main poster session. For more information about the main poster session (time, poster format, etc.) please refer to the ICLR homepage.
-
-## Submissions
-
-Our goal is to avoid heavily engineered, professionally-made blog posts ---Such as the “100+ hours” mentioned as a standard by the [Distill guidelines](https://distill.pub/journal/)---to entice ideas and clear writing rather than dynamic visualizations or embedded javascript engines.
-Please check our [submission instructions]({{ '/submitting' | relative_url }}) for more details.
-We accept submissions in both Markdown and HTML. We believe this is a good trade-off between complexity and flexibility.
-
-**Submit** your blogpost on [Openreview](https://openreview.net/group?id=ICLR.cc/2024/BlogPosts&referrer=%5BHomepage%5D(%2F))
-
-## Contact
-
-For any technical issues with the blog post repository (for example, blog posts not displaying correctly or issues while following the [submission instructions](https://iclr-blogposts.github.io/2024/submitting/#creating-a-blog-post)), please open an [issue in our github repository](https://github.com/iclr-blogposts/2024/issues).
-
-For other inquiries, reach us via email at: [blog.track.chairs@gmail.com](mailto:blog.track.chairs@gmail.com)
-
-## Organizers
-
-
- {% include people_horizontal.html name="Gauthier Gidel" affiliation="Mila, Université de Montréal" url="https://gauthiergidel.github.io/" img="assets/img/organizers/gg.jpg" %}
- {% include people_horizontal.html name="Charlie Gauthier" affiliation="Mila, Université de Montréal" url="https://velythyl.github.io/" img="assets/img/organizers/cg.jpg" %}
- {% include people_horizontal.html name="David Dobre" affiliation="Mila, Université de Montréal" url="" img="assets/img/organizers/dd.jpg" %}
- {% include people_horizontal.html name="Claire Vernade" affiliation="University of Tuebingen" url="https://www.cvernade.com/" img="assets/img/organizers/cv.jpg" %}
- {% include people_horizontal.html name="Fabian Pedregosa" affiliation="Google DeepMind" url="https://fa.bianp.net/pages/about.html" img="assets/img/organizers/fp.jpg" %}
- {% include people_horizontal.html name="Leo Schwinn" affiliation="Technical University of Munich" url="https://schwinnl.github.io//" img="assets/img/organizers/ls.jpg" %}
-
-
----
-
-## References
-
-Michael L Littman. Collusion rings threaten the integrity of computer science research. Communications of the ACM, 2021.
-
-David Tran, Alex Valtchanov, Keshav Ganapathy, Raymond Feng, Eric Slud, Micah Goldblum, and Tom Goldstein. An open review of OpenReview: A critical analysis of the machine learning conference review process. arXiv, 2020.
-
-Hsuan-Tien Lin, Maria-Florina Balcan, Raia Hadsell, and Marc’Aurelio Ranzato. What we learned from NeurIPS 2020 reviewing process. Medium https://medium.com/@NeurIPSConf/what-we-learned-from-neurips-2020-reviewing-process-e24549eea38f, 2020.
-
-Eryn Brown and Chris Woolston. Why science blogging still matters. Nature, 2018.
-
-Paul R Halmos. Nicolas Bourbaki. Scientific American, 1957.
-
-Nicolas Bourbaki. Elements of mathematics. Éditions Hermann, 1939.
diff --git a/_pages/call.md b/_pages/call.md
deleted file mode 100644
index 6d4fe076..00000000
--- a/_pages/call.md
+++ /dev/null
@@ -1,73 +0,0 @@
----
-layout: page
-title: call for blogposts
-permalink: /call/
-description:
-nav: true
-nav_order: 2
----
-
-**Announcements**:
-- Instructions on getting started with building your blog posts have now been posted [here]({{ '/submitting' | relative_url }})!
-- The deadline for submitting an abstract and your final blog post have been updated to December 11th and 17th (00:00 GMT) respectively!
-
-**Submit** your blogpost on [Openreview](https://openreview.net/group?id=ICLR.cc/2024/BlogPosts&referrer=%5BHomepage%5D(%2F))
-
-# Call for blog posts
-
-We invite all researchers and practitioners to submit a blog post discussing work previously published at a top-tier venue to the ICLR 2024 blog post track.
-The format and process for this blog post track are described below.
-
-
-### Content
-
-Write a post on a subject that has been published at a top-tier venue (ICLR, ICML, NeurIPS, AAAI, UAI, CVPR, SIGGRAPH, ECCV, ICCV, etc.) relatively recently.
-Past blog posts can be accessed [here](https://iclr-blogposts.github.io/2023/about).
-
-
-### Conflict of interest
-
-The authors of the blog posts will have to declare their conflicts of interest (positive or negative) with the paper (and their authors) they write about.
-Conflicts of interest include:
-
- - Recent collaborators (less than 3 years)
- - Current institution
-
-Reviewers will be asked to judge if the submission is sufficiently critical and objective of the papers addressed in the blog post.
-**Blog Posts must not be used to highlight or advertise past publications of the authors or of their lab**.
-
-
-### Publication
-
-##### Blog post
-
-The posts will be created and published under a unified template; see [the submission instructions]({{ '/submitting' | relative_url }}) and the [sample post]({{ '/blog/2024/distill-example' | relative_url }}) hosted on the blog of this website.
-
-##### Poster
-Additionally, accepted posts will have the option to present their work as a poster during the main poster session. For more information about the main poster session (time, poster format, etc.) please refer to the ICLR homepage.
-
-### Review
-
-Blogs will be peer-reviewed (double-blind) for quality and novelty of the content: clarity and pedagogy of the exposition, new theoretical or practical insights, reproduction/extension of experiments, etc.
-The review is dual-anonymous assuming good faith from both submitters and reviewers (see [the submission instructions]({{ '/submitting' | relative_url }}) for more details).
-
-
-## Key Dates
-- **Abstract deadline**: December 11th 00:00GMT, 2023 ([submit to OpenReview](https://openreview.net/group?id=ICLR.cc/2024/BlogPosts&referrer=%5BHomepage%5D(%2F))).
-
-
-- **Submission deadline**: December 17th 00:00GMT, 2023 (any modifications to your blog post, via a pull request on github).
-
-
-- **Notification of acceptance**: ~~January 30th, 2024~~ UPDATED: February 15th, 2024
-
-
-- **Camera-ready merge**: March 15th, 2024
-
-
-### Contact
-
-For answers to many common questions please refer to the ICLR [FAQ](https://iclr.cc/FAQ)
-
-Should you have other inquiries, please don't hesitate to reach out via email at: [blog.track.chairs@gmail.com](mailto:blog.track.chairs@gmail.com)
-
diff --git a/_pages/dropdown.md b/_pages/dropdown.md
deleted file mode 100644
index 0eb85be8..00000000
--- a/_pages/dropdown.md
+++ /dev/null
@@ -1,13 +0,0 @@
----
-layout: page
-title: past iterations
-nav: true
-nav_order: 99
-dropdown: true
-children:
- - title: 2023
- permalink: https://iclr-blogposts.github.io/2023/about
- - title: divider
- - title: 2022
- permalink: https://iclr-blog-track.github.io/home/
----
\ No newline at end of file
diff --git a/_pages/dropdown/index.html b/_pages/dropdown/index.html
new file mode 100644
index 00000000..44ce61cb
--- /dev/null
+++ b/_pages/dropdown/index.html
@@ -0,0 +1 @@
+ past iterations | ICLR Blogposts 2024
\ No newline at end of file
diff --git a/_pages/reviewer_guidelines.md b/_pages/reviewer_guidelines.md
deleted file mode 100644
index 0958cd1c..00000000
--- a/_pages/reviewer_guidelines.md
+++ /dev/null
@@ -1,25 +0,0 @@
----
-layout: page
-title: reviewing
-permalink: /reviewing/
-description:
-nav: true
-nav_order: 4
----
-
-### Reviewing Process
-
-Reviewers will be required to only view the live content of the blog.
-We ask that they act in good faith, and refrain from digging into the repository's logs and closed Pull Requests to find any identifying information on the authors.
-
-Reviewers should motivate their final decision based on the following points:
-
-- Is there a significant added value in comparison to the cited papers?
-- Is this added value supported by accurate, convincing, and clear arguments?
-- If the blogpost does not directly relate to a paper, does it address a relevant research topic from a novel perspective?
-- In case the field *Conflict Of Interest* is marked as *YES* the reviewers are asked to pay specific attention to how the related work mentioned in the field *ICLR Papers*: is the blogpost *too positive* (self advertisement) or *too negative* (unfair assessment of this related work)?
-
-In order to access them please follow the following steps:
-
-1. Go to the OpenReview submission page.
-2. To see the blogpost submission, go to the blogpost url specified in the field 'Blogpost Url'.
\ No newline at end of file
diff --git a/_pages/submitting.md b/_pages/submitting.md
deleted file mode 100644
index 5e24980c..00000000
--- a/_pages/submitting.md
+++ /dev/null
@@ -1,360 +0,0 @@
----
-layout: page
-title: submitting
-permalink: /submitting/
-description:
-nav: true
-nav_order: 3
----
-
-### A more open process
-
-As with the previous edition of the Blog Post track, we forgo the requirement for total anonymity.
-The blog posts **must be anonymized for the review process**, but users will submit their anonymized blog posts via a pull request to the blog track's repository (in addition to a submission on OpenReview).
-The pull request will trigger an automated pipeline that will build and deploy your post onto a website dedicated to the reviewing process.
-
-Reviewers will be able to access the posts directly through a public URL (generated by the Github action), and will submit their reviews on OpenReview.
-Reviewers should refrain from looking at the git history for the post, which may reveal information about the authors.
-
-This still largely follows the Double-Blind reviewing principle; it is no less double-blind than when reviewers are asked to score papers that have previously been released to [arXiv](https://arxiv.org/), an overwhelmingly common practice in the ML community.
-This approach was chosen to lower the burden on both the organizers and the authors; in 2022, many submissions had to be reworked once deployed due to a variety of reasons.
-By allowing the authors to render their websites to Github Pages prior to the review process, we hope to avoid this issue entirely.
-
-
-However, we understand the desire for total anonymity.
-Authors that wish to have a fully double-blind process might consider creating new GitHub accounts without identifying information which they will only be use for this track.
-For an example of a submission in the past which used an anonymous account in this manner, you can check out the [World Models blog post (Ha and Schmidhuber, 2018)](https://worldmodels.github.io/) and the [accompanying repository](https://github.com/worldmodels/worldmodels.github.io).
-
-### Template
-
-The workflow you will use to participate in this track should be relatively familiar to you if have used [Github Pages](https://pages.github.com/). Specifically, our website uses the [Al-Folio](https://github.com/alshedivat/al-folio) template.
-This template uses Github Pages as part of its process, but it also utilizes a separate build step using [Github Actions](https://github.com/features/actions) and intermediary [Docker Images](https://www.docker.com/).
-
-**We recommend paying close attention to the steps presented in this guide.
-Small mistakes here can have very hard-to-debug consequences.**
-
-### Contents
-
-- [Quickstart](#quickstart)
-- [Download the Blog Repository](#download-the-blog-repository)
-- [Creating a Blog Post](#creating-a-blog-post)
-- [Local Serving](#local-serving)
- - [Method 1: Using Docker](#method-1-using-docker)
- - [Method 2: Using Jekyll Manually](#method-2-using-jekyll-manually)
- - [Installation](#installation)
- - [Manual Serving](#manual-serving)
-- [Submitting Your Blog Post](#submitting-your-blog-post)
-- [Reviewing Process](#reviewing-process)
-- [Camera Ready (TBD)](#camera-ready)
-
-
-### Quickstart
-
-This section provides a summary of the workflow for creating and submitting a blog post.
-For more details about any of these steps, please refer to the appropriate section.
-
-
-1. Fork or download our [repository](https://github.com/iclr-blogposts/2024).
-
-2. Create your blog post content as detailed in the [Creating a Blog Post](#creating-a-blog-post) section.
- In summary, to create your post, you will:
- - Create a Markdown or HTML file in the `_posts/` directory with the format `_posts/2024-05-07-[SUBMISSION NAME].md`. If you choose to write the post in HTML, then the extension of this last file should be .html instead of .md. NOTE: HTML posts are not officially supported, use at your own risk!
- - Add any static image to `assets/img/2024-05-07-[SUBMISSION NAME]/`.
- - Add any interactive HTML figures to `assets/html/2024-05-07-[SUBMISSION NAME]/`.
- - Put your citations into a bibtex file in `assets/bibliography/2024-05-07-[SUBMISSION NAME].bib`.
-
- **DO NOT** touch anything else in the repository.
- We will utilize an automated deployment action which will filter out all submissions that modifiy more than the list of files that we just described above.
- Read the [relevant section](#creating-a-blog-post) for more details.
- **Make sure to omit any identifying information for the review process.**
-
-3. To render your website locally, you can build a docker container via `$ ./bin/docker_run.sh` to serve your website locally.
- Alternatively, you can setup your local environment to render the website via conventional `$ bundle exec jekyll serve --future` commands.
- More information for both of these configuratoins can be found in the [Local Serving](#local-serving) section.
-
-4. To submit your website, create a pull request to the main repository. Make sure that this PR's title is `_posts/2024-05-07-[SUBMISSION NAME]`. This will trigger a GitHub Action that will build your blogpost and write the host's URL in a comment to your PR.
-
-5. If accepted, we will merge the accepted posts to our main repository. See the [camera ready](#camera-ready) section for more details on merging in an accepted blog post.
-
-**Should you edit ANY files other your new post inside the `_posts` directory, and your new folder inside the `assets` directory, your pull requests will automatically be rejected.**
-
-You can view an example of a successful PR [here](https://github.com/iclr-blogposts/2024/pull/48). You can view an example of a PR with erroneous files [here](https://github.com/iclr-blogposts/2024/pull/51).
-
-### Download the Blog Repository
-
-Download or fork our [repository](https://github.com/iclr-blogposts/2024).
-You will be submitting a pull request this repository.
-
-### Creating a Blog Post
-
-To create a blog post in Markdown format, you can modify the [example]({% post_url 2024-05-07-distill-example %}) Markdown post `_posts/2024-05-07-distill-example.md` and rename it to `_posts/2024-05-07-[SUBMISSION NAME].md`, where `[SUBMISSION NAME]` is the name of your submission. You can see the result of the sample post .
-
-While most users will want to create a post in the Markdown format, it is also possible to create a post in HTML format. For this, modify instead the example `_posts/2024-05-08-distill-example2.html` and rename it to `_posts/2024-05-07-[SUBMISSION NAME].html`. (NOTE: HTML is not officially supported, use at your own risk).
-
-
-You must modify the file's header (or 'front-matter') as needed.
-
-
-
- ```markdown
- ---
-layout: distill
-title: [Your Blog Title]
-description: [Your blog post's abstract - no math/latex or hyperlinks!]
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# anonymize when submitting
-authors:
- - name: Anonymous
-
-# do not fill this in until your post is accepted and you're publishing your camera-ready post!
-# authors:
-# - name: Albert Einstein
-# url: "https://en.wikipedia.org/wiki/Albert_Einstein"
-# affiliations:
-# name: IAS, Princeton
-# - name: Boris Podolsky
-# url: "https://en.wikipedia.org/wiki/Boris_Podolsky"
-# affiliations:
-# name: IAS, Princeton
-# - name: Nathan Rosen
-# url: "https://en.wikipedia.org/wiki/Nathan_Rosen"
-# affiliations:
-# name: IAS, Princeton
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-distill-example.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-toc:
- - name: [Section 1]
- - name: [Section 2]
- # you can additionally add subentries like so
- subsections:
- - name: [Subsection 2.1]
- - name: [Section 3]
----
-
-# ... your blog post's content ...
-```
-
-You must change the `title`, `discription`, `toc`, and eventually the `authors` fields (**ensure that the
-submission is anonymous for the review process**).
-
-
-Read our [sample blog post]({% post_url 2024-05-07-distill-example %}) carefully to see how you can add image assets, and how to write using $$\LaTeX$$!
-Read about rendering your post locally [below](#serving).
-
-**Important: make sure your post is completely anonymized before you export and submit it!**
-
-Before going any further, it will be useful to highlight exactly what folders and files you are going to add or modify.
-Even if you use one of our simpler quickstart methods, this will always be what's happening
-behind the scenes.
-
-If you clone our repo or download a release, you will find a directory structure that looks like
-the following (excluding all files and directories that are not relevant to your submission):
-
-```bash
-your_blogpost_repo/
-│
-├── _posts
-│ ├── 2024-05-07-[YOUR SUBMISSION].md # <--- Create this markdown file; this is your blogpost
-│ └── ...
-├── assets
-│ ├── bibliography
-│ │ ├── 2024-05-07-[YOUR SUBMISSION].bib # <--- Create this bibtex file
-│ │ └── ...
-│ ├── html
-│ │ ├── 2024-05-07-[YOUR SUBMISSION] # <--- Create this directory and add interactive html figures
-│ │ │ └──[YOUR HTML FIGURES].html
-│ │ └── ...
-│ ├── img
-│ │ ├── 2024-05-07-[YOUR SUBMISSION] # <--- Create this directory and add static images here
-│ │ │ └──[YOUR IMAGES].png
-│ │ └── ...
-│ └── ...
-└── ...
-```
-
-In summary, to create your post, you will:
-
-- Create a Markdown (or HTML) file in the `_posts/` directory with the format `_posts/2024-05-07-[SUBMISSION NAME].md` (`_posts/2024-05-07-[SUBMISSION NAME].html` in the case of an HTML file).
-- Add any static image assets will be added to `assets/img/2024-05-07-[SUBMISSION NAME]/`.
-- Add any interactive HTML figures will be added to `assets/html/2024-05-07-[SUBMISSION NAME]/`.
-- Put your citations into a bibtex file in `assets/bibliography/2024-05-07-[SUBMISSION NAME].bib`.
-
-**DO NOT** touch anything else in the blog post!
-If you do, our automated pipeline will reject your PR and you will have to undo those changes in order for it to be accepted!
-
-Note that `2024-05-07-[YOUR SUBMISSION]` serves as a tag to your submission, so it should be the
-same for all three items.
-For example, if you're writing a blog post called "Deep Learning", you'd likely want to make your
-tag `2024-05-07-deep-learning`, and the directory structure would look like this:
-
-```bash
-your_blogpost_repo/
-│
-├── _posts
-│ ├── 2024-05-07-deep-learning.md # <--- Create this markdown file; this is your blogpost
-│ └── ...
-├── assets
-│ ├── bibliography
-│ │ ├── 2024-05-07-deep-learning.bib # <--- Create this bibtex file
-│ │ └── ...
-│ ├── html
-│ │ ├── 2024-05-07-deep-learning # <--- Create this directory and add interactive html figures
-│ │ │ └──[YOUR HTML FIGURES].html
-│ │ └── ...
-│ ├── img
-│ │ ├── 2024-05-07-deep-learning # <--- Create this directory and add static images here
-│ │ │ └──[YOUR IMAGES].png
-│ │ └── ...
-│ └── ...
-└── ...
-```
-
-### Local serving
-
-So far we've talked about how to get the relevant repository and create a blog post conforming to our requirements.
-Everything you have done so far has been in Markdown, but this is not the same format as web content (typically HTML, etc.).
-You'll now need to build your static web site (which is done using Jekyll), and then *serve* it on some local webserver in order to view it properly.
-We will now discuss how you can *serve* your blog site locally, so you can visualize your work before you open a pull request on the staging website so you can submit it to the ICLR venue.
-
-#### Method 1: Using Docker
-
-To render your website locally, we follow the instructions for [Local setup using Docker (Recommended on Windows)](https://github.com/iclr-blogposts/iclr-blogposts.github.io/blob/master/README.md#local-setup-using-docker-recommended-on-windows), but specifically you will need to create your own docker container rather than pull it from Dockerhub (because we modified the Gemfile).
-
-Create and run the Docker image:
-
-```bash
-./bin/docker_run.sh
-```
-
-Remove the `Gemfile.lock` file if prompted.
-This will create a docker image labeled as `al-folio:latest`.
-Don't use `dockerhub_run.sh`; this may result in issues with missing jekyll dependencies.
-
-
-#### Method 2: Using Jekyll Manually
-
-For users wishing to not use a Docker container, you can install Jekyll directly to your computer and build the site using Jekyll directly.
-This is done at your own risk, as there are many potential points of error!
-Follow the instructions for rendering the website via the conventional method of `$ bundle exec jekyll serve --future`
-
-##### Installation
-
-You will need to manually install Jekyll which will vary based on your operating system.
-The instructions here are only for convenience - you are responsible for making sure it works on your system and we are not liable for potential issues that occur when adding your submissions to our repo!
-
-**Ubuntu/Debian**
-
-1. Install Ruby
-
- ```bash
- sudo apt install ruby-full
- ```
-
-2. Once installed, add the following to your `.bashrc` or whatever terminal startup script you may use (this is important because otherwise gem may complain about needing sudo permission to install packages):
-
- ```bash
- export GEM_HOME="$HOME/.gem"
- export PATH="$HOME/.gem/bin:$PATH"
- ```
-
-3. Install Jekyll and Bundler:
-
- ```bash
- gem install jekyll bundler
- ```
-
-**MacOS and Windows**
-
-Mac and Windows users can find relevant guides for installing Jekyll here:
-
-- [Windows guide](https://jekyllrb.com/docs/installation/windows/)
-- [MacOS guide](https://jekyllrb.com/docs/installation/macos/)
-
-##### Manual Serving
-
-Once you've installed jekyll and all of the dependencies, you can now serve the webpage on your local machine for development purposes using the `bundle exec jekyll serve` command.
-
-You may first need to install any project dependencies. In your terminal, from the directory containing the Jekyll project run:
-
-```bash
-bundle install
-```
-
-This will install any plugins required by the project.
-To serve the webpage locally, from your terminal, in the directory containing the Jekyll project run:
-
-```bash
-bundle exec jekyll serve --future --port=8080 --host=0.0.0.0
-```
-
-You should see something along the lines of:
-
-```
-> bundle exec jekyll serve
-Configuration file: /home/$USER/blog_post_repo/_config.yml
- Source: /home/$USER/blog_post_repo
- Destination: /home/$USER/blog_post_repo/_site
- Incremental build: disabled. Enable with --incremental
- Generating...
- Jekyll Feed: Generating feed for posts
-
- ... you may see a lot of stuff in here related to images ...
-
- done in 0.426 seconds.
- Auto-regeneration: enabled for '/home/$USER/blog_post_repo'
- Server address: http://0.0.0.0:8080/2024/
- Server running... press ctrl-c to stop.
-```
-
-If you see this, you've successfully served your web page locally!
-You can access it at server address specified, in this case `http://0.0.0.0:8080/2024/` (and the blog posts should once again be viewable at the `blog/` endpoint).
-
-
-### Submitting your Blog Post
-
-To submit your blog post:
-
-1. **Anonymize your blog post.** Strip all identifying information from your post, including the
- author's list (replace with `Anonymous`).
-2. Double check that your post matches the formatting requirements, including (but not limited to):
- - **Only modify** files in the following locations (failure to do so will result in your PR
- automatically being closed!):
- - a Markdown (or HTML) file in `_posts/` with the format `_posts/2024-05-07-[SUBMISSION NAME].md`
- (or `.html`)
- - static image assets added to `assets/img/2024-05-07-[SUBMISSION NAME]/`
- - interactive HTML figures added to `assets/html/2024-05-07-[SUBMISSION NAME]/`
- - citations in a bibtex file in `assets/bibliography/2024-05-07-[SUBMISSION NAME].bib`
- - Have a short 2-3 sentence abstract in the `description` field of your front-matter ([example](https://github.com/iclr-blogposts/2024/blob/295ab5b4c31f2c7d421a4caf41e5481cbb4ad42c/_posts/2024-05-07-distill-example.md?plain=1#L4-L6))
- - Have a table of contents, formatted using the `toc` field of your front-matter ([example](https://github.com/iclr-blogposts/2024/blob/295ab5b4c31f2c7d421a4caf41e5481cbb4ad42c/_posts/2024-05-07-distill-example.md?plain=1#L36-L47))
- - Your bibliography uses a `.bibtex` file as per the sample post
-3. Open a pull request against the `main` branch of the [2024 repo](https://github.com/iclr-blogposts/2024).
- Fill in the checklist provided in the PR template. The title of your pull request should be
- exactly the name of your markdown/html file.
- - i.e. `_posts/2024-05-07-[SUBMISSION NAME].md` would require a PR name `2024-05-07-[SUBMISSION NAME]`
-4. (TBD) Your post will automatically run two pipelines: one to verify that you have not modified any other
- file in the repo, and another that will create a unique URL for your contributed blog post.
- - Verify that everything looks correct in the given URL.
- - If the pipelines failed, check if it was because of improper formatting (i.e. you modified
- restricted files). If this is the case, fix the issues. If the issue persist, please ping one of the repo admins.
-
-5. Submit the name of your blog post and its URL to our OpenReview through [this link](https://openreview.net/group?id=ICLR.cc/2024/BlogPosts&referrer=%5BHomepage%5D(%2F)).
-
-> **Note:** If you wish to make updates to your submission, you should update the content in the
-> PR that you already opened.
-
-### Reviewing Process
-
-Reviewers will be required to only view the live content of the reviewing website - the website to which the Pull Requests push to.
-We ask that they act in good faith, and refrain from digging into the repository's logs and closed Pull Requests to find any identifying information on the authors.
-
-### Camera-ready
-
-**TBD** - instructions will be provided closer to the submission deadline.
diff --git a/_plugins/external-posts.rb b/_plugins/external-posts.rb
deleted file mode 100644
index e4fd5eb6..00000000
--- a/_plugins/external-posts.rb
+++ /dev/null
@@ -1,36 +0,0 @@
-require 'feedjira'
-require 'httparty'
-require 'jekyll'
-
-module ExternalPosts
- class ExternalPostsGenerator < Jekyll::Generator
- safe true
- priority :high
-
- def generate(site)
- if site.config['external_sources'] != nil
- site.config['external_sources'].each do |src|
- p "Fetching external posts from #{src['name']}:"
- xml = HTTParty.get(src['rss_url']).body
- feed = Feedjira.parse(xml)
- feed.entries.each do |e|
- p "...fetching #{e.url}"
- slug = e.title.downcase.strip.gsub(' ', '-').gsub(/[^\w-]/, '')
- path = site.in_source_dir("_posts/#{slug}.md")
- doc = Jekyll::Document.new(
- path, { :site => site, :collection => site.collections['posts'] }
- )
- doc.data['external_source'] = src['name'];
- doc.data['feed_content'] = e.content;
- doc.data['title'] = "#{e.title}";
- doc.data['description'] = e.summary;
- doc.data['date'] = e.published;
- doc.data['redirect'] = e.url;
- site.collections['posts'].docs << doc
- end
- end
- end
- end
- end
-
-end
diff --git a/_plugins/hideCustomBibtex.rb b/_plugins/hideCustomBibtex.rb
deleted file mode 100644
index 4a852fde..00000000
--- a/_plugins/hideCustomBibtex.rb
+++ /dev/null
@@ -1,15 +0,0 @@
- module Jekyll
- module HideCustomBibtex
- def hideCustomBibtex(input)
- keywords = @context.registers[:site].config['filtered_bibtex_keywords']
-
- keywords.each do |keyword|
- input = input.gsub(/^.*#{keyword}.*$\n/, '')
- end
-
- return input
- end
- end
-end
-
-Liquid::Template.register_filter(Jekyll::HideCustomBibtex)
diff --git a/_posts/2024-05-07-alibi-mlm.md b/_posts/2024-05-07-alibi-mlm.md
deleted file mode 100644
index 5f6cdbaf..00000000
--- a/_posts/2024-05-07-alibi-mlm.md
+++ /dev/null
@@ -1,239 +0,0 @@
----
-layout: distill
-title: Masked Language Model with ALiBi and CLAP head
-description: As a new approach to positional encoding, Attention with Linear Biases (ALiBi) uses linear biases of the attention weights to encode positional information, with capability of context length extrapolation. In their paper however, Press et al. focus on the perplexity of autoregressive decoder-only language models, leaving the question of downstream tasks and its applicability to encoder-attention open. In this blogpost, we attempt to bridge the gap by testing masked language models (MLMs) with encoder-attention ALiBi and prediction head similar to the counterparts of the original ALiBi models. We find that while simplified prediction head may be beneficial, performance of MLMs with encoder-attention ALiBi starts to deteriorate with 2048 sequence length at larger scales. We put our results in the context of related recent experiments and tentatively identify the circumstances more challenging to positional encoding designs. Finally, we open-source our MLMs, with BERT-level performance and 2048 context length.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Jason Chuan-Chih Chou
- url: https://scholar.google.com/citations?user=V7BXGawAAAAJ
- affiliations:
- name: Cohere For AI Community
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-alibi-mlm.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Attention with Linear Biases (ALiBi)
- - name: Contrastive Language Pretraining (CLAP) Head
- - name: Experiments
- subsections:
- - name: WikiText-103
- - name: The Pile
- - name: Conclusions
- - name: Model Checkpoints
----
-
-*Adapted and expanded from [EIFY/fairseq](https://github.com/EIFY/fairseq).*
-
-Unmodified and unmasked, attention mechanism is permutation-invariant and positional encoding is therefore employed by transformer-based language models to break the symmetry and enable sequence modeling. In their ICLR 2022 paper, Press et al. introduced Attention with Linear Biases (ALiBi) as a new approach to positional encoding, where the positional info of the tokens are encoded by applying an attention weight bias proportional to the distance between tokens:
-
-{% include figure.html path="assets/img/2024-05-07-alibi-mlm/ALiBi.jpeg" class="img-fluid" %}
-
-where $$m$$ is a head-specific slope chosen to follow geometric sequence $$\frac{1}{2^{0.5}}, \frac{1}{2^1}, \frac{1}{2^{1.5}}, \dots, \frac{1}{2^\frac{n}{2}}$$ for a model with $$n$$ attention heads. This approach is shown to enable input length extrapolation in the sense that perplexity of the model remains stable as the inference context length exceeds training context length. The paper, however, focuses on autoregressive decoder-only models and relies on model perplexity as the metric, therefore leaves the question open whether ALiBi is applicable to MLMs like BERT and RoBERTa . To help answer this question, we tested the two following changes to the RoBERTa baseline models, based on the first-party Fairseq toolkit :
-
-
-## Attention with Linear Biases (ALiBi)
-
-Since MLMs are based on encoders that attend to tokens both before and after the given position, considerations must be made regarding how to distinguish them. Press himself [suggested the 3 following options for encoder-attention ALiBi](https://github.com/ofirpress/attention_with_linear_biases/issues/5):
-
-1. Symmetric: Keep attention weight bias proportional to the distance between tokens and rely on the context to distinguish between tokens at +N and -N position.
-2. Nonsymmetric, one-sided: Make half of the heads only attend to the tokens before and half of the heads only attend to the tokens after. Weight bias is still proportional to the distance.
-3. Nonsymmetric with different slopes: Make the slopes $$m$$ different forward and backward, with either learned or fixed values.
-
-With the observation that option 2 spends about half of the attention compute on no-op and option 3 can still result in bias value collision (e.g. $$m_{bwd} = 2 m_{fwd}$$ and -1 vs. +2 positions), we implemented both option 1 and what we call "nonsymmetric with offset": [Shift the linear biases ahead by `0.5 * slope`](https://github.com/ofirpress/attention_with_linear_biases/issues/5#issuecomment-1213410982), i.e. the constant bias (right matrix of the figure above) becomes
-
-```
- 0 -.5 -1.5 -2.5 -3.5
--1 0 -.5 -1.5 -2.5
--2 -1 0 -.5 -1.5
--3 -2 -1 0 -.5
--4 -3 -2 -1 0
-```
-
-Unless otherwise noted, ALiBi for the following experiments means this nonsymmetric-with-offset encoder-attention ALiBi.
-
-## Contrastive Language Pretraining (CLAP) Head
-The prediction head is one part of the LMs that has received less attention that happens to differ between the ALiBi autoregressive decoder-only models and RoBERTa. Based on the configs and [training logs](https://github.com/ofirpress/attention_with_linear_biases#saved-checkpoints), the ALiBi models use the adaptive word embedding and softmax of Baevski & Auli with weight tying , whereas the RoBERTa prediction head has an additional fully-connected layer and nonlinearity on top of weight-tying. Inspired by CLIP , we decided to test what we called Contrastive Language Pretraining (CLAP) head below, as the [simplest possible prediction head with weight tying](https://github.com/EIFY/fairseq/blob/8143446dfa88d9f8e246b366bd335f6c9b018db0/fairseq/models/roberta/model.py#L527-L543) for the masked tokens plus the thermodynamic beta (inverse temperature):
-
-{% highlight python %}
-class ClapHead(nn.Module):
- """Head for masked language modeling."""
-
- def __init__(self, initial_beta, weight):
- super().__init__()
- self.beta = nn.Parameter(torch.tensor(initial_beta))
- self.weight = weight
-
- def forward(self, features, masked_tokens=None, normalize=True):
- # Only project the masked tokens while training,
- # saves both memory and computation
- if masked_tokens is not None:
- features = features[masked_tokens, :]
- w = self.weight
- if normalize:
- w = F.normalize(w, dim=-1)
- return self.beta * F.linear(features, w)
-{% endhighlight %}
-
-Compared to the [baseline RoBERTa prediction head](https://github.com/facebookresearch/fairseq/blob/da8fb630880d529ab47e53381c30ddc8ad235216/fairseq/models/roberta/model.py#L470-L495)
-
-{% highlight python %}
-class RobertaLMHead(nn.Module):
- """Head for masked language modeling."""
-
- def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
- super().__init__()
- self.dense = nn.Linear(embed_dim, embed_dim)
- self.activation_fn = utils.get_activation_fn(activation_fn)
- self.layer_norm = LayerNorm(embed_dim)
-
- if weight is None:
- weight = nn.Linear(embed_dim, output_dim, bias=False).weight
- self.weight = weight
- self.bias = nn.Parameter(torch.zeros(output_dim))
-
- def forward(self, features, masked_tokens=None, **kwargs):
- # Only project the masked tokens while training,
- # saves both memory and computation
- if masked_tokens is not None:
- features = features[masked_tokens, :]
-
- x = self.dense(features)
- x = self.activation_fn(x)
- x = self.layer_norm(x)
- # project back to size of vocabulary with bias
- x = F.linear(x, self.weight) + self.bias
- return x
-{% endhighlight %}
-
-We removed the `embed_dim x embed_dim` fully-connected layer, activation function (GELU), layer norm, and the `output_dim` trainable bias. Just like CLIP, we added the trainable thermodynamic beta and L2-normalize the token embeddings before feeding them to the transformer and computing the inner products between them and the transformer output as the softmax logits, scaled by beta.
-
-## Experiments
-
-### WikiText-103
-At first we tested the changes with the [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/) with a GeForce RTX 3080 16 GB Laptop GPU, using the validation set MLM perplexity as the metric. We tested the baseline (learned positional encoding + RoBERTa prediction head), learned-clap (learned positional encoding + CLAP head), ALiBi (ALiBi + RoBERTa prediction head), and zero-clap (ALiBi + CLAP head), in addition to baseline but with sinusoidal positional encoding instead of learned positional encoding:
-
-{% include figure.html path="assets/img/2024-05-07-alibi-mlm/valid_ppl_cleaned.png" class="img-fluid" %}
-
-where solid lines are what's considered "canonical" setup and dotted lines are experiments with the following variations in setup. These variations turned out to be irrelevant:
-
-1. Whether we use attention dropout or not
-2. Whether we use [symmetric ALiBi (option 1)](https://github.com/ofirpress/attention_with_linear_biases/issues/5) or nonsymmetric-with-offset ALiBi above
-3. ~~Whether we use zero vector or a separate learnable embedding for the mask embedding~~The intention was to test using zero vector instead of a separate learnable embedding for the mask embedding, which in combination with ALiBi results in no non-semantic information in the input embeddings. However, a bug prevented this variation from working correctly and the end effect was merely deleting the last two words (madeupword0001 and madeupword0002) from the dictionary instead, which we don't expect to be consequential.
-4. Whether we L2-normalize the embeddings for the CLAP head or not
-5. Whether we scale the L2-normalized embeddings by `sqrt(embed_dim)` (`no_scale_embedding=False`) or not
-
-As we can see, the dotted lines are almost on top of the solid lines. Notably, sinusoidal positional encoding underperforms significantly compared to learned positional encoding.
-
-### The Pile
-As the next step, we scaled our experiments to train on the Pile for one epoch. About half of the examples in the Pile has sequence length > 1024, so we set sequence length to 2048. Even so, ~1/7 of the examples have sequence length > 2048 and had to be discarded. In the end, one epoch consists of 133082 updates and [we employ cosine learning rate schedule while "overestimating" the number of training steps by 10%](https://github.com/EIFY/fairseq/blob/33fb2c306851f104cc567b7fe865b1e3fd1e6fe7/examples/roberta/config/pretraining/baseline_pile.yaml#L31-L36), as inspired by the Chinchilla paper . In addition to the validation MLM perplexity, we also fine-tuned the models on the [GLUE](https://gluebenchmark.com/) benchmark . As in the original RoBERTa paper, we tested both the `roberta.base` with 125M parameters and `roberta.large` with 355M parameters. These experiments were performed on 8 x A100 40GB SXM4 GPUs, where the `roberta.base` experiments took ~3 days and `roberta.large` experiments took ~9 days. In the table below, `PPL` is the final validation MLM perplexity, `STS-B` is the best validation loss, and all the others are the best validation accuracies over 10 epochs of finetuning.
-
-#### `roberta.base`
-```
- PPL↓ CoLA MNLI MRPC QNLI QQP RTE SST-2 STS-B↓
-baseline 2.94 83.6 84.2 90 91.6 91.3 73.6 92.1 0.028
-learned-clap 2.86 81.7 84.4 86.3 90.9 91.2 72.6 92.5 0.027
-alibi 2.93 69.2 85.1 80.9 92 91.5 63.9 93.1 0.033
-zero-clap 2.83 70.5 84.9 75.5 90.6 91.1 54.9 89.7 0.041
-```
-\**Baseline but with sinusoidal positional encoding instead of learned positional encoding failed to converge.*
-
-#### `roberta.large`
-```
- PPL↓ CoLA MNLI MRPC QNLI QQP RTE SST-2 STS-B↓
-baseline* 2.55 83.7 86.8 84.3 92.5 91.8 79.8 93.3 0.027
-learned-clap 2.5 84.1 86.3 89.7 92.8 91.7 79.8 93.7 0.023
-alibi 2.65 69.1 86.5 68.4 92.4 91.7 52.7 93.6 0.123
-zero-clap 2.54 69.1 86.7 81.9 92.2 91.6 52.7 93.1 0.031
-```
-\**Loss spiked somewhere between 24000-24500 updates and the model failed to recover. Loosely following the practice of `5.1 Training Instability` in the PaLM paper , we solved the issue by restarting the training from the 20000 updates checkpoint with the PyTorch random seed changed from `1` to `2`.*
-
-We found that ALiBi no longer helps lowering the validation MLM perplexity. Furthermore, ALiBi turned out to be harmful for several specific GLUE tasks (`CoLA`, `MRPC`, and `RTE`). CLAP head on its own, however, seems to be competitive and in fact outperforms the baseline with `roberta.large`.
-
-## Conclusions
-This seems to be another case where models with lower perplexity do not necessarily yield higher accuracies for downstream tasks and architectural changes beneficial for models at smaller scales do not imply the same for models at larger scales . CLAP head, however, is simpler than the standard prediction head for MLMs, requires minimal changes, and may be worth trying especially at larger scales.
-
-In the broader context, MosaicBERT and LittleBird are most similar to our experiments. In the MosaicBERT paper, Portes et al. also evaluate BERT-style MLMs with symmetric (option 1) encoder-attention ALiBi on the GLUE benchmark and find performance exceeding the BERT baseline within limited training budget. However, these MosaicBERT models were trained with much shorter (128) sequence length and so may have avoided the sequence length regime in which perplexity and performance of certain downstream tasks start to deteriorate The same can be said about , which also reports in Table 4 the MLM perplexity of RoBERTa large models trained on an excerpt of the Pile with various positional encodings including symmetric (option 1) encoder-attention ALiBi with 128 sequence length.. The LittleBird architecture is designed for question answering and built with BiALiBi (Bidirectional ALiBi), a variation of option 3 (nonsymmetric with different slopes) where the model not only learned the forward and backward slopes $$m_{fwd}$$ and $$m_{bwd}$$, but also a special bias value for the attention weight of the global `[CLS]` token. Lee et al. evaluate LittleBird models on a collection of QA Benchmarks for both English and Korean and report favorable performance, but leave the question open whether they work well for other NLP tasks. Notably, we also found our ALiBi models capable of matching the baseline performance of the question answering task `QNLI`, so the reported performance is compatible with our experiments even without attributing to the other differences in architecture or pretraining task.
-
-Finally, what can we say about the original decoder-attention ALiBi and positional encodings in general? The original decoder-attention ALiBi has been shown to help not only perplexity, but also performance on evaluation suites consist of a diverse set of tasks like the EleutherAI Language Model Evaluation Harness . This discrepancy may be explained by the causal mask, which has been proven to be sufficient for encoding positional information in theory One caveat is that Proof C.1 of for absolute positional encoding depends on distinguishing values of unit fractions 1/t, which eventually fails due to precision limit. For example, 1/1464 can't be distinguished from 1/1465 in float16, well within the context length of interest., if not quite matching the performance of models with additional positional encodings in practice . Perhaps we can conclude that
-
-1. Decoder-attention positional encodings really should be considered causal mask + additional encodings and how they complement each other should be taken into account.
-2. Longer context length and certain downstream tasks are more challenging for positional encodings. One worthwhile direction may be to rank their difficulties systematically and iterate on the more challenging circumstances first for future positional encoding designs.
-
-## Model checkpoints
-Final checkpoints for models trained on the Pile:
-
-### `roberta.base`
-
-[baseline](https://drive.google.com/file/d/1r9VwJCU3AeuivNULRuY3Taq_3AEBg-v5/view?usp=share_link)
-[learned-clap](https://drive.google.com/file/d/1KmO3FEaawz0tHW-s581NmrkL-OZklLYk/view?usp=share_link)
-[alibi](https://drive.google.com/file/d/1s4Tcjnbawq1W6LBcknysj6NdpMfJdek6/view?usp=share_link)
-[zero-clap](https://drive.google.com/file/d/1PwE_MASg4FinuKq6DX29A8c2lPP2B6nb/view?usp=share_link)
-
-### `roberta.large`
-
-[baseline](https://drive.google.com/file/d/1XSStju8S9y1BCHpXqZ_fZcueH3A0yW2c/view?usp=share_link)
-[learned-clap](https://drive.google.com/file/d/1UyFxC3XoQ5eAhhXaAUQznLbBLa0J_45U/view?usp=share_link)
-[alibi](https://drive.google.com/file/d/1D22xJxJTI4gPAD4gHfKaN1ytjQTy2u_y/view?usp=share_link)
-[zero-clap](https://drive.google.com/file/d/1ktiRIVqz46DbV261_WxA9RELR971_2iu/view?usp=share_link)
-
-To load them, install [EIFY/fairseq](https://github.com/EIFY/fairseq) following [the original instructions](https://github.com/facebookresearch/fairseq/blob/b8ac3fa6cc95f9dc97085232d4faf125e5bcd2e7/README.md#requirements-and-installation) and download the GPT-2 fairseq dictionary:
-```
-wget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
-```
-Then all of the checkpoints above except the `zero-clap` ones can load as follows:
-```
-$ python
-Python 3.8.10 (default, Jun 22 2022, 20:18:18)
-[GCC 9.4.0] on linux
-Type "help", "copyright", "credits" or "license" for more information.
->>> from fairseq.models.roberta import RobertaModel
->>> roberta = RobertaModel.from_pretrained('/checkpoint-dir', 'learned-clap-large.pt', '/dict-dir')
-(...)
->>> roberta.fill_mask('The capital of China is .', topk=3)
-[('The capital of China is Beijing.', 0.7009016871452332, ' Beijing'), ('The capital of China is Shanghai.', 0.23566904664039612, ' Shanghai'), ('The capital of China is Moscow.', 0.010170688852667809, ' Moscow')]
->>>
-```
-The `zero-clap` ones were trained without the last two `madeupword`'sThis is due to the same bug that affected the WikiText-103 variation above and its only visible effect., so you need to delete them from `dict.txt` before loading, i.e.:
-
-
-
-```
-$ python
-Python 3.8.10 (default, Jun 22 2022, 20:18:18)
-[GCC 9.4.0] on linux
-Type "help", "copyright", "credits" or "license" for more information.
->>> from fairseq.models.roberta import RobertaModel
->>> roberta = RobertaModel.from_pretrained('/checkpoint-dir', 'zero-clap-large.pt', '/dict-dir')
-(...)
->>> roberta.fill_mask('The capital of China is .', topk=3)
-[('The capital of China is Beijing.', 0.7051425576210022, ' Beijing'), ('The capital of China is Shanghai.', 0.21408841013908386, ' Shanghai'), ('The capital of China is Taiwan.', 0.007823833264410496, ' Taiwan')]
->>>
-```
-
-The rest of the original [example usage](https://github.com/facebookresearch/fairseq/blob/b8ac3fa6cc95f9dc97085232d4faf125e5bcd2e7/examples/roberta/README.md#example-usage) should also just work. While these checkpoints have only been tested with this fork, the `baseline` ones should also work with the [original fairseq repo](https://github.com/facebookresearch/fairseq) with minimum changes to the state dict:
-
-```
->>> path = '/checkpoint-dir/baseline-large.pt'
->>> with open(path, 'rb') as f:
-... state = torch.load(f, map_location=torch.device("cpu"))
-...
->>>
->>> del state['cfg']['task']['omit_mask']
-(...)
->>> torch.save(state, '/checkpoint-dir/compatible.pt')
-```
diff --git a/_posts/2024-05-07-bench-hvp.md b/_posts/2024-05-07-bench-hvp.md
deleted file mode 100644
index 14175807..00000000
--- a/_posts/2024-05-07-bench-hvp.md
+++ /dev/null
@@ -1,534 +0,0 @@
----
-layout: distill
-title: How to compute Hessian-vector products?
-description: The product between the Hessian of a function and a vector, the Hessian-vector product (HVP), is a fundamental quantity to study the variation of a function. It is ubiquitous in traditional optimization and machine learning. However, the computation of HVPs is often considered prohibitive in the context of deep learning, driving practitioners to use proxy quantities to evaluate the loss geometry. Standard automatic differentiation theory predicts that the computational complexity of an HVP is of the same order of magnitude as the complexity of computing a gradient. The goal of this blog post is to provide a practical counterpart to this theoretical result, showing that modern automatic differentiation frameworks, JAX and PyTorch, allow for efficient computation of these HVPs in standard deep learning cost functions.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-authors:
- - name: Mathieu Dagréou
- url: https://matdag.github.io
- affiliations:
- name: Inria
- - name: Pierre Ablin
- url: https://pierreablin.com/
- affiliations:
- name: Apple
- - name: Samuel Vaiter
- url: https://samuelvaiter.com/
- affiliations:
- name: CNRS
- - name: Thomas Moreau
- url: https://tommoral.github.io/
- affiliations:
- name: Inria
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-bench-hvp.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: What are HVPs and where are they useful?
- - subsections:
- - name: Inverse Hessian-vector products (iHVPs) in optimization
- - name: HVPs for the study of the loss landscape
- - name: A quick detour by automatic differentiation
- - subsections:
- - name: Computational graph
- - name: Forward mode
- - name: Reverse mode
- - name: Naive computation of HVPs
- - name: HVPs with automatic differentiation
- subsections:
- - name: Forward-over-reverse
- - name: Reverse-over-reverse
- - name: Reverse-over-forward
- - name: Benchmark with deep learning architectures
- subsections:
- - name: Time complexity
- - name: Memory complexity
- - name: Conclusion
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .framed {
- border: 1px var(--global-text-color) dashed !important;
- padding: 20px;
- }
- .marge {
- margin-left: 20px;
- }
----
-
-Hessian-vector products (HVPs) play a central role in the study and the use of the geometric property of the loss function of deep neural networks, as well as in many recent bilevel optimizers.
-However, computing such quantity is often considered prohibitive by practitioners, discouraging them from using algorithms that rely on HVPs.
-
-With this blog post, we aim to convince the practitioners that with modern automatic differentiation (AD) frameworks such as `JAX` or `PyTorch`, HVPs can be efficiently evaluated. Indeed, standard AD theory predicts that the computational cost of an HVP is of the same order as the cost of computing a gradient. After a brief introduction on why HVPs are useful for optimization and ML applications and on the basis of AD, we explain in detail the AD-based methods to compute an HVP and the reason for their efficiency. In particular, we show that one can compute HVPs without explicit Hessian computation. We then compare the different methods to compute HVPs for several deep neural network architectures in terms of time and memory for both `JAX` and `PyTorch`. Our results illustrate the complexity predicted by the theory, showing that computing an HVP is not much more expensive than computing a gradient. This opens an avenue to develop efficient second-order informed methods for neural networks.
-
-## What are HVPs and where are they useful?
-
-Let us first introduce the notion of Hessian and HVP. We will consider in this post a twice differentiable function $$f:\mathbb{R}^d\to\mathbb{R}$$ that goes from a vector $$x$$ in space $$\mathbb{R}^d$$ to a real number in $$\mathbb{R}$$. This typically corresponds to a function that maps the value of the parameters $$\theta$$ of a neural network to the loss $$f(\theta)$$.
-For such a function, standard AD can be used to efficiently compute the gradient of the loss $$\nabla f(\theta) = \left[ \frac{\partial f}{\partial \theta_i}(\theta)\right]_{1\le i \le d} \in \mathbb{R}^d$$, using the backpropagation.
-The Hessian matrix of $$f$$ at $$\theta$$ is the matrix of its second-order partial derivatives
-
-$$
- \nabla^2 f(\theta) = \left[\frac{\partial^2f}{\partial \theta_i\partial \theta_j}(\theta)\right]_{1\leq i,j\leq d}\in\mathbb{R}^{d\times d}\enspace.
-$$
-
-This matrix corresponds to the derivative of the gradient and captures how the gradient will change when moving $$x$$. To evaluate the variation of the gradient when moving $$\theta$$ in the direction $$v\in\mathbb{R}^d$$, one can compute the quantity $$\nabla^2 f(\theta) v\in\mathbb{R}^d$$. This is the Hessian-vector product (HVP).
-
-Let us review some use cases of HVPs in optimization and machine learning.
-
-### Inverse Hessian-vector products (iHVPs) in optimization
-When trying to find the minimum of the function $$f$$, methods that account for the second-order information often rely on the product between the inverse Hessian and a vector to find a good update direction.
-For instance, Newton's method relies on update rules of the form
-
-$$
- \theta_{k+1} = \theta_k - \eta_k[\nabla^2f(\theta_k)]^{-1}\nabla f(\theta_k)
-$$
-
-for some step-size $$\eta_k>0$$.
-
-When evaluating the term $$[\nabla^2f(\theta_k)]^{-1}\nabla f(\theta_k)$$, it would be very inefficient to first compute the full Hessian matrix $$\nabla^2f(\theta_k)$$, then invert it and finally multiply this with the gradient $$\nabla f(\theta_k)$$.
-Instead, one computes the inverse Hessian-Vector Product (iHPV) by solving the following linear system
-
-\begin{equation}\label{eq:linear_system}
- \nabla^2f(\theta)v = b\enspace.
-\end{equation}
-
-with $$b = \nabla f(\theta_k)$$.
-This approach is much more efficient as it avoids computing and storing the full Hessian matrix, and only computes the inverse of the matrix in the direction $$v$$.
-
-A second use case for the iHVP in optimization is with bilevel optimization. In bilevel optimization, one wants to solve the following problem
-
-\begin{equation}\label{eq:bilevel_pb}
- \min_{x\in\mathbb{R}^d} h(x) = F(x, y^* (x))\quad\text{with}\quad y^*(x) = \arg\min_{y\in\mathbb{R}^p} G(x, y)\enspace.
-\end{equation}
-
-The gradient of the function $$h$$ can be computed using the implicit function theorem, giving the following expression
-
-$$
- \nabla h(x) = \nabla_x F(x, y^* (x)) - \nabla_{xy}G(x, y^*(x))[\nabla_{yy}G(x, y^*(x))]^{-1}\nabla_y G(x, y^*(x))\enspace.
-$$
-
-Here, the term $$\nabla^2_{yy} G(x, y)$$ is the Hessian of the function $$G$$ relatively to $$y$$. Thus, this quantity also requires computing an iHVP.
-
-To compute the iHVP, there are many methods in the literature to solve \eqref{eq:linear_system}, like Neumann iterates, the Conjugate Gradient method or gradient descent steps in the quadratic form $$v\mapsto \frac12\langle\nabla^2f(\theta)v, v\rangle - \langle b, v\rangle$$. These methods rely on HVPs, as illustrated by the highlighted terms in the Conjugate Gradient method. Thus, an efficient implementation of HVPs is crucial for the overall algorithm performance.
-
-
-
-### HVPs for the study of the loss landscape
-
-The study of the geometry of neural networks is an active field that aims at understanding the links between training dynamics, local geometry of the training loss and generalization. One way to study the local geometry of a neural network is to find the distribution of the eigenvalues of its Hessian matrix. Indeed, depending on the sign of the eigenvalues of the Hessian, one can for instance distinguish local minima, local maxima and saddle points. As an illustration, the following figure shows how the sign of the eigenvalues of the Hessian matrix of a function affects the shape of the function's landscape around a stationary point.
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/hess_eig.png" class="img-fluid" %}
-
-
-In several papers, an approximation of the Hessian spectrum is computed thanks to the Lanczos algorithm. This algorithm is a modification of the power method where each new iterate is taken in the orthogonal complement of the previous iterates. It outputs a factorization of the Hessian of the form $\nabla^2 f(\theta) = VTV^\top$ where $$V=(v_0,...,v_{k-1})$$ is orthogonal and
-
-$$
-T = \begin{pmatrix}
- \alpha_0& \beta_1 & 0 & \cdots & 0\\
- \beta_1 & \alpha_1 & \beta_2 & \ddots & \vdots\\
- 0 & \beta_2 & \alpha_2 & \ddots & 0\\
- \vdots & \ddots & \ddots & \ddots & \beta_{k-1}\\
- 0 & \cdots & 0 & \beta_{k-1} & \alpha_{k-1}
-\end{pmatrix}\enspace.
-$$
-
-
-
-
-We observe once again that the Hessian information is accessed through HVPs rather than the full Hessian matrix itself.
-
-
-## A quick detour by automatic differentiation
-
-Automatic differentiation (AD) is an important tool to compute exactly the derivatives of differentiable functions obtained as the composition of simple operations.
-There are two modes in AD; the forward mode that computes Jacobian-vector products (JVPs) and the reverse mode that computes vector-Jacobian products (VJPs).
-Since the gradient of a scalar function is a special case of the VJP, the reverse mode is the most frequently used in machine learning.
-It is typically used to compute the gradients of deep learning cost functions, where it is called *backpropagation*.
-
-In what follows, we briefly present the notion of computational graph and the two AD modes. For a more detailed explanation, we refer the reader to the excellent survey by Baydin et al..
-
-### Computational graph
-
-A key ingredient of AD is a computational graph associated with the code that evaluates a function.
-It is a directed acyclic graph that represents the succession of elementary operations required the evaluate a function.
-Simple computational graph of a function $$f:\mathbb{R}^d\to\mathbb{R}^p$$ are typically
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/direct_graph.png" class="img-fluid"%}
-
-In this graph, the vertices $$z_i\in\mathbb{R}^{m_i}$$ represent the intermediate states of the evaluation of $$f$$.
-To get the vertex $$z_i$$, we use the values of its parents in the graph $$z_{i-1}$$, with simple transfer functions $$z_i(z_{i-1})$$.
-The computational complexity of the function evaluation depends on the complexity of the considered graph, as one node might have more than one parent.
-The memory footprint of the evaluation of the function is also linked to the maximum number of parents that can have a vertex in the computational graph, as their value needs to be stored until all children nodes have been computed.
-
-Let us take an example with a multilayer linear perceptron (MLP) with 2 layers.
-The function $$f_x:\mathbb{R}^h\times \mathbb{R}^{h\times p}\to \mathbb{R}$$ is defined for an input $$x\in\mathbb{R}^p$$ by
-
-\begin{equation}\label{eq:mlp}
- f_x(U, W) = \frac12(UWx)^2\enspace.
-\end{equation}
-
-Here, the input $$\theta$$ corresponds to the parameters of the network $$(U, V)$$ and the intermediate steps are $$z_1 = Wx$$, $$z_2 = Uz_1$$ and $$z_3 = \frac12 z_2^2$$.
-A possible computational graph to get $$f_x(U, W)$$ is the following
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/computational_graph.png" class="img-fluid"%}
-
-and the associated Python code to compute $$f_x$$ is
-```python
-def f(U, W):
- z1 = W @ x
- z2 = U @ z1
- z3 = 0.5 * z2**2
- return z3
-```
-
-Here, the feed-forward structure of the function makes the computational graph very simple, as each node has a single intermediate result parent.
-
-AD uses this computational graph to compute the function's derivatives.
-Using the chain rule, the Jacobian $$\frac{\partial f}{\partial \theta}(\theta)$$ of $$f$$ is obtained as a product of the Jacobian of the intermediate states $$z_1, \dots, z_n$$.
-\begin{equation}\label{eq:chain_rule}
- \underbrace{\frac{\partial f}{\partial \theta}(\theta)}\_{p\times d} = \frac{\partial z_n}{\partial \theta}
- =\frac{\partial z_n}{\partial z_1}\frac{\partial z_1}{\partial \theta}=\cdots = \underbrace{\frac{\partial z_n}{\partial z_{n-1}}}\_{p\times m_{n-1}}\underbrace{\frac{\partial z_{n-1}}{\partial z_{n-2}}}\_{m_{n-1}\times m_{n-2}}\cdots\underbrace{\frac{\partial z_1}{\partial \theta}}\_{m_1\times d}\enspace.
-\end{equation}
-Depending on the order of the multiplication, one can compute the derivative of $$f$$ with respect to $$\theta$$ in two ways: the forward mode and the reverse mode.
-
-### Forward mode
-
-For a vector $v\in\mathbb{R}^d$, the Jacobian-vector product (JVP) corresponds to the directional derative of $f$ in the direction $v$. It can be computed by the forward mode AD
-
-\begin{equation}\label{eq:chain_rule_jvp}
- \frac{\partial f}{\partial \theta}(\theta)\times v = \frac{\partial z_n}{\partial z_{n-1}}\frac{\partial z_{n-1}}{\partial z_{n-2}}\cdots\frac{\partial z_1}{\partial \theta}v\enspace.
-\end{equation}
-
-It consists in doing the multiplications in \eqref{eq:chain_rule_jvp} from the right to the left. It is a forward pass in the computational graph where we propagate at the same time the states $$z_i$$ and the partial derivatives $$\frac{\partial z_{i+1}}{\partial z_i}$$. If $$f$$ is real-valued, the $$i$$th coordinate of its gradient is exactly given by product of the Jacobian of $$f$$ and the $$i$$th canonical basis vector $$e_i$$ since
-\begin{equation}
-\frac{\partial f}{\partial \theta_i}(\theta) = \lim_{t\to 0}\frac{f(\theta+te_i)-f(\theta)}{t}\enspace.
-\end{equation}
- Thus, we can get its gradient by computing each of the $$d$$ JVPs $$\left(\frac{\partial f}{\partial \theta_i}(\theta)\times e_i\right)_{1\leq i \leq d}$$ with forward AD.
-
-To understand properly what is happening when using forward differentiation, let us go back to the linear MLP defined in \eqref{eq:mlp}.
-If we implement ourselves the forward differentiation to get the JVP, we obtain the following code
-
-``` python
-def jvp(U, W, v_u, v_w):
- # Forward diff of f
- z1 = W @ x
- v_z1 = v_w @ x # Directional derivative of W -> W @ x in the direction v_w
-
- z2 = U @ z1
- v_z2 = U @ v_z1 + v_u @ z1 # Directional derivative of (U, z_1) -> z2 in the direction (v_u, v_z1)
-
- v_z3 = v_z2 @ z2 # Directional derivative of z2 -> .5*z2**2 in the direction v_z2
- return v_z3
-```
-
-In comparison with the code of the evaluation of $$f_x$$, there are two more operations corresponding to the computation of the dual variables `v_z1` and `v_z2`. In terms of memory, if we consider the computation of the JVP as coded in the previous snippet, the maximum number of parents of a vertex is four. This maximum is achieved by the vertex `v_z2` which has the vertices `U`, `v_z1`, `v_u` and `z1` as parents.
-
-In `JAX`, we get the JVP of a function $$f$$ in the direction $$v$$ with `jax.jvp(f, (params, ), (v, ))[1]`.
-
-### Reverse mode
-The reverse mode is also known as backpropagation in the context of deep learing. For $u\in\mathbb{R}^p$, it aims at computing VJPs
-
-\begin{equation}\label{eq:chain_rule_vjp}
- u^\top\frac{\partial f}{\partial \theta}(\theta) = u^\top\frac{\partial z_n}{\partial z_{n-1}}\frac{\partial z_{n-1}}{\partial z_{n-2}}\cdots\frac{\partial z_1}{\partial \theta}\enspace.
-\end{equation}
-
-In the reverse AD, the multiplications of \eqref{eq:chain_rule_jvp} are done from the left to the right. It requires doing one forward pass in the computational graph to compute the intermediate states $$z_i$$ and then a backward pass to propagate the successive partial derivatives from the left to the right. Contrary to the forward mode, it has a more important memory footprint. Indeed, it requires storing the values of all the states. For instance, to compute the last term $$\frac{\partial z_3}{\partial z_2}$$, one needs the value of $$z_2$$ which was the first computed during the forward pass. If $$f$$ is real-valued, $$u$$ is a scalar and the VJP is the multiplication of the gradient of $$f$$ by $$u$$. Thus, one can get the gradient on $$f$$ by using $$u=1$$ and performing only one reverse differentiation. This makes this mode more efficient in computing gradients.
-
-Let us observe what happens if we code manually the backpropagation to get the gradient of the previous function $$f_x$$ defined by $$f_x(U, W) = \frac12(UW x)^2$$.
-
-``` python
-def gradient(U, W):
- # Forward pass
- z1 = W @ x
- z2 = U @ z1
- z3 = 0.5 * z2**2
-
- # Reverse pass
- ## Transfer function: z3 = 0.5 * z2**2
- dz2 = z2 # derivative of z3 wrt z2
-
- ## Transfer function: z2 = U @ z1
- dU = jnp.outer(dz2, z1) # derivative of z3 wrt U
- dz1 = U.T @ dz2 # derivative of z3 wrt z1
-
- ## Transfer function: z1 = W @ x
- dW = jnp.outer(dz1, x) # derivative of z3 wrt W
-
- return dU, dW
-```
-
-This function returns the gradient of $$f_x$$. At reading this code, we understand one needs to store all the intermediate values of the forward pass in the graph. Indeed, if we look at the case of `z1` which is the first node computed, it is used four steps later for the computation of `dU`.
-
-To get the gradient in JAX, one can use `jax.grad(f)(params)`.
-
-
-## Naive computation of HVPs
-Since we are interested in computing $$\nabla^2 f(\theta)v$$, the simplest way to do it is to compute the Hessian matrix and then multiply it by the vector $$v$$. This can be achieved in `JAX` by calling `jax.hessian(f)(params) @ v`.
-
-This method is quite cumbersome making it impossible to use for deep neural networks. Indeed, the storage of the full Hessian matrix has $$\mathcal{O}(d^2)$$ complexity where $$d$$ is the dimension of the model's parameters set.
-
-The good news is that we can compute HVP without computing the Hessian thanks to clever use of AD.
-
-
-## HVPs without explicit Hessian computation
-In 1994, Pearlmutter proposed to leverage the following observation to compute HVP efficiently: the HVP is also the directional derivative of the gradient in the direction $$v$$:
-
-$$
-\nabla^2f(\theta) v = \lim_{\epsilon\to 0} \frac1\epsilon[\nabla f(\theta+\epsilon v)-\nabla f(\theta)] = \nabla [\langle \nabla f(.), v\rangle](\theta)\enspace.
-$$
-
-Based on this identity, AD enables to compute HVPs in three ways, as described in the [JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html).
-
-
-### Forward-over-reverse
-The forward-over-reverse mode consists in doing forward differentiation in a computational graph of the gradient of $$f$$.
-
-Its implementation in `JAX` is only two lines of code.
-
-```python
-def hvp_forward_over_reverse(f, params, v):
- return jax.jvp(jax.grad(f), (params, ), (v, ))[1]
-```
-In this case, `jax.grad(f)(params)` is computed by backward AD, whose complexity is two times the complexity of evaluating $$f$$.
-Thus, the temporal complexity of `hvp_forward_over_reverse` is roughly four times the complexity of the evaluation of $$f$$.
-
-To better see what happens, let us consider again our function $$f_x$$ defined by \eqref{eq:mlp}. The Python code of the `forward-over-reverse` HVP is the following.
-
-```python
-def forward_over_reverse(U, W, v_U, v_W):
- # Forward through the forward pass through f
- z1 = W @ x
- v_z1 = v_W @ x
-
- z2 = U @ z1
- v_z2 = U @ v_z1 + v_U @ z1
-
- # z3 = 0.5 * z2**2
- # Forward through the backward pass through f
- z4 = z2 # dz2
- v_z4 = v_z2 # v_dz2
-
- z5 = jnp.outer(z4, z1) # dU
- v_z5 = jnp.outer(v_z4, z1) + jnp.outer(z4, v_z1) # v_dU
-
- z6 = U.T @ z4 # dz1
- v_z6 = U.T @ v_z4 + v_U.T @ z4 # v_dz1
-
- z7 = jnp.outer(z6, x) # dW
- v_z7 = jnp.outer(v_z6, x) # v_dW
-
- return v_z5, v_z7 # v_dU, v_dW
- ```
-
-The take-home message of this part is that, after computing the gradient of $$f_x$$, one can consider a computational graph of this gradient and perform forward differentiation through this new computational graph.
-Here, the variables `z1`,..., `z7` are the vertices of a computational graph of the gradient of $$f_x$$.
-The nice thing is that this mode enables getting at the same time the gradient and the HVP.
-Indeed, in the previous snippet, `z5` and `z7` are the components of the gradient of $$f_x$$ which could be also returned if needed.
-This feature can be useful in bilevel optimization for instance.
-
-### Reverse-over-reverse
-Instead of doing forward differentiation of the gradient, one can multiply the gradient by $$v$$ and thus get a scalar. We can then backpropagate into this scalar product. This is the reverse-over-reverse mode.
-
-It can be implemented by these lines of code.
-```python
-def hvp_reverse_over_reverse(f, params, v):
- return jax.grad(lambda y: jnp.vdot(jax.grad(f)(y), v))(params)
-```
-Since the gradients are computed by backpropagation, the complexity of `hvp_reverse_over_reverse` is twice the complexity of `jax.grad(f)`, which is roughly four times the complexity of the evaluation of $$f$$.
-
-Writting down the code of the reverse-over-reverse HVP for our function $$f_x$$ defined by \eqref{eq:mlp} makes us understand the differences between this mode and the `forward-over-reverse` mode. Particularly, one can notice that there are more elementary operations in the `reverse-over-reverse` mode than in the `forward-over-reverse` mode. Moreover, in terms of memory footprint, the `reverse-over-reverse` requires storing the values of the vertices of the computational graph of the gradient of $$f_x$$, while the `forward-over-reverse` only needs to store the values of the vertices of the computational graph of $$f_x$$. Thus, the former is less efficient than the latter.
-
-```python
-def reverse_over_reverse(U, W, v_u, v_w):
- # Forward through
- ## Forward through f
- z1 = W @ x
- z2 = U @ z1
- z3 = 0.5 * jnp.linalg.norm(z2)**2
-
- ## Reverse through f
- z4 = z2 # dz2
- z4 = jnp.outer(z3, z1) # dU
- z5 = U.T @ z3 # dz1
- z6 = jnp.outer(z5, x) # dW
-
- # Output: dot product
- z7 = jnp.sum(z4 * v_u) + jnp.sum(z6 * v_w)
-
- # Backward through z7 =
- ## z7 = jnp.sum(z4 * v_u) + jnp.sum(z6 * v_w)
- dz6 = v_w
- dz4 = v_u
-
- ## z6 = jnp.outer(z5, x)
- dz5 = dz6 @ x
-
- ## z5 = U.T @ z3
- dz3 = U @ dz5
- ddU = jnp.outer(z3, dz5) # Derivative of z7 wrt U
-
- ## z4 = jnp.outer(z3, z1)
- dz3 += dz4 @ z1
- dz1 = dz4.T @ z3
-
- ## z3 = z2
- dz2 = dz3
-
- ## z2 = U @ z1
- dz1 += dz2 * U
- # As U appears multiple times in the graph, we sum its contributions
- ddU += jnp.outer(dz2, z1)
-
- ## z1 = W @ x
- ddW = jnp.outer(dz1, x) # Derivative of z7 wrt W
-
- return ddU, ddW
- ```
-
-### Reverse-over-forward
-What about doing forward differentiation of $$f$$ rather than reverse propagation? This is what is done in the reverse-over-forward mode. It consists in backpropagating in the computational graph of the JVP of $$f$$ and $$v$$.
-
-```python
-def hvp_reverse_over_forward(f, params, v):
- jvp_fun = lambda params: jax.jvp(f, (params, ), (v, ))[1]
- return jax.grad(jvp_fun)(params)
-```
-
-This method is more efficient than the previous one. Indeed, since we backpropagate only once, the memory burden is lower than for the `reverse_over_reverse` fashion. In comparison with `forward-over-reverse`, the complexity is the same. However, one can notice that the `forward-over-reverse` enables computing at the same time the gradient of $$f$$ and the HVP, which is not the case for the `reverse-over-forward` mode.
-
-The code of the `reverse-over-forward` HVP for the MLP $$f_x$$ defined by \eqref{eq:mlp} is the following.
-
-```python
-def reverse_over_forward(U, W, v_U, v_W):
- # Forward diff of f to
- z1 = W @ x
- z6 = v_W @ x # v_z1
-
- z2 = U @ z1
- z5 = U @ z6 + v_U @ z1 # v_z2
-
- # output
- z4 = z5 @ z2 # v_z3
-
- # Backward pass through
- ## z4 = z5 @ z2
- dz2 = z5
- dz5 = z2 # dv_z2
-
- ## z5 = U @ z6 + v_U @ z1
- dz1 = v_U.T @ dz5
- dz6 = U.T @ dz5 # dv_z1
- ddU = jnp.outer(dz5, z6) # derivative of z4 wrt U
-
- ## z2 = U @ z1
- # As U and dz1 appear multiple times, we sum their contributions
- dz1 += U.T @ dz2
- ddU += jnp.outer(dz2, z1)
-
- ## z1 = W @ x
- ddW = jnp.outer(dz1, x)
- return ddU, ddW
-```
-
-## Benchmark with deep learning architectures
-
-While these three methods compute the same outputs, the different ways of traversing the computational graph change their overall time and memory complexities. We now compare the computation of HVPs with these three methods for various deep-learning architectures. To cover a broad range of use cases, we consider a residual network ([ResNet34](https://huggingface.co/docs/transformers/model_doc/resnet)) and a transformer-based architecture ([ViT-base](https://huggingface.co/docs/transformers/model_doc/vit)) for image classification as well as a transformer for natural language processing ([Bert-base](https://huggingface.co/docs/transformers/model_doc/bert#transformers.FlaxBertForTokenClassification).).
-We use the `Flax` and `PyTorch` implementations of these architectures available in the [transformers package](https://huggingface.co/docs/transformers/) provided by [Hugging Face 🤗](https://huggingface.co).
-
-All computations were run on an Nvidia A100 GPU with 40 GB of memory. We used the version 0.4.21. of `Jax` and the version 2.1.1. of `torch`.
-
-The code of the benchmark is available on [this repo](https://github.com/MatDag/bench_hvp/).
-
-### Time complexity
-
-The first comparison we make is a comparison in terms of wall-clock time between the different ways to compute HVPs and also the computation of a gradient by backpropagation. For each architecture, we compute the gradient of the model with respect to the parameters by backpropagation. We also compute the HVPs in `forward-over-reverse`, `reverse-over-forward` and `reverse-over-reverse` modes. For each computation, we measure the time taken. Specifically for the HVPs, we subtract the time taken by a gradient computation, to get only the time of the overhead required by the HVP computation.
-The inputs for each architecture are generated randomly. For the ResNet34 architecture, we generated a batch of images of size 224x224x3. To limit out-of-memory issues in the experiments, we generated for the ViT architecture images of size 96x96x3. For the BERT architecture, we generated a batch of sequences of length 32.
-
-We first use `JAX` with just-in-time compilation. Each computation is run 90 times. We plot on the left of the figure, the median computation time and also the 20% and 80% percentile in black. The computations are done with a batch size of 128. We observe that, in practice, the overhead over the gradient computation for the HVP computation is between one and twice the time of a gradient computation for the three architectures. Consequently, a whole HVP computation takes between twice and three times the time of a gradient calculation. This is consistent with the theory. One can notice that the `reverse-over-reverse` is slightly slower than the others in all the cases. The `forward-over-reverse` and `reverse-over-forward` are, as for them, very close in terms of time.
-
-We also report on the right figure the computational time of each method with respect to the batch size for the ResNet34 architecture. We observe, as expected, that the computational time scales linearly with the batch size.
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/bench_hvp_time_jax.png" class="img-fluid" %}
-
-We run a similar experiment with the functional API available in `PyTorch` [`torch.func`](https://pytorch.org/docs/stable/func.html) similar to the one `JAX` has. The results we get are more contrasted.
-
-In the case of ResNet34, the scaling between the different methods is similar to the one we get with `JAX`. Also, during our experiments, we figured out that batch normalization made the forward computation slow and induced out-of-memory issues. Thus, we removed the batch normalization layers from the ResNet34 architecture.
-
-For ViT and BERT, the `forward-over-reverse` is surprisingly longer than the `reverse-over-reverse` method. Moreover, the scaling between the gradient and HVP computational time differs from the one we get with `JAX`. Indeed, for these architectures, the HVP computations take between four and five more time than the gradient computations. This is a discrepancy with what we would expect in theory. This might be because, at the time we are writing this blog post, the functional API of `PyTorch` is still in its early stages. Particularly, we could not use the compilation with `torch.compile` because it does not work with some operators of `torch.func` such as `torch.func.jvp`.
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/bench_hvp_time_torch.png" class="img-fluid" %}
-
-### Memory complexity
-
-We also compare the memory footprint of each approach. The following figure provides the results we get with jax jitted code. On the left, we represent the result for each method and model with a batch size of 64. On the right, we show the evolution of the memory footprint of each method for the ResNet34 with the batch size. Surprisingly, we could observe that the memory footprint of the different methods to compute HVPs does not vary for a given model. This is counterintuitive since we expect that the `reverse-over-reverse` method have a larger memory footprint due to the double backpropagation.
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/bench_hvp_memory_jax.png" class="img-fluid" %}
-
-However, we do the same experiment by *disabling the JIT compilation*. The result we get corroborates the theory. Indeed, one can observe in the following figure that the memory footprint of the `reverse-over-reverse` method is larger than the one of the `forward-over-reverse` and `reverse-over-forward` methods. This is because the `reverse-over-reverse` involves two successive backward differentiations while the other two involve only one reverse differentiation. Moreover, it scales linearly with the batch size, which was not the case in the previous figure in the small batch size regime.
-
-In light of these two results, the clever memory allocation performed during just-in-time compilation reduces significantly the memory footprint of the HVP computations.
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/bench_hvp_memory_jax_without_jit.png" class="img-fluid" %}
-
-In the following figure, we plot the results we get with the `PyTorch` implementation. One can observe that in all the cases the `forward-over-reverse` consumes more memory in comparison with the `reverse-over-forward` mode. It is almost at the same level as `reverse-over-reverse` mode, which is quite unexpected.
-
-The right plot of the evolution of the memory footprint with the batch size for the ResNet34 architecture evolves linearly as expected.
-
-{% include figure.html path="assets/img/2024-05-07-bench-hvp/bench_hvp_memory_torch.png" class="img-fluid" %}
-
-## Conclusion
-
-In this blog post, we have explored the different ways to compute HVP from theoretical and practical perspectives. The three take-home messages to keep in mind are the following:
-
-* We can compute HVPs without computing Hessian matrices.
-
-* In practice, computing an HVP takes between twice and four times the time taken by a gradient computation and requires two to three times more memory than computing a gradient.
-
-* The AD framework and the use or not of the just-in-time compilation affects the practical performances of HVPs computations in time and memory.
-
diff --git a/_posts/2024-05-07-diffusion-theory-from-scratch.md b/_posts/2024-05-07-diffusion-theory-from-scratch.md
deleted file mode 100644
index e7c71591..00000000
--- a/_posts/2024-05-07-diffusion-theory-from-scratch.md
+++ /dev/null
@@ -1,457 +0,0 @@
----
-layout: distill
-title: "Building Diffusion Model's theory from ground up"
-description: "Diffusion Models, a new generative model family, have taken the world by storm after the seminal paper by Ho et al. [2020]. While diffusion models are often described as a probabilistic Markov Chains, their underlying principle is based on the decade-old theory of Stochastic Differential Equations (SDE), as found out later by Song et al. [2021]. In this article, we will go back and revisit the 'fundamental ingredients' behind the SDE formulation and show how the idea can be 'shaped' to get to the modern form of Score-based Diffusion Models. We'll start from the very definition of the 'score', how it was used in the context of generative modeling, how we achieve the necessary theoretical guarantees and how the critical design choices were made to finally arrive at the more 'principled' framework of Score-based Diffusion. Throughout this article, we provide several intuitive illustrations for ease of understanding."
-date: 2024-05-07
-htmlwidgets: true
-
-authors:
- - name: Ayan Das
- url: "https://ayandas.me/"
- affiliations:
- name: "University of Surrey UK, MediaTek Research UK"
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-diffusion-theory-from-scratch.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-toc:
- - name: Introduction
- subsections:
- - name: Motivation
- - name: Generative Modeling
- - name: Existing Frameworks
- - name: Diffusion is no different
- - name: "The 'Score'"
- - name: Generative Modeling with Scores
- subsections:
- - name: Langevin Equation and Brownian Motion
- - name: Fokker-Planck Equation
- - name: A probability path
- - name: Estimating the "score" is hard
- - name: The "forward process"
- - name: Finite time & the "schedule"
- - name: Estimating the Score
- subsections:
- - name: Implicit Score Matching
- - name: Denoising Score Matching
- - name: Probing the learning objective
- - name: Denoising as inverse problem
- - name: Last few bits
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-## Introduction
-
-### Motivation
-
-Not only generative modeling has been around for decades, few promising model families emerged and dominated the field for several years in the recent past. VAEs dominated the generative modelling landscape from 2014 onwards, until GANs took off in 2015-16; Normalizing Flows (NF) never really made it to the mainstream generative modeling due to its restrictive architectural requirement. However, it is quite clear at this point that the magnitude of impact they made is relatively less than barely 2-3 years of Diffusion Models. It is mostly attributed to one of the seminal papers (by Jonathan Ho et al.), now popularly referred to as "Denoising Diffusion Probabilistic Models" or DDPM. With the exponential explosion of works following DDPM, it is very hard, or rather unnecessary to look beyond this pivotal point.
-
-In this article, we look back into the conceptual and theoretical ideas that were in development for a long time, even outside the field of core machine learning. We will show in a later sections that, some of the theoretical 'pillars' holding Diffusion Models, have their roots deep into statistical physics and other fields. A significant part of this theory was presented afresh in the ICLR paper (won best paper award). Lastly, even though the ideas presented in this article are quite theoretical, we made our best attempt to convey them with intuitive explanations, diagrams and figures, thereby expanding its potential audience. To encourage further exploration, we provide all codes used in producing the figures (and experiments) of this article in [this repository](https://github.com/dasayan05/iclr24_blog_code).
-
-This article notes that, historically, there were two distinct roads of development that merged in order for modern diffusion models to emerge -- "scalable estimation of score" and "using the score for generative modelling". The former is relatively short, while the latter traces its origin back to ~1900, if not earlier. This article explores these two paths independently -- the latter one first while assuming the knowledge of the former. Rest of this introductory section is spent on defining the general modelling problem and the very notion of 'score' -- the primary quantity of interest. The next section deals with how we can use score in generative modelling, assuming access to an oracle for the true score. The last section dives solely into the problem of estimating the score in a scalable manner. It is worth mentioning that, in this article, we explain only the "sufficient and necessary" concepts needed to build the diffusion model framework and hence may not directly resemble the typical formalism seen in most papers.
-
-
-### Generative Modeling
-
-The problem of generative modeling, in most cases, is posed as *parametric density estimation* using a finite set of samples $$\{ x^{(n)} \}_{n=1}^N$$ from a "true but unknown" data distribution $$q_{data}(x)$$. With a suitable model family chosen as $$p_{\theta}(x)$$, with unknown parameters $$\theta$$, the problem boils down to maximizing the average (log-)likelihood (w.r.t $$\theta$$) of all the samples under the model
-
-$$
-\theta^* = arg\max_{\theta} \mathbb{E}_{x \sim q_{data}(x)} \left[ \log p_{\theta}(x) \right] \approx arg\max_{\theta} \frac{1}{N} \sum_{n=1}^N \log p_{\theta}(x^{(n)})
-$$
-
-It turned out however, that defining an arbitrary parametric density $$p_{\theta}(x)$$ is not as easy as it looks. There was one aspect of $$p_{\theta}$$ that is widely considered to be the evil behind this difficulty -- _the normalizing constant_ that stems from the axiom of probability
-
-$$
-p_{\theta}(x) = \frac{\tilde{p}_{\theta}(x)}{\color{purple} \int_x \tilde{p}_{\theta}(x)}
-$$
-
-### Existing Frameworks
-
-It was understood quite early on that any promising generative model family must have one property -- _ease of sampling_, i.e. generating new data samples. Sampling was so essential to generative modeling, that the model families that followed were all geared towards effective sampling, even if it was at the expense of other not-so-important properties. It was also well understood that there was one common underlying principle most effective for crafting "sampling-centric" generative models -- _transforming simple probability densities_. This formed the backbone of every single generative model family so far; be it VAEs, GANs or NFs, their generative process is a density transformation of this form
-
-$$
-x = f_{\theta}(z),\text{ where } z \sim \mathcal{N}(0, I)
-$$
-
-that suggests to start with a simple density (often just standard normal) followed by a functional transformation $$f_{\theta}$$, typically a neural network with parameters $$\theta$$. For VAEs, the function $$f_{\theta}$$ is the decoder; for GANs, it's the generator network and for NFs, it's the entire flow model. It is to be noted however, that the way they differ is mostly _how they are trained_, which may involve more parametric functions (e.g. VAE's encoder or GAN's discriminator) and additional machinery. This way of building generative models turned out to be an effective way of sidestepping the notorious normalizing constant.
-
-### Diffusion is no different
-
-Diffusion Models, at its core, follow the exact same principle, but with a slightly clever design. For diffusion models, the transformation $$f_{\theta}$$ is rather complicated. It is a sequence of invocations of a neural function (denoted as $$s_{\theta}$$) along with some additional computation (denoted as $$g(\cdot)$$)
-
-\begin{equation} \label{eq:diffusion_general_parametric_structure}
-x = g_1(g_2(g_3(\cdots z \cdots, s_{\theta}), s_{\theta}), s_{\theta}), \text{ where } z \sim \mathcal{N}(0, I)
-\end{equation}
-
-This is a big difference between Diffusion Models and other generative model families. Prior generative families tried to learn the exact transformation directly via one parametric neural function $$f_{\theta}$$. Diffusion Models on the other hand, try to learn $$s_{\theta}$$, a quantity very _fundamental and intrinsic_ to any true data distribution $$q_{data}(x)$$. The quantity in question has historically been called the "_Score_".
-
-### The 'Score'
-
-The term 'Score' is simply defined as the _gradient of the log-density of a distribution_, i.e. $$\nabla \log p(\cdot)$$. In statistics, it is also known (but not very popular) as the 'Informant'. One might argue that 'Score' is rather a strange name for such a quantity. It so happened that the origin of this term can be tracedThanks to this StackOverflow answer by @ben to a 1935 paper by Ronald Fisher, where he used the term in a very generic sense in order to "rank" some quantities. In the context of diffusion models however, we stick to the modern definition of score. The _true score_ of our data distribution is therefore defined as the gradient of the log of _true density_ of data, w.r.t the data variable
-
-\begin{equation} \label{eq:data_score_defn}
-\nabla_x \log q_{data}(x) \triangleq s(x)
-\end{equation}
-
-The quantity in Eq.\eqref{eq:data_score_defn} is unknown, just like the true data density $$q_{data}(x)$$. It does have a meaning though: the "_true score_" refers to the _direction of steepest increase_ in log-likelihood at any given point in the data space. See the gray arrows in the figure below.
-
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/score_def.png" class="col-8" %}
-
-
-Simply, at a point $$x$$, it tell us the best direction to step into (with little step-size $$\delta$$) if we would like to see a point $$x'$$ with slightly higher likelihood
-
-\begin{equation} \label{eq:naive_score_steps}
-x' = x + \delta \cdot \left. \nabla_x \log q_{data}(x) \right|_{x = x}
-\end{equation}
-
-Please note that this stems just from the definition of the gradient operator $$\nabla$$ in score. If you are familiar with gradient descent, you may find conceptual resemblance.
-
-Now, there are two burning questions here:
-
-1. Considering we have access to the true score, is Eq.\eqref{eq:naive_score_steps} enough to define a generative process with appropriate convergence guarantee ?
-2. How do we actually get the true score ?
-
-The following two sections answer these questions respectively. Luckily, as we now understand that these two questions are somewhat decoupled, that they can be studied independently. The first section analyzes the first question, _assuming_ we have access to the true score $$\nabla_x \log q_{data}(x)$$. The second section explores how to get the true score, or rather, an approximation of it.
-
-
-## Generative Modeling with Scores
-
-As explained before, we would like to sample from the true data distribution $$q_{data}(x)$$ but all we have access to (we assume) is its score $$s(x)$$ as defined in Eq.\eqref{eq:data_score_defn}. One may define a naive generative process as the iterative application of Eq.\eqref{eq:naive_score_steps}. Intuitively, it is very similar to gradient descent, where we greedily climb the log-density surface to attain a local maxima. If so, we can already see a possible instance of the general structure of Diffusion's generative process as hinted in Eq.\eqref{eq:diffusion_general_parametric_structure}, with $$g(\cdot)$$ being
-
-$$
-g(z, s(\cdot)) = z + \delta \cdot s(z) = z + \delta \cdot \nabla_x \log q_{data}(x)
-$$
-
-With a little reshuffling of Eq.\eqref{eq:naive_score_steps} and considering $$\delta \rightarrow 0$$, one can immediately reveal the underlying ODEOrdinary Differential Equations, or ODEs describe how a process evolves over time by its infinitesimal change. that describes the infinitesimal change
-
-\begin{equation} \label{eq:ode_with_score}
-dx = \nabla_x \log q_{data}(x) dt
-\end{equation}
-
-BUT, please note that this is only an intuitive attempt and is entirely based on the definition of score. It possesses **absolutely no guarantee** that this process can converge to samples from the true data distribution. In fact, this process is **greedy**, i.e. it only seeks to go uphill, converging exactly at the _modes_Local maxima of probability density. You can see the below figure that shows the samples $$x$$ subjected to the process in Eq.\eqref{eq:ode_with_score} and its density $$p_t(x)$$ evolving over time. The density in red is the target density whose score (we assume we know it) is being used.
-
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/greedy_wo_noise.gif" class="img-fluid" %}
-
-
-In this case, at $$t=\infty$$, all samples will converge to the state with _the highest_ likelihood (i.e. exactly a the center). This isn't really desirable as it doesn't "explore" at all. Just like any other sampling algorithm, we need noise injection !
-
-### Langevin Equation and Brownian Motion
-
-Turned out that this problem was explored long ago in molecular dynamics by french physicist Paul Langevin in the context of analyzing movements of particles suspended in a fluid. He described the overall dynamics of particles, i.e how the position of the particle changes over time $t$ when in a _potential energy_ field $$U(x)$$
-
-\begin{equation} \label{eq:original_langevin_dyn}
-dx = - \nabla_x U(x) dt + \sqrt{2} dB_t
-\end{equation}
-
-The term $$dB_t$$ is called "Brownian Motion" and is effectively the source of noise -- we will talk about this later in this subsection. Energy is considered "bad", i.e. particles do not want to stay in a state with high energy. So they try to go downhill and settle in low-energy states using the gradient of the energy surface. The langevin equation (i.e. Eq.\eqref{eq:original_langevin_dyn}) happened to provide sufficient "exploration" abilities so that the particles visit states with probability $$\propto e^{-U(x)}$$. This suggests that we can treat "negative energy" as log-likelihood
-
-$$
-q_{data}(x) \propto e^{-U(x)} \implies \log q_{data}(x) = -U(x) + C \implies \nabla_x \log q_{data}(x) = - \nabla_x U(x)
-$$
-
-By using the above substitution into the langevin equation, we can move out of physics and continue with out ML perspective
-
-\begin{equation} \label{eq:langevin_dyn}
-dx = \nabla_x \log q_{data}(x) dt + \sqrt{2} dB_t
-\end{equation}
-
-Note that this isn't very different from our "intuitive" and greedy process in Eq.\eqref{eq:ode_with_score}, except for the noise term $$dB_t$$ and a strange $$\sqrt{2}$$. But this makes a difference! The brownian motion is an old construct from particle physics to describe random motion of particles in fluid/gas. It is simply a gaussian noise with infinitesimally small varianceIn practice, the smaller step you take, the small noise you get.
-
-$$
-dB_t = \mathcal{N}(0, dt) \implies dB_t = \sqrt{dt} \cdot z,\text{ where } z \sim \mathcal{N}(0, I)
-$$
-
-With that, we can simulate our new langevin equation _with noise_ (i.e. Eq.\eqref{eq:langevin_dyn}) just like the noiseless case. You can see now that the noise is keeping the process from entirely converging into the mode. If you notice carefully, we have added a little "tail" to each point to help visualize their movement.
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/langevin_dyn_basic.gif" class="img-fluid" %}
-
-### Fokker-Planck Equation
-
-The simulation is convincing; but it'd be even better if we can _theoretically verify_ that the process in Eq.\eqref{eq:langevin_dyn} indeed converges to $$q_{data}(x)$$. The key to this proof is figuring out $$p_t(x)$$ and making sure that it stabilizes as $$t\rightarrow \infty$$, i.e. $$p_{\infty}(x) = q_{data}(x)$$. It turned out that a stochastic process of the form $$dx = \mu_t(x) dt + \sigma_t(x) dB_t$$, acting on a random variable $$x$$, induces a time-varying distribution that can be described by this ODE
-
-\begin{equation}
-\frac{\partial}{\partial t}p_t(x) = -\frac{\partial}{\partial x} \Big[ p_t(x)\mu_t(x) \Big] + \frac{1}{2} \frac{\partial^2}{\partial x^2} \Big[ p_t(x) \sigma^2_t(x) \Big]
-\end{equation}
-
-This is a well celebrated result know as the "Fokker-Planck equation" that even predates the Langevin Equation. So, the solution of this ODE is exactly what we are seeing in the above figure (middle). One can easily verify the convergence of Eq.\eqref{eq:langevin_dyn} by first observing $$\mu_t(x) = \nabla_x \log q_{data}(x), \sigma_t(x) = \sqrt{2}$$ and then using $$\frac{\partial}{\partial t} p_{\infty}(x) = \frac{\partial}{\partial t} q_{data}(x) = 0$$.
-
-$$\begin{eqnarray*}
-\frac{\partial}{\partial t}p_{\infty}(x) &=& -\frac{\partial}{\partial x} \Big[ p_{\infty}(x) \nabla_x \log q_{data}(x) \Big] + \frac{(\sqrt{2})^2}{2} \frac{\partial^2}{\partial x^2} \Big[ p_{\infty}(x) \Big] \\
-\frac{\partial}{\partial t} q_{data}(x) &=& -\frac{\partial}{\partial x} \Big[ q_{data}(x) \nabla_x \log q_{data}(x) \Big] + \frac{(\sqrt{2})^2}{2} \frac{\partial^2}{\partial x^2} \Big[ q_{data}(x) \Big] \\
-0 \text{ (LHS)} &=& -\frac{\partial}{\partial x} \Big[ \nabla_x q_{data}(x) \Big] + \frac{\partial}{\partial x} \Big[ \nabla_x q_{data}(x) \Big] = 0\text{ (RHS)}
-\end{eqnarray*}$$
-
-
-The LHS holds due to the fact that after a long time (i.e. $$t = \infty$$) the distribution stabilizesIt's called a "stationary or equilibrium distribution". Please also note that the proof above is for the 1 dimensional case and included for illustrative purpose only -- the general case is slightly more complicated.
-
-So, we're all good. Eq.\eqref{eq:langevin_dyn} is a provable way of sampling given we have access to the true score. In fact, the very work (by Song et al.) that immediately precedes DDPM, used exactly Eq.\eqref{eq:langevin_dyn} in its discrete form
-
-\begin{equation}
-x_{t+\delta} = x_t + \delta \cdot \nabla_x \log q_{data}(x) + \sqrt{2\delta} \cdot z
-\end{equation}
-
-where $$\delta$$ (a small constant) is used as a practical proxy for the theoretical $$dt$$.
-
-If you are already familiar with Diffusion Models, specifically their reverse process, you might be scratching your head. That is because, the generative process in Eq.\eqref{eq:langevin_dyn} isn't quite same as what modern diffusion models do. We need to cross a few more hurdles before we get there.
-
-### A probability path
-
-More than just a proof, the Fokker-Planck ODE provides us with a key insight -- i.e. gradually transforming one distribution into another is equivalent to traveling (over time) on a "path" in the _space of probability distributions_. Imagine a space of all possible probability distributions $$p$$While each distribution vary in space (i.e. $x$) too, let's hide it for now and imagine them to be just a vectors.. The Fokker-Planck ODE for Eq.\eqref{eq:langevin_dyn}, therefore, represents a specific dynamics on this probability space whose solution trajectory $$p_t$$ ends at $$q_{data}$$ at $$t = \infty$$.
-
-Speaking of ODEs, there is something we haven't talked about yet -- the initial distribution at $$t=0$$, i.e. $$p_0$$. In the simulation above, I quietly used a standard normal $$\mathcal{N}(0, I)$$ as starting distributionYou can notice this if you carefully see the first few frames of the animation. without ever discussing it. Turns out that our Fokker-Planck ODE does not have any specific requirement for $$p_0$$, i.e. it always converges to $$p_{\infty} = q_{data}$$ no matter where you start. Here's an illustration that shows two different starting distributions $$p_0$$ and both of their "paths" over time, i.e. $$p_t$$ in probability space ultimately converges to $$q_{data}$$.
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/fokker-plank-multiple.gif" class="img-fluid" %}
-
-So theoretically, given the score function $$\nabla_x \log q_{data}(x)$$ of a target distribution $$q_{data}(x)$$, one can "travel to" it from _any_ distribution. However, keeping in mind our need for _sampling_, it's best to choose an initial distribution that is sampling-friendly. Strictly speaking, there are couple of reasonable choices, but the diffusion model community ended up with the _Isotropic Gaussian_ (i.e. $$\mathcal{N}(0, I)$$). This is not only due to its goodwill across machine learning and statistics, but also the fact that in the context of SDEs with Brownian motionsRemember, they are infinitesimal gaussian noises., Gaussians arise quite naturally.
-
-### Estimating the "score" is hard
-
-So far what we've talked about, is just the _generative process_ or as diffusion model literature calls it, the "reverse process". But we haven't really talked about the "forward process" yet, in case you are familiar with it. The forward process, in simple terms, is an _ahead-of-time description_ of the "probability path" that reverse process intends to take. But the question is, why do we need to know the path ahead of time -- the reverse process seems quite spontaneousIn the sense that, given a score function, it just travels to the correct target distribution on its own., no ? Sadly, it can't be answered with theory alone.
-
-The problem lies in Eq.\eqref{eq:langevin_dyn} -- let's write it again with a little more verbosity
-
-\begin{equation}
-dx_t = \nabla_x \left. \log q_{data}(x) \right|_{x = x_t}\ dt + \sqrt{2} dB_t
-\end{equation}
-
-Even though we wished to estimate $$\nabla_x \log q_{data}(x)\vert_{x = x_t}$$ with neural network $$s_{\theta}(x = x_t)$$, this turned out to be **extremely hard** in practice. It was understood that one neural network is not enough to capture the richness of the score function at all values of $$x$$. There were two options before the us -- one, make the neural network expressive enough, or second, learn the network **only where it's needed**. The community settled on the second one because it was easier to solve.
-
-So, what some of the pioneering works did, is first fixing a pathOn probability space, like we showed above and then learning the score only _on that path_. It's all about specializing the neural network $$s_{\theta}(x_t, t)$$ over $$t \in [0, \infty]$$. The neural score estimator is capable of producing the right score if we provide the time $$t$$, which we can of course. We will see in [the next section](#estimating-the-score) that, to learn a score of any distribution, we need samples from it. This begs the question: how do we get samples $$x_t$$ (for all $$t$$) for training purpose ? It certainly can't be with Eq.\eqref{eq:langevin_dyn} since it requires the score. The answer is, we need to run this process in the other way -- this is what Diffusion Models call the "Forward Process".
-
-### The "forward process"
-
-Going _the other way_ requires us to run a simulation to go from $$q_{data}(x)$$ at $$t=0$$ to $$t=\infty$$, just the opposite of the animation above. Recall that we already saw how to do this. To go to any distribution at $$t=\infty$$, all you need is its score and the langevin equation. So how about we start from $$q_0 = q_{data}(x)$$ this timeDo you remember that starting point doesn't matter ! and run the langevin simulation again with a _known_ end target $$q_{\infty} = \mathcal{N}(0, I)$$ ?
-
-$$\begin{eqnarray}
-dx &=& \nabla_x \log \mathcal{N}(0, I) dt + \sqrt{2} dB_t \\
-\label{eq:forward_sde}
-&=& -x dt + \sqrt{2 dt} z
-\end{eqnarray}$$
-
-It is interesting to note that due to the target distribution being known in its closed form, we do not see any awkward scores dangling around. The score of $$\mathcal{N}(0, I)$$ is simply $$-x$$We encourage the reader to verify this on their own as an exercise.. The discretized version of Eq.\eqref{eq:forward_sde}, i.e.
-
-$$\begin{eqnarray*}
-x_{t+dt} &=& x_t - x_t \cdot dt + \sqrt{2 dt}\ z \\
-&=& (1 - dt) x_t + \sqrt{2 dt}\ z
-\end{eqnarray*}$$
-
-.. may resemble DDPM's forward processHint: compare $dt$ with DDPM's $\beta_t$..
-
-> NOTE: A little subtlety here that we only fixed the _end point_ of the forward process, but not the _exact path_. It seems that running the langevin equation in the forward direction chose one path on its own. Turns out that this is the "isotropic path" where all dimensions of the variable $$x$$ evolves in time the exact same way. Some works recently uncovered _non-isotropic_ diffusion, where it is indeed possible to travel on other paths. But this is outside the scope of this article.
-
-We can simulate the above equation just like we did in the reverse process, in order to get samples $$x_t \sim q_t$$. Below we show simulation of the forward process
-
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/forward_process_2.gif" class="col-10" %}
-
-
-While it is true that the reverse process in inherently sequential due to the arbitrary nature of the score, the forward process (in Eq.\eqref{eq:forward_sde}) is entirely known and hence can be exploited for easing the sequentiality. We can see a way out if we try to simplifyWe use the standard assumption of $dt^2 = 0$. the expression for $$x_{t+2dt}$$ using $$x_{t+dt}$$
-
-$$\begin{eqnarray*}
-x_{t+2dt} &=& (1 - dt) {\color{blue} x_{t+dt}} + \sqrt{2dt}\ z_2 \\
-&=& (1 - dt) {\color{blue} \left[(1 - dt) x_t + \sqrt{2 dt}\ z_1\right]} + \sqrt{2dt}\ z_2 \\
-&=& (1 - 2dt) x_t + \sqrt{2dt(1-dt)^2 + 2dt}\ z_{12} \\
-&=& (1 - 2 \cdot dt) x_t + \sqrt{2 \cdot 2dt}\ z_{12} \\
-\implies x_{t+2dt} &\sim& \mathcal{N}((1 - 2 \cdot dt) x_t, 2 \cdot 2dt I)
-\end{eqnarray*}$$
-
-The above simplification suggests that we can jump to any time $$t$$, without going through the entire sequence, in order to sample $$x_t \sim q_t$$. In fact, $$q_t(x_t\vert x_0)$$ is gaussian ! This result opens up an interesting interpretation -- generating $$x_0 \sim q(x_0 \vert x_t)$$ can be interpreted as solving a "gaussian inverse problems", which we explore [in a later section](#denoising-as-inverse-problem).
-
-All good for now, but there is one more thing we need to deal with.
-
-### Finite time & the "schedule"
-
-What we discussed so far, i.e. the forward and reverse process, require infinite time to reach its end state. This is a direct consequence of using the langevin equation. That, of course, is unacceptable in practice. But it so happened that there exists quite an elegant fix, which is well known to mathematics -- we simply _re-define what time means_. We may choose a re-parameterization of time as, for example, $$t' = \mathcal{T}(t) = 1 - e^{-t} \in [0, 1]$$You can see $t = 0 \implies t' = 0$ and $t = \infty \implies t' = 1$. Hence we converted the range $[0, \infty]$ to $[0, 1]$.. Plugging $$dt = \mathcal{T}'(t)^{-1} dt' = e^t dt'$$One can easily see that $t' = 1 - e^{-t} \implies dt' = e^{-t} dt \implies dt = e^t dt'$. into the forward equation brings us even closer to DDPM's forward process
-
-$$
-x_{t' + dt'} = (1 - {\color{blue}e^t dt'}) x_t + \sqrt{2 {\color{blue}e^t dt'}}\ z
-$$
-
-This suggests that in the world where time runs from $$t' = 0 \rightarrow 1$$, we need to _escalate_ the forward process by replacing $$dt$$ with $$e^t dt'$$. The quantity $$\mathcal{T}'(t)^{-1} dt' = e^t dt'$$ is analogous to what diffusion models call a "schedule". Recall that DDPM uses a small but increasing$e^t dt'$ is small because of $dt'$, while increasing because of $e^t$. "schedule" $$\beta_t$$.
-
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/ddpm_forward_kernel.png" class="col-6 z-depth-1"%}
-
-
-Of course, our choice of the exact value of end time (i.e. $$t' = 1$$) and the re-parameterization $$\mathcal{T}$$ are somewhat arbitrary. Different choices of $$\mathcal{T}$$, and consequently $$\mathcal{T}'(t)^{-1} dt'$$ lead to different schedules (e.g. linear, cosine etc.).
-
-> NOTE: Choosing a different schedule does not mean the process takes a different path on the probability space, it simply changes its _speed_ of movement over time towards the end state.
-
-#### Summary
-
-To summarize, in this section, we started with the definition of 'score' and arrived at a stochastic process (thanks to an old result by Langevin) that, at infinite time, converges to the density associated with the score. We saw that this process is provably correct and can be interpreted as a "path" on the probability space. We argued that due to the difficulty of score estimation everywhere along the path, we need samples at the intermediate time $$t$$ in order to specialize the score estimates. To do that, we had to travel backwards on the path, which can be done in closed form. We also saw how this process, even though theoretically takes infinite time, can be shrunk down to a finite interval, opening up a design choice known as "schedules".
-
-## Estimating the Score
-
-The last chapter, while explaining the "sampling" part of score-based diffusion models, assumed that we have access to the true score $$\nabla_x \log q_{data}(x)$$ via some oracle. That is, of course, untrue in practice. In fact, accessing the true score for any arbitrary distribution is just not possibleWe can only have access to the true score for distributions with closed-form, e.g. Gaussian.. So the way forward, as mentioned before, is to estimate/learn it with a parametric neural network $$s_{\theta}(x)$$. Recall however, that all we have access to is samples from $$q_{data}(x)$$.
-
-If curious enough, one may question how realistic it is to estimate the score $$\nabla_x \log q_{data}(x)$$, while we can NOT usually estimate the density $$q_{data}(x)$$ itself ? After all, it is a quantity derived from the density ! The answer becomes clear once you make the _normalization constant_ explicit
-
-$$\begin{eqnarray*}
-\nabla_x \log q_{data}(x) &=& \nabla_x \log \frac{\tilde{q}_{data}(x)}{\int_{x} \tilde{q}_{data}(x) dx} \\
-&=& \nabla_x \log \tilde{q}_{data}(x) - {\color{red}\nabla_x \log \int_{x} \tilde{q}_{data}(x) dx} \\
-&=& \nabla_x \log \tilde{q}_{data}(x)
-\end{eqnarray*}$$
-
-The part in red is zero due to not having dependence on $$x$$. So, the score, very cleverly **sidesteps the normalization constant**. This is the reason score estimation gained momentum in the research community.
-
-### Implicit Score Matching
-
-The first notable attempt of this problem was by Aapo Hyvärinen back in 2005. His idea was simply to start from a loss function that, when minimized, leads to an estimator of the true score
-
-\begin{equation}
-J(\theta) = \frac{1}{2} \mathbb{E}_{x\sim q\_{data}(x)}\Big[ \vert\vert s\_{\theta}(x) - \nabla_x \log q\_{data}(x) \vert\vert^2 \Big]
-\end{equation}
-
-It is simply an $$L_2$$ loss between a parametric model and the true score, weighted by the probability of individual states (hence the expectation). But of course, it is not computable in this form as it contains the true score. Hyvärinen's contribution was to simply show that, theoretically, the minimization problem is equivalent when the loss function is
-
-\begin{equation} \label{eq:impl_score_match}
-J_{\mathrm{I}}(\theta) = \mathbb{E}_{x\sim q\_{data}(x)}\Big[ \mathrm{Tr}(\nabla\_x s\_{\theta}(x)) + \frac{1}{2} \vert\vert s\_{\theta}(x) \vert\vert^2 \Big]
-\end{equation}
-
-In the literature, this is known as the "_Implicit Score Matching_". The derivation is relatively simple and only involves algebraic manipulations -- please see Appendix A of . The remarkable nature of this result stems from the fact that $$J_{\mathrm{I}}$$ no longer contains the true score. The only dependency on $$q_{data}$$ is via the expectation, which can be approximated by sample average over our dataset.
-
-But the key challenge with Implicit Score Matching was the $$\mathrm{Tr}(\nabla_x s_{\theta}(x))$$ term, i.e. the trace of the hessian of the neural score model, which is costly to compute. This prompted several follow-up works for the race towards scalable score matching, one of which (namely De-noising score matching) is used in Diffusion Models till this day.
-
-For the sake of completeness, I would like to mention the work of Yang Song et al. around 2019, that proposed an engineering trick to alleviate the hessian computation. They simply used the "Hutchinson Trace estimator"A stochastic way of computing trace: $\mathrm{Tr}(M) = \mathbb{E}_{v\sim p_v} \Big[ v^T M v \Big]$, where $p_v$ can be a lot of distributions, most notably $\mathcal{N}(0, I)$. to replace the $$\mathrm{Tr}(\cdot)$$ in Eq.\eqref{eq:impl_score_match}, which eased the computation a bit. This approach however, did not end up being used in practice.
-
-### Denoising Score Matching
-
-The most valuable contribution came from Vincent Pascal in 2011, when he showed that the score matching problem has yet another equivalent objective, which was called "Denoising" score matching
-
-\begin{equation} \label{eq:deno_score_match}
-J_{\mathrm{D}}(\theta) = \mathbb{E}_{x\sim q\_{data}(x), \epsilon\sim\mathcal{N}(0, I)}\left[ \frac{1}{2} \left|\left| s\_{\theta}(\ \underbrace{x + \sigma\epsilon}\_{\tilde{x}}\ ) - (- \frac{\epsilon}{\sigma}) \right|\right|^2 \right]
-\end{equation}
-
-We deliberately wrote it in a way that exposes its widely accepted interpretation. Denoising score matching simply adds some _known_ noise $$\sigma\epsilon$$ to the datapoints $$x$$ and learns (in mean squeared sense), from the "noisy" point $$\tilde{x}$$, the direction of comeback, i.e. $$(-\epsilon)$$, scaled by $$\frac{1}{\sigma}$$. In a way, it acts like a "de-noiser", hence the name. It is theoretically guaranteed that $$J_{\mathrm{D}}$$ leads to an unbiased estimate of the true score. Below we show a visualization of the score estimate as it learns from data.
-
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/deno_score_learning.gif" class="col-10" %}
-
-
-A little algebraic manipulation of Eq.\eqref{eq:deno_score_match}, demonstrated by Ho et al. , leads to an equivalent form which turned out to be training friendly.
-
-$$\begin{eqnarray}
-J_{\mathrm{D}}(\theta) &=& \mathbb{E}_{x\sim q_{data}(x), \epsilon\sim\mathcal{N}(0, I)}\left[ \frac{1}{2\sigma^2} \left|\left| {\color{blue} - \sigma s_{\theta}}(\tilde{x}) - \epsilon \right|\right|^2 \right] \\
-&=& \mathbb{E}_{x\sim q_{data}(x), \epsilon\sim\mathcal{N}(0, I)}\left[ \frac{1}{2\sigma^2} \left|\left| {\color{blue} \epsilon}_{\theta}(\tilde{x}) - \epsilon \right|\right|^2 \right]\label{eq:deno_eps_match}
-\end{eqnarray}$$
-
-We simply change the _interpretation_ of what the network learns. In this form, the "noise estimator" network learns _just_ the original pure gaussian noise vector $$\epsilon$$ that was added while crafting the noisy sample. So, from a noisy sample, the network $$\epsilon_{\theta}$$ learns roughly an unit variance direction that points towards the clean sample.
-
-There is yet another re-interpretation of Eq.\eqref{eq:deno_score_match} that leads to a slightly different perspective
-
-$$\begin{eqnarray}
-J_{\mathrm{D}}(\theta) &=& \mathbb{E}_{x\sim q_{data}(x), \epsilon\sim\mathcal{N}(0, I)}\left[ \frac{1}{2\sigma^4} \left|\left| {\color{blue}\tilde{x} + \sigma^2 s_{\theta}}(\tilde{x}) - (\underbrace{\tilde{x} - \sigma\epsilon}_{x}) \right|\right|^2 \right] \\
-&=& \mathbb{E}_{x\sim q_{data}(x), \epsilon\sim\mathcal{N}(0, I)}\left[ \frac{1}{2\sigma^4} \left|\left| {\color{blue} x_{\theta}}(\tilde{x}) - x \right|\right|^2 \right]\label{eq:deno_endpoint_match}
-\end{eqnarray}$$
-
-Eq.\eqref{eq:deno_endpoint_match} shows, that instead of the noise direction towards clean sample, we can also have the clean sample directly as a learning target. This is like doing "denoising" in its true sense. We will get back to this in [the next subsection](#probing-the-learning-objective).
-
-### Probing the learning objective
-
-If you are still puzzled about how Eq.\eqref{eq:deno_eps_match} is related to learning the score, there is a way to probe exactly what the network is learning at an arbitrary input point $$\tilde{x}$$. We note that the clean sample $$x$$ and the noisy sample $$\tilde{x}$$ come from a joint distribution that factorizes
-
-$$
-q(x, \tilde{x}) = q(\tilde{x} \vert x) q_{data}(x) = \mathcal{N}(\tilde{x}; x, \sigma I) q_{data}(x).
-$$
-
-We then factorize this joint in a slightly different way, i.e.
-
-$$
-q(x, \tilde{x}) = q(x \vert \tilde{x}) q(\tilde{x})
-$$
-
-where $$q(x \vert \tilde{x})$$ can be thought of as a distribution of all clean samples which could've led to the given $$\tilde{x}$$. Eq.\eqref{eq:deno_eps_match} can therefore be written as
-
-$$\begin{eqnarray*}
-J_{\mathrm{D}}(\theta) &=& \mathbb{E}_{(x, \tilde{x}) \sim q(x,\tilde{x})}\left[ \frac{1}{2\sigma^2} \left|\left| \epsilon_{\theta}(\tilde{x}) - \epsilon \right|\right|^2 \right] \\
-&=& \mathbb{E}_{\tilde{x} \sim q(\tilde{x}), x \sim q(x\vert \tilde{x})}\left[ \frac{1}{2\sigma^2} \left|\left| \epsilon_{\theta}(\tilde{x}) - \frac{\tilde{x} - x}{\sigma} \right|\right|^2 \right] \\
-&=& \mathbb{E}_{\tilde{x} \sim q(\tilde{x})}\left[ \frac{1}{2\sigma^2} \left|\left| \epsilon_{\theta}(\tilde{x}) - \frac{\tilde{x} - \mathbb{E}_{x \sim q(x\vert \tilde{x})}[x]}{\sigma} \right|\right|^2 \right] \\
-\end{eqnarray*}$$
-
-In the last step, the expectation $$\mathbb{E}_{q(x\vert\tilde{x})}\left[ \cdot \right]$$ was pushed inside, up until the only quantity that involves $$x$$. Looking at it, you may realize that the network $$\epsilon_{\theta}$$, given an input $$\tilde{x}$$, learns the _average noise direction_ that leads to the given input point $$\tilde{x}$$. It also exposes the quantity $$\mathbb{E}_{x \sim q(x\vert \tilde{x})}[x]$$, which is the _average clean sample_ that led to the given $$\tilde{x}$$.
-
-Below we visualize this process with a toy example, followed by a short explanation.
-
-
-{% include figure.html path="assets/img/2024-05-07-diffusion-theory-from-scratch/probing_deno_estimation.gif" class="col-10" %}
-
-
-Explanation: We have 10 data points $$x\sim q_{data}(x)$$ in two clusters (big red dots) and we run the learning process by generating noisy samples $$\tilde{x}\sim q(\tilde{x})$$ (small red dots). Instead of learning a neural mapping over the entire space, we learn a tabular map with only three chosen input points $$\tilde{x}_1, \tilde{x}_2, \tilde{x}_3$$ (blue, magenta and green cross). Every time we sample one of thosePractically it's impossible to randomly sample a specific point. So we assume a little ball around each point. three chosen input points, we note which input data point it came from (shown by connecting a dotted line of same color) and maintain a running average (bold cross of same color) of them, i.e. which is nothing but $$\mathbb{E}_{x \sim q(x\vert \tilde{x})}[x]$$. We also show the average noise direction at each $$\tilde{x}$$, i.e. $$\frac{\tilde{x} - \mathbb{E}_{x \sim q(x\vert \tilde{x})}[x]}{\sigma}$$, with gray arrows. The gray arrows, as the training progresses, start to resemble the score estimate of the data.
-
-### Denoising as inverse problem
-
-A similar treatment, when applied on Eq.\eqref{eq:deno_endpoint_match}, yields the following
-
-$$\begin{eqnarray*}
-J_{\mathrm{D}}(\theta) &=& \mathbb{E}_{(x, \tilde{x}) \sim q(x,\tilde{x})}\left[ \frac{1}{2\sigma^4} \left|\left| {\color{blue}x_{\theta}}(\tilde{x}) - x \right|\right|^2 \right] \\
-&=& \mathbb{E}_{\tilde{x} \sim q(\tilde{x})}\left[ \frac{1}{2\sigma^4} \left|\left| {\color{blue}\tilde{x} + \sigma^2 s_{\theta}}(\tilde{x}) - \mathbb{E}_{x \sim q(x\vert \tilde{x})}[x] \right|\right|^2 \right] \\
-\end{eqnarray*}$$
-
-Notice that I brought back the original form of $$x_{\theta}(\cdot)$$ that involves the score. If we had the true score instead of an learned estimate, we would have
-
-$$
-\mathbb{E}_{x \sim q(x\vert \tilde{x})}[x] = \tilde{x} + \sigma^2 \nabla_{\tilde{x}} \log p(\tilde{x})
-$$
-
-In "Inverse problem" and Bayesian literature, this is a very well celebrated result named "_Tweedie's Formula_", first published by Robbins but credited to statistician Maurice Tweedie. This theorem is applied in the context of bayesian posterior estimation of a "true" quantity $$x$$ which we only observe through a (gaussian) noisy measurement $$\tilde{x}$$. Tweedie's formula tells us that the _posterior mean_ of the inverse problem $$q(x\vert \tilde{x})$$ can be computed without ever knowing the actually density, as long as we have access to the score at the noisy measurement.
-
-#### Summary
-
-In this section, we explored the problem of scalable score matching. We looked at the notable attempts in the literature and learned that score can be estimated from samples only. We also looked at several interpretations of the learning objective and the connections they expose.
-
-## Last few bits
-
-#### Incorporating time
-
-In the last section, we expressed and explained everything in terms of one known noise level $$\sigma$$ and the noisy sample $$\tilde{x}$$. We did so to avoid cluttering of multiple concepts that aren't necessary to explain each other. In [a previous section](#estimating-the-score-is-hard) however, we learned that the score must be estimated along every timestep of the forward process. By simply augmenting Eq.\eqref{eq:deno_score_match} with an additional time variable $$t \in \mathcal{U}[0, 1]$$ is sufficient to induce the time dependency in the score matching problem
-
-\begin{equation} \label{eq:deno_score_match_with_time}
-J_{\mathrm{D}}(\theta) = \mathbb{E}_{x_0, \epsilon, t \sim \mathcal{U}[0, 1], x_t\sim q_t(x_t\vert x_0) }\left[ \frac{1}{2} \left|\left| s\_{\theta}(x_t, t) - (- \frac{\epsilon}{\sigma_t}) \right|\right|^2 \right]
-\end{equation}
-
-.. where $$q_t(x_t \vert x_0)$$ is defined in a [previous section](#the-forward-process) and $$\sigma_t$$ is the standard deviation of it.
-
-
-#### We took an different approach
-
-We would like to highlight that, in this article, we first explored the reverse process and then showed why the forward process emerges out of necessity. Typical diffusion models papers start from a forward process specification of the form
-
-$$
-dx_t = f(t)x_t dt + g(t) {dB}_t
-$$
-
-.. and then use Anderson's SDE reversal to explain the reverse process, which also involves the score
-
-$$
-dx_t = \left[ f(t) x_t - g(t)^2 \underbrace{\nabla_{x_t} \log q_t(x_t)}_{s_{\theta}(x_t, t)} \right] dt + g(t) dB_t
-$$
-
-We argue that our approach is more "organic" in the sense that it builds up the theory _chronologically_, exploring the exact path the community went through over time.
-
-#### Conclusion
-
-In this article, we dived deep into the theoretical fundamentals of Diffusion Models, which are often ignored by practitioners. We started from the 'heart' of diffusion models, i.e. scores, and built the concepts up almost chronologically. We hope this article will serve as a conceptual guide toward understanding diffusion models from the score SDE perspective. We intentionally avoid the 'probabilistic markov model' view of diffusion since more and more works have been seen to embrace the SDE formalism.
\ No newline at end of file
diff --git a/_posts/2024-05-07-distill-example.md b/_posts/2024-05-07-distill-example.md
deleted file mode 100644
index 8cc9b578..00000000
--- a/_posts/2024-05-07-distill-example.md
+++ /dev/null
@@ -1,452 +0,0 @@
----
-layout: distill
-title: Sample Blog Post
-description: Your blog post's abstract.
- Please add your abstract or summary here and not in the main body of your text.
- Do not include math/latex or hyperlinks.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-
-authors:
- - name: Albert Einstein
- url: "https://en.wikipedia.org/wiki/Albert_Einstein"
- affiliations:
- name: IAS, Princeton
- - name: Boris Podolsky
- url: "https://en.wikipedia.org/wiki/Boris_Podolsky"
- affiliations:
- name: IAS, Princeton
- - name: Nathan Rosen
- url: "https://en.wikipedia.org/wiki/Nathan_Rosen"
- affiliations:
- name: IAS, Princeton
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-distill-example.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Equations
- - name: Images and Figures
- subsections:
- - name: Interactive Figures
- - name: Citations
- - name: Footnotes
- - name: Code Blocks
- - name: Diagrams
- - name: Tweets
- - name: Layouts
- - name: Other Typography?
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-Note: please use the table of contents as defined in the front matter rather than the traditional markdown styling.
-
-## Equations
-
-This theme supports rendering beautiful math in inline and display modes using [MathJax 3](https://www.mathjax.org/) engine.
-You just need to surround your math expression with `$$`, like `$$ E = mc^2 $$`.
-If you leave it inside a paragraph, it will produce an inline expression, just like $$ E = mc^2 $$.
-
-To use display mode, again surround your expression with `$$` and place it as a separate paragraph.
-Here is an example:
-
-$$
-\left( \sum_{k=1}^n a_k b_k \right)^2 \leq \left( \sum_{k=1}^n a_k^2 \right) \left( \sum_{k=1}^n b_k^2 \right)
-$$
-
-Note that MathJax 3 is [a major re-write of MathJax](https://docs.mathjax.org/en/latest/upgrading/whats-new-3.0.html)
-that brought a significant improvement to the loading and rendering speed, which is now
-[on par with KaTeX](http://www.intmath.com/cg5/katex-mathjax-comparison.php).
-
-
-## Images and Figures
-
-Its generally a better idea to avoid linking to images hosted elsewhere - links can break and you
-might face losing important information in your blog post.
-To include images in your submission in this way, you must do something like the following:
-
-```markdown
-{% raw %}{% include figure.html path="assets/img/2024-05-07-distill-example/iclr.png" class="img-fluid" %}{% endraw %}
-```
-
-which results in the following image:
-
-{% include figure.html path="assets/img/2024-05-07-distill-example/iclr.png" class="img-fluid" %}
-
-To ensure that there are no namespace conflicts, you must save your asset to your unique directory
-`/assets/img/2024-05-07-[SUBMISSION NAME]` within your submission.
-
-Please avoid using the direct markdown method of embedding images; they may not be properly resized.
-Some more complex ways to load images (note the different styles of the shapes/shadows):
-
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/9.jpg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/7.jpg" class="img-fluid rounded z-depth-1" %}
-
-
-
- A simple, elegant caption looks good between image rows, after each row, or doesn't have to be there at all.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/8.jpg" class="img-fluid z-depth-2" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/10.jpg" class="img-fluid z-depth-2" %}
-
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/11.jpg" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/12.jpg" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/7.jpg" class="img-fluid" %}
-
-
-
-### Interactive Figures
-
-Here's how you could embed interactive figures that have been exported as HTML files.
-Note that we will be using plotly for this demo, but anything built off of HTML should work
-(**no extra javascript is allowed!**).
-All that's required is for you to export your figure into HTML format, and make sure that the file
-exists in the `assets/html/[SUBMISSION NAME]/` directory in this repository's root directory.
-To embed it into any page, simply insert the following code anywhere into your page.
-
-```markdown
-{% raw %}{% include [FIGURE_NAME].html %}{% endraw %}
-```
-
-For example, the following code can be used to generate the figure underneath it.
-
-```python
-import pandas as pd
-import plotly.express as px
-
-df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/earthquakes-23k.csv')
-
-fig = px.density_mapbox(
- df, lat='Latitude', lon='Longitude', z='Magnitude', radius=10,
- center=dict(lat=0, lon=180), zoom=0, mapbox_style="stamen-terrain")
-fig.show()
-
-fig.write_html('./assets/html/2024-05-07-distill-example/plotly_demo_1.html')
-```
-
-And then include it with the following:
-
-```html
-{% raw %}
-
-
{% endraw %}
-```
-
-Voila!
-
-
-
-
-
-## Citations
-
-Citations are then used in the article body with the `` tag.
-The key attribute is a reference to the id provided in the bibliography.
-The key attribute can take multiple ids, separated by commas.
-
-The citation is presented inline like this: (a number that displays more information on hover).
-If you have an appendix, a bibliography is automatically created and populated in it.
-
-Distill chose a numerical inline citation style to improve readability of citation dense articles and because many of the benefits of longer citations are obviated by displaying more information on hover.
-However, we consider it good style to mention author last names if you discuss something at length and it fits into the flow well — the authors are human and it’s nice for them to have the community associate them with their work.
-
-***
-
-## Footnotes
-
-Just wrap the text you would like to show up in a footnote in a `` tag.
-The number of the footnote will be automatically generated.This will become a hoverable footnote.
-
-***
-
-## Code Blocks
-
-This theme implements a built-in Jekyll feature, the use of Rouge, for syntax highlighting.
-It supports more than 100 languages.
-This example is in C++.
-All you have to do is wrap your code in a liquid tag:
-
-{% raw %}
-{% highlight c++ linenos %} code code code {% endhighlight %}
-{% endraw %}
-
-The keyword `linenos` triggers display of line numbers. You can try toggling it on or off yourself below:
-
-{% highlight c++ %}
-
-int main(int argc, char const \*argv[])
-{
-string myString;
-
- cout << "input a string: ";
- getline(cin, myString);
- int length = myString.length();
-
- char charArray = new char * [length];
-
- charArray = myString;
- for(int i = 0; i < length; ++i){
- cout << charArray[i] << " ";
- }
-
- return 0;
-}
-
-{% endhighlight %}
-
-***
-
-## Diagrams
-
-This theme supports generating various diagrams from a text description using [jekyll-diagrams](https://github.com/zhustec/jekyll-diagrams){:target="\_blank"} plugin.
-Below, we generate a few examples of such diagrams using languages such as [mermaid](https://mermaid-js.github.io/mermaid/){:target="\_blank"}, [plantuml](https://plantuml.com/){:target="\_blank"}, [vega-lite](https://vega.github.io/vega-lite/){:target="\_blank"}, etc.
-
-**Note:** different diagram-generation packages require external dependencies to be installed on your machine.
-Also, be mindful of that because of diagram generation the first time you build your Jekyll website after adding new diagrams will be SLOW.
-For any other details, please refer to [jekyll-diagrams](https://github.com/zhustec/jekyll-diagrams){:target="\_blank"} README.
-
-**Note:** This is not supported for local rendering!
-
-The diagram below was generated by the following code:
-
-{% raw %}
-```
-{% mermaid %}
-sequenceDiagram
- participant John
- participant Alice
- Alice->>John: Hello John, how are you?
- John-->>Alice: Great!
-{% endmermaid %}
-```
-{% endraw %}
-
-{% mermaid %}
-sequenceDiagram
-participant John
-participant Alice
-Alice->>John: Hello John, how are you?
-John-->>Alice: Great!
-{% endmermaid %}
-
-***
-
-## Tweets
-
-An example of displaying a tweet:
-{% twitter https://twitter.com/rubygems/status/518821243320287232 %}
-
-An example of pulling from a timeline:
-{% twitter https://twitter.com/jekyllrb maxwidth=500 limit=3 %}
-
-For more details on using the plugin visit: [jekyll-twitter-plugin](https://github.com/rob-murray/jekyll-twitter-plugin)
-
-***
-
-## Blockquotes
-
-
- We do not grow absolutely, chronologically. We grow sometimes in one dimension, and not in another, unevenly. We grow partially. We are relative. We are mature in one realm, childish in another.
- —Anais Nin
-
-
-***
-
-
-## Layouts
-
-The main text column is referred to as the body.
-It is the assumed layout of any direct descendants of the `d-article` element.
-
-
-
.l-body
-
-
-For images you want to display a little larger, try `.l-page`:
-
-
-
.l-page
-
-
-All of these have an outset variant if you want to poke out from the body text a little bit.
-For instance:
-
-
-
.l-body-outset
-
-
-
-
.l-page-outset
-
-
-Occasionally you’ll want to use the full browser width.
-For this, use `.l-screen`.
-You can also inset the element a little from the edge of the browser by using the inset variant.
-
-
-
.l-screen
-
-
-
.l-screen-inset
-
-
-The final layout is for marginalia, asides, and footnotes.
-It does not interrupt the normal flow of `.l-body`-sized text except on mobile screen sizes.
-
-
-
.l-gutter
-
-
-***
-
-## Other Typography?
-
-Emphasis, aka italics, with *asterisks* (`*asterisks*`) or _underscores_ (`_underscores_`).
-
-Strong emphasis, aka bold, with **asterisks** or __underscores__.
-
-Combined emphasis with **asterisks and _underscores_**.
-
-Strikethrough uses two tildes. ~~Scratch this.~~
-
-1. First ordered list item
-2. Another item
-⋅⋅* Unordered sub-list.
-1. Actual numbers don't matter, just that it's a number
-⋅⋅1. Ordered sub-list
-4. And another item.
-
-⋅⋅⋅You can have properly indented paragraphs within list items. Notice the blank line above, and the leading spaces (at least one, but we'll use three here to also align the raw Markdown).
-
-⋅⋅⋅To have a line break without a paragraph, you will need to use two trailing spaces.⋅⋅
-⋅⋅⋅Note that this line is separate, but within the same paragraph.⋅⋅
-⋅⋅⋅(This is contrary to the typical GFM line break behavior, where trailing spaces are not required.)
-
-* Unordered lists can use asterisks
-- Or minuses
-+ Or pluses
-
-[I'm an inline-style link](https://www.google.com)
-
-[I'm an inline-style link with title](https://www.google.com "Google's Homepage")
-
-[I'm a reference-style link][Arbitrary case-insensitive reference text]
-
-[I'm a relative reference to a repository file](../blob/master/LICENSE)
-
-[You can use numbers for reference-style link definitions][1]
-
-Or leave it empty and use the [link text itself].
-
-URLs and URLs in angle brackets will automatically get turned into links.
-http://www.example.com or and sometimes
-example.com (but not on Github, for example).
-
-Some text to show that the reference links can follow later.
-
-[arbitrary case-insensitive reference text]: https://www.mozilla.org
-[1]: http://slashdot.org
-[link text itself]: http://www.reddit.com
-
-Here's our logo (hover to see the title text):
-
-Inline-style:
-![alt text](https://github.com/adam-p/markdown-here/raw/master/src/common/images/icon48.png "Logo Title Text 1")
-
-Reference-style:
-![alt text][logo]
-
-[logo]: https://github.com/adam-p/markdown-here/raw/master/src/common/images/icon48.png "Logo Title Text 2"
-
-Inline `code` has `back-ticks around` it.
-
-```javascript
-var s = "JavaScript syntax highlighting";
-alert(s);
-```
-
-```python
-s = "Python syntax highlighting"
-print(s)
-```
-
-```
-No language indicated, so no syntax highlighting.
-But let's throw in a tag.
-```
-
-Colons can be used to align columns.
-
-| Tables | Are | Cool |
-| ------------- |:-------------:| -----:|
-| col 3 is | right-aligned | $1600 |
-| col 2 is | centered | $12 |
-| zebra stripes | are neat | $1 |
-
-There must be at least 3 dashes separating each header cell.
-The outer pipes (|) are optional, and you don't need to make the
-raw Markdown line up prettily. You can also use inline Markdown.
-
-Markdown | Less | Pretty
---- | --- | ---
-*Still* | `renders` | **nicely**
-1 | 2 | 3
-
-> Blockquotes are very handy in email to emulate reply text.
-> This line is part of the same quote.
-
-Quote break.
-
-> This is a very long line that will still be quoted properly when it wraps. Oh boy let's keep writing to make sure this is long enough to actually wrap for everyone. Oh, you can *put* **Markdown** into a blockquote.
-
-
-Here's a line for us to start with.
-
-This line is separated from the one above by two newlines, so it will be a *separate paragraph*.
-
-This line is also a separate paragraph, but...
-This line is only separated by a single newline, so it's a separate line in the *same paragraph*.
diff --git a/_posts/2024-05-07-distill-example2.html b/_posts/2024-05-07-distill-example2.html
deleted file mode 100644
index d291070c..00000000
--- a/_posts/2024-05-07-distill-example2.html
+++ /dev/null
@@ -1,442 +0,0 @@
----
-layout: distill
-title: Sample Blog Post (HTML version)
-description: Your blog post's abstract.
- Please add your abstract or summary here and not in the main body of your text.
- Do not include math/latex or hyperlinks.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-
-authors:
- - name: Albert Einstein
- url: "https://en.wikipedia.org/wiki/Albert_Einstein"
- affiliations:
- name: IAS, Princeton
- - name: Boris Podolsky
- url: "https://en.wikipedia.org/wiki/Boris_Podolsky"
- affiliations:
- name: IAS, Princeton
- - name: Nathan Rosen
- url: "https://en.wikipedia.org/wiki/Nathan_Rosen"
- affiliations:
- name: IAS, Princeton
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-distill-example.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Equations
- - name: Images and Figures
- subsections:
- - name: Interactive Figures
- - name: Citations
- - name: Footnotes
- - name: Code Blocks
- - name: Diagrams
- - name: Tweets
- - name: Layouts
- - name: Other Typography?
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;\
- }
----
-
-
- This is a sample blog post written in HTML (while the other sample post is written in Markdown). Authors have the choice to write in HTML or Markdown. While Markdown is easier to write, HTML gives you more control over the layout of your post. Furthermore, Markdown often interacts in unexpected ways with MathJax and other HTML widgets. If you are having trouble with Markdown, try writing in HTML instead.
-
-
-
- Note: please use the table of contents as defined in the front matter rather than the traditional markdown styling.
-
-
-
Equations
-
-
This theme supports rendering beautiful math in inline and display modes using MathJax 3 engine.
-You just need to surround your math expression with $$, like $$ E = mc^2 $$.
-If you leave it inside a paragraph, it will produce an inline expression, just like \(E = mc^2\).
-
-
To use display mode, again surround your expression with $$ and place it as a separate paragraph.
-Here is an example:
-$$
-\left( \sum_{k=1}^n a_k b_k \right)^2 \leq \left( \sum_{k=1}^n a_k^2 \right) \left( \sum_{k=1}^n b_k^2 \right)
-$$
-
Its generally a better idea to avoid linking to images hosted elsewhere - links can break and you
-might face losing important information in your blog post.
-You can display images from this repository using the following code:
-
-
{% raw %}{% include figure.html path="assets/img/2024-05-07-distill-example/iclr.png" class="img-fluid" %}{% endraw %}
-
-
which results in the following image:
-
-{% include figure.html path="assets/img/2024-05-07-distill-example/iclr.png" class="img-fluid" %}
-
-
-
- To ensure that there are no namespace conflicts, you must save your asset to your unique directory
- `/assets/img/2024-05-07-[SUBMISSION NAME]` within your submission.
-
-
-
- Please avoid using the direct HTML method of embedding images; they may not be properly resized.
- Some below complex ways to load images (note the different styles of the shapes/shadows):
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/9.jpg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/7.jpg" class="img-fluid rounded z-depth-1" %}
-
-
-
- A simple, elegant caption looks good between image rows, after each row, or doesn't have to be there at all.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/8.jpg" class="img-fluid z-depth-2" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/10.jpg" class="img-fluid z-depth-2" %}
-
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/11.jpg" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/12.jpg" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-distill-example/7.jpg" class="img-fluid" %}
-
-
-
-
Interactive Figures
-
-
- Here's how you could embed interactive figures that have been exported as HTML files.
- Note that we will be using plotly for this demo, but anything built off of HTML should work.
- All that's required is for you to export your figure into HTML format, and make sure that the file
- exists in the `assets/html/[SUBMISSION NAME]/` directory in this repository's root directory.
- To embed it into any page, simply insert the following code anywhere into your page.
-
-
-
{% raw %}{% include [FIGURE_NAME].html %}{% endraw %}
-
-
-For example, the following code can be used to generate the figure underneath it.
-
- Citations are then used in the article body with the <d-cite> tag.
- The key attribute is a reference to the id provided in the bibliography.
- The key attribute can take multiple ids, separated by commas.
-
-
-
- The citation is presented inline like this: (a number that displays more information on hover).
- If you have an appendix, a bibliography is automatically created and populated in it.
-
-
-
- Distill chose a numerical inline citation style to improve readability of citation dense articles and because many of the benefits of longer citations are obviated by displaying more information on hover.
- However, we consider it good style to mention author last names if you discuss something at length and it fits into the flow well - the authors are human and it's nice for them to have the community associate them with their work.
-
-
-
-
Footnotes
-
-
- Just wrap the text you would like to show up in a footnote in a <d-footnote> tag.
- The number of the footnote will be automatically generated.This will become a hoverable footnote.
-
-
-
-
Code Blocks
-
-
- This theme implements a built-in Jekyll feature, the use of Rouge, for syntax highlighting.
- It supports more than 100 languages.
- This example is in C++.
- All you have to do is wrap your code in a liquid tag as follows:
-
-
-
{% raw %}
-{% highlight c++ linenos %} code code code {% endhighlight %}
-{% endraw %}
-
-
-The keyword `linenos` triggers display of line numbers. You can try toggling it on or off yourself below:
-
-{% highlight c++ %}
-
-int main(int argc, char const *argv[])
-{
-string myString;
-
- cout << "input a string: ";
- getline(cin, myString);
- int length = myString.length();
-
- char charArray = new char * [length];
-
- charArray = myString;
- for(int i = 0; i < length; ++i){
- cout << charArray[i] << " ";
- }
-
- return 0;
-}
-
-{% endhighlight %}
-
-
-
-
Diagrams
-
-
- This theme supports generating various diagrams from a text description using jekyll-diagrams plugin.
- Below, we generate a few examples of such diagrams using languages such as mermaid, plantuml, vega-lite, etc.
-
-
-
- Notedifferent diagram-generation packages require external dependencies to be installed on your machine.
- Also, be mindful of that because of diagram generation the first time you build your Jekyll website after adding new diagrams will be SLOW.
- For any other details, please refer to the jekyll-diagrams README.
-
-
-
- Note: This is not supported for local rendering!
-
-
-
- The diagram below was generated by the following code:
-
-
-
{% raw %}{% mermaid %}
-sequenceDiagram
- participant John
- participant Alice
- Alice->>John: Hello John, how are you?
- John-->>Alice: Great!
-{% endmermaid %}
-{% endraw %}
-
-
-{% mermaid %}
-sequenceDiagram
-participant John
-participant Alice
-Alice->>John: Hello John, how are you?
-John-->>Alice: Great!
-{% endmermaid %}
-
-
-
Tweets
-
-
- An example of displaying a tweet:
- {% twitter https://twitter.com/rubygems/status/518821243320287232 %}
-
-
-
- An example of pulling from a timeline:
- {% twitter https://twitter.com/jekyllrb maxwidth=500 limit=3 %}
-
- We do not grow absolutely, chronologically. We grow sometimes in one dimension, and not in another, unevenly. We grow partially. We are relative. We are mature in one realm, childish in another.
- —Anais Nin
-
-
-
-
Layouts
-
-The main text column is referred to as the body.
-It's the assumed layout of any direct descendants of the `d-article` element.
-
-
-
.l-body
-
-
-For images you want to display a little larger, try `.l-page`:
-
-
-
.l-page
-
-
-All of these have an outset variant if you want to poke out from the body text a little bit.
-For instance:
-
-
-
.l-body-outset
-
-
-
-
.l-page-outset
-
-
-Occasionally you'll want to use the full browser width.
-For this, use `.l-screen`.
-You can also inset the element a little from the edge of the browser by using the inset variant.
-
-
-
.l-screen
-
-
-
.l-screen-inset
-
-
-The final layout is for marginalia, asides, and footnotes.
-It does not interrupt the normal flow of `.l-body`-sized text except on mobile screen sizes.
-
-
-
.l-gutter
-
-
-
-
Other Typography?
-
-
- Emphasis, aka italics, with the <i></i> tag emphasis.
-
-
-
- Strong emphasis, aka bold, with <b></b> tag bold.
-
-
-
- Strikethrough ca be accomplished with the <s></s> tag. Scratch this.
-
-
-
-
First ordered list item
-
Another item
-
-
Unordered sub-list.
-
-
And another item.
-
-
-
-
-
- For code, the language can be specified in the class. For example, use language-javascript for Javascript and language-python for Python code.
-
-
-
var s = "JavaScript syntax highlighting";
- alert(s);
-
-
s = "Python syntax highlighting"
- print(s)
-
-
No language indicated, so no syntax highlighting.
-
-
- A table can be created with the <table> element. Below is an example
-
-
-
-
-
-
Tables
-
Are
-
Cool
-
-
-
-
-
col 3 is
-
right-aligned
-
$1600
-
-
-
col 2 is
-
centered
-
$12
-
-
-
zebra stripes
-
are neat
-
$1
-
-
-
-
-
-
-
Blockquotes can be defined with the >blockquote< tag.
-
\ No newline at end of file
diff --git a/_posts/2024-05-07-double-descent-demystified.md b/_posts/2024-05-07-double-descent-demystified.md
deleted file mode 100644
index d78ba984..00000000
--- a/_posts/2024-05-07-double-descent-demystified.md
+++ /dev/null
@@ -1,736 +0,0 @@
----
-layout: distill
-title: Double Descent Demystified
-description: Identifying, Interpreting & Ablating the Sources of a Deep Learning Puzzle
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Rylan Schaeffer
- url: "https://scholar.google.com/citations?user=6tMEGz8AAAAJ&hl=en"
- affiliations:
- name: Stanford University
- - name: Zachary Robertson
- url: "https://scholar.google.com/citations?user=769PIisAAAAJ&hl=en&oi=ao"
- affiliations:
- name: Stanford University
- - name: Akhilan Boopathy
- url: "https://scholar.google.com/citations?user=21alU7EAAAAJ&hl=en"
- affiliations:
- name: MIT
- - name: Mikail Khona
- url: "https://scholar.google.com/citations?user=K5f0SYQAAAAJ&hl=en&oi=ao"
- affiliations:
- name: MIT
- - name: Kateryna Pistunova
- url: "https://scholar.google.com/citations?user=V7QY5j0AAAAJ&hl=en"
- affiliations:
- name: Stanford University
- - name: Jason W. Rocks
- url: "https://scholar.google.com/citations?user=rFHAzMUAAAAJ"
- affiliations:
- name: Boston University
- - name: Ila R. Fiete
- url: "https://scholar.google.com/citations?user=uE-CihIAAAAJ&hl=en&oi=ao"
- affiliations:
- name: MIT
- - name: Andrey Gromov
- url: "https://scholar.google.com/citations?user=D056qfMAAAAJ&hl=en&oi=ao"
- affiliations:
- name: UMD & Meta AI FAIR
- - name: Sanmi Koyejo
- url: "https://scholar.google.com/citations?user=EaaOeJwAAAAJ&hl=en&oi=ao"
- affiliations:
- name: Stanford University
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-double-descent-demystified.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Introduction
- - name: Double Descent in Ordinary Linear Regression
- subsections:
- - name: Empirical Evidence
- - name: Notation and Terminology
- - name: Mathematical Analysis
- - name: Factor 1 - Low Variance in Training Features
- - name: Factor 2 - Test Features in Training Feature Subspace
- - name: Factor 3 - Errors from Best Possible Model
- - name: Divergence at the Interpolation Threshold
- - name: Generalization in Overparameterized Linear Regression
- - name: Adversarial Data
- subsections:
- - name: Adversarial Test Examples
- - name: Adversarial Training Data
- - name: Intuition for Nonlinear Models
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-## Introduction
-
-Machine learning models, while incredibly powerful, can sometimes act unpredictably. One of the most intriguing
-behaviors is when the test loss suddenly diverges at the interpolation threshold, a phenomenon
-distinctly observed in **double descent** .
-
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/california_housing/unablated.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/diabetes/unablated.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/student_teacher/unablated.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/who_life_expectancy/unablated.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 1. Double descent in ordinary linear regression.
- Three real datasets (California Housing, Diabetes, and WHO Life Expectancy) and one synthetic dataset (Student-Teacher) all exhibit double descent,
- with test loss spiking at the interpolation threshold.
- Blue is training error.Orange is test error.
-
-
-
-While significant theoretical work has been done to comprehend why double descent occurs, it can be difficult
-for a newcomer to gain a general understanding of why the test loss behaves in this manner, and under what conditions
-one should expect similar misbehavior. In this blog post, when we say double descent, we mean the divergence at the interpolation
-threshold, and not whether overparameterized models generalize (or fail to generalize).
-
-In this work, we intuitively and quantitatively explain why the test loss diverges at the interpolation threshold,
-with as much generality as possible and with as simple of mathematical machinery as possible, but also without sacrificing rigor.
-To accomplish this, we focus on the simplest supervised model - ordinary linear regression - using the most
-basic linear algebra primitive: the singular value decomposition. We identify three distinct interpretable
-factors which, when collectively present, trigger the divergence.
-Through practical experiments on real data sets, we confirm that both model's test losses diverge at the
-interpolation threshold, and this divergence vanishes when even one of the three factors is removed.
-We complement our understanding by offering a geometric picture that reveals linear models perform
-representation learning when overparameterized, and conclude by shedding light on recent results in nonlinear
-models concerning superposition.
-
-
-## Double Descent in Ordinary Linear Regression
-
-### Empirical Evidence of Double Descent in Ordinary Linear Regression
-
-
-
-Before studying ordinary linear regression mathematically, does our claim that it exhibits double descent
-hold empirically? We show that it indeed does, using one synthetic and three real datasets:
-World Health Organization Life Expectancy , California Housing , Diabetes ;
-these three real datasets were selected on the basis of being easily accessible through sklearn or Kaggle.
-As shown in [Fig 1](#fig_unablated_all), all display a spike in test mean squared error at the interpolation threshold. Our simple Python code is [publicly available]().
-
-
-
-### Notation and Terminology
-
-Consider a regression dataset of $N$ training data with features $\vec{x}_n \in \mathbb{R}^D$ and targets $y_n \in \mathbb{R}$.
-We sometimes use matrix-vector notation to refer to the training data:
-
-$$X \in \mathbb{R}^{N \times D} \quad , \quad Y \in \mathbb{R}^{N \times 1}.$$
-
-In ordinary linear regression, we want to learn parameters $\hat{\vec{\beta}} \in \mathbb{R}^{D}$ such that:
-
-$$\vec{x}_n \cdot \hat{\vec{\beta}} \approx y_n.$$
-
-We will study three key parameters:
-1. The number of model parameters $P$
-2. The number of training data $N$
-3. The dimensionality of the data $D$
-
-We say that a model is _overparameterized_ if $N < P$ and _underparameterized_ if $N > P$.
-The _interpolation threshold_ refers to $N=P$, because when $N\leq P$, the model can perfectly interpolate the training points.
-Recall that in ordinary linear regression, the number of parameters $P$ equals the dimension $D$ of the covariates.
-Consequently, rather than thinking about changing the number of parameters $P$, we'll instead think about changing
-the number of data points $N$.
-
-
-### Mathematical Analysis of Ordinary Linear Regression
-
-To understand under what conditions and why double descent occurs at the interpolation threshold in linear regression,
-we'll study the two parameterization regimes.
-If the regression is _underparameterized_, we estimate the linear relationship between covariates $\vec{x}_n$
-and target $y_n$ by solving the least-squares minimization problem:
-
-
-$$
-\begin{align*}
-\hat{\vec{\beta}}_{under} \, &:= \, \arg \min_{\vec{\beta}} \frac{1}{N} \sum_n ||\vec{x}_n \cdot \vec{\beta} - y_n||_2^2\\
-\, &:= \, \arg \min_{\vec{\beta}} ||X \vec{\beta} - Y ||_2^2.
-\end{align*}
-$$
-
-The solution is the ordinary least squares estimator based on the second moment matrix $X^T X$:
-
-$$\hat{\vec{\beta}}_{under} = (X^T X)^{-1} X^T Y.$$
-
-If the model is overparameterized, the optimization problem is ill-posed since we have fewer constraints than parameters.
-Consequently, we choose a different (constrained) optimization problem that asks for the minimum norm parameters that
-still perfectly interpolate the training data:
-
-
-$$
-\begin{align*}
-\hat{\vec{\beta}}_{over} \, &:= \, \arg \min_{\vec{\beta}} ||\vec{\beta}||_2^2\\
-\text{s.t.} \quad \quad \forall \, n \in &\{1, ..., N\}, \quad \vec{x}_n \cdot \vec{\beta} = y_n.
-\end{align*}
-$$
-
-We choose this optimization problem because it is the one gradient descent implicitly minimizes.
-The solution to this optimization problem uses the Gram matrix $X X^T \in \mathbb{R}^{N \times N}$:
-
-$$\hat{\vec{\beta}}_{over} = X^T (X X^T)^{-1} Y.$$
-
-One way to see why the Gram matrix appears is via constrained optimization: define the Lagrangian
-$\mathcal{L}(\vec{\beta}, \vec{\lambda}) \, := \, \frac{1}{2}||\vec{\beta}||_2^2 + \vec{\lambda}^T (Y - X \vec{\beta})$
-with Lagrange multipliers $\vec{\lambda} \in \mathbb{R}^N$, then differentiate with respect to the parameters
-and Lagrange multipliers to obtain the overparameterized solution.
-
-After being fit, for test point $\vec{x}_{test}$, the model will make the following predictions:
-
-$$\hat{y}_{test, under} = \vec{x}_{test} \cdot \hat{\vec{\beta}}_{under} = \vec{x}_{test} \cdot (X^T X)^{-1} X^T Y$$
-
-
-$$\hat{y}_{test, over} = \vec{x}_{test} \cdot \hat{\vec{\beta}}_{over} = \vec{x}_{test} \cdot X^T (X X^T)^{-1} Y.$$
-
-
-
-Hidden in the above equations is an interaction between three quantities that can, when all grow extreme, create a
-divergence in the test loss!
-
-To reveal the three quantities, we'll rewrite the regression targets by introducing
-a slightly more detailed notation. Unknown to us, there are some ideal linear parameters
-$\vec{\beta}^* \in \mathbb{R}^P = \mathbb{R}^D$ that truly minimize the test mean squared error.
-We can write any regression target as the inner product of the data $\vec{x}_n$ and the ideal parameters $\vec{\beta}^*$,
-plus an additional error term $e_n$ that is an
-"uncapturable" residual from the "viewpoint" of the model class
-
-$$y_n = \vec{x}_n \cdot \vec{\beta}^* + e_n.$$
-
-In matrix-vector form, we will equivalently write:
-
-$$Y = X \vec{\beta}^* + E,$$
-
-with $E \in \mathbb{R}^{N \times 1}$.
-To be clear, we are _not_ imposing assumptions. Rather, we are introducing notation to express that
-there are (unknown) ideal linear parameters, and possibly non-zero errors $E$ that even the ideal model might
-be unable to capture; these errors $E$ could be random noise or could be fully deterministic patterns that this
-particular model class cannot capture. Using this new notation, we rewrite the model's predictions to show how
-the test datum's features $\vec{x}_{test}$,
-training data's features $X$ and training data's regression targets $Y$ interact.
-
-Let $y_{test}^* := \vec{x}_{test} \cdot \vec{\beta}^*$. In the underparameterized regime:
-
-$$
-\begin{align*}
-\hat{y}_{test,under} &= \vec{x}_{test} \cdot \hat{\vec{\beta}}_{under}\\
-&=\vec{x}_{test} \cdot (X^T X)^{-1} X^T Y\\
-&=\vec{x}_{test} \cdot (X^T X)^{-1} X^T (X \vec{\beta}^* + E)\\
-&=\vec{x}_{test} \cdot \vec{\beta}^* + \, \vec{x}_{test} \cdot (X^T X)^{-1} X^T E\\
-\hat{y}_{test,under} - y_{test}^* &= \vec{x}_{test} \cdot (X^T X)^{-1} X^T E.
-\end{align*}
-$$
-
-This equation is important, but opaque. To extract the intuition,
-replace $X$ with its singular value decomposition $X = U S V^T$.
-Let $R \, := \, \text{rank}(X)$ and let $\sigma_1 > \sigma_2 > ... > \sigma_R > 0$ be
-$X$'s (non-zero) singular values. Let $S^+$ denote the [Moore-Penrose inverse](https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse);
-in this context, this means that if a singular value $\sigma_r$ is non-zero, then in $S^+$, it becomes its reciprocal
-$1/\sigma_r$, but if the singular value is zero, then in $S^+$, it remains $0$.
-We can decompose the underparameterized prediction error
-along the orthogonal singular modes:
-
-$$
-\begin{align*}
-\hat{y}_{test, under} - y_{test}^* &= \vec{x}_{test} \cdot V S^{+} U^T E\\
-&= \sum_{r=1}^R \frac{1}{\sigma_r} (\vec{x}_{test} \cdot \vec{v}_r) (\vec{u}_r \cdot E).
-\end{align*}
-$$
-
-This equation will be critical! The same term will appear in the overparameterized regime (plus one additional term):
-
-$$
-\begin{align*}
-\hat{y}_{test,over} &= \vec{x}_{test} \cdot \hat{\vec{\beta}}_{over}\\
-&= \vec{x}_{test} \cdot X^T (X X^T)^{-1} Y\\
-&= \vec{x}_{test} \cdot X^T (X X^T)^{-1} (X \beta^* + E)\\
-\hat{y}_{test,over} - y_{test}^* &= \vec{x}_{test} \cdot (X^T (X X^T)^{-1} X - I_D) \beta^* \\
-&\quad\quad + \quad \vec{x}_{test} \cdot X^T (X X^T)^{-1} E\\
- &= \vec{x}_{test} \cdot (X^T (X X^T)^{-1} X - I_D) \beta^* \\
-&\quad\quad + \quad \sum_{r=1}^R \frac{1}{\sigma_r} (\vec{x}_{test} \cdot \vec{v}_r) (\vec{u}_r \cdot E),
-\end{align*}
-$$
-
-where the last step again replaced $X$ with its SVD $X = U S V^T$. Thus, the prediction errors
-in the overparameterized and underparameterized regimes will be:
-
-$$
-\begin{align*}
-\hat{y}_{test,over} - y_{test}^* &= \sum_{r=1}^R \frac{1}{\sigma_r} (\vec{x}_{test} \cdot \vec{v}_r) (\vec{u}_r \cdot E)\\
-&\quad \quad + \quad \vec{x}_{test} \cdot (X^T (X X^T)^{-1} X - I_D) \beta^*\\
-\hat{y}_{test,under} - y_{test}^* &= \sum_{r=1}^R \frac{1}{\sigma_r} (\vec{x}_{test} \cdot \vec{v}_r) (\vec{u}_r \cdot E).
-\end{align*}
-$$
-
-The shared term in the two prediction errors causes the divergence:
-
-$$
-\begin{equation}
-\sum_{r=1}^R \frac{1}{\sigma_r} (\vec{x}_{test} \cdot \vec{v}_r) (\vec{u}_r \cdot E).
-\label{eq:variance}
-\end{equation}
-$$
-
-Eqn. \ref{eq:variance} is critical. It reveals that our test prediction error (and thus, our
-test squared error!) will depend on an interaction between 3 quantities:
-
-1. How much the training features vary in each direction.
-More formally, the inverse (non-zero) singular values of the _training features_ $X$:
-
- $$\frac{1}{\sigma_r}$$
-
-2. How much, and in which directions, the test features vary relative to the training features.
-More formally: how $\vec{x}_{test}$ projects onto $X$'s right singular vectors $V$:
-
- $$\vec{x}_{test} \cdot \vec{v}_r$$
-
-3. How well the best possible model in the model class can correlate the variance in the training features with the training regression targets.
-More formally: how the residuals $E$ of the best possible model in the model class (i.e. insurmountable "errors" from the "perspective" of the model class) project onto $X$'s left singular vectors $U$:
-
- $$\vec{u}_r \cdot E$$
-
-We use the term "vary" when discussing $\vec{v}_r$ because $V$ can be related to the empirical (or sample) covariance
-matrix oftentimes studied in Principal Component Analysis. That is, if the SVD of $X$ is $U S V^T$, then
-$\frac{1}{N} X^T X = \frac{1}{N} V S^2 V^T$. If the training data are centered
-(a common preprocessing step), then this is the empirical covariance
-matrix and its eigenvectors $\vec{v}_1, ..., \vec{v}_R$ identify the orthogonal directions of variance. We'll return
-to this in [Fig 6](#fig_geometric_smallest_nonzero_singular_value).
-
-**Why does the test error diverge?** When (1) and (3) are both present in the learning problem, the model's
-parameters along this singular mode are likely incorrect.
-When (2) is added to the mix by a test datum $\vec{x}_{test}$ with a large projection along this mode,
-the model is forced to extrapolate significantly beyond what it saw in the training data, in a direction where
-the training data had an error-prone relationship between its predictions and the training targets, using
-parameters that are likely wrong. As a consequence, the test squared error explodes!
-
-### Factor 1 - Low Variance in Training Features
-
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/california_housing/no_small_singular_values.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/diabetes/no_small_singular_values.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/student_teacher/no_small_singular_values.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/who_life_expectancy/no_small_singular_values.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 2. Required Factor #1: How much training features vary in each direction.
- The test loss diverges at the interpolation threshold only if training features $X$ contain small (non-zero)
- singular values. Ablation: By removing all singular values below a cutoff, the divergence at the interpolation threshold is diminished or disappears entirely.
- Blue is training error.Orange is test error.
-
-
-
-The test loss will not diverge if any of the three required factors are absent. What could cause that?
-One way is if small-but-nonzero singular values do not appear in the training data features. One way to
-accomplish this is by setting all singular values below a selected threshold to exactly 0. To test our understanding,
-we independently ablate all small singular values in the training features. Specifically, as we run the
-ordinary linear regression fitting process, and as we sweep the number of training data, we also sweep different
-singular value cutoffs and remove all singular values of the training features $X$ below the cutoff ([Fig 2](#fig_factor_1_small_singular_values)).
-
-### Factor 2 - Test Features in Training Feature Subspace
-
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/california_housing/test_feat_in_train_feat_subspace.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/diabetes/test_feat_in_train_feat_subspace.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/student_teacher/test_feat_in_train_feat_subspace.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/who_life_expectancy/test_feat_in_train_feat_subspace.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 3. Required Factor #2: How much, and in which directions, test features vary relative to training features.
- The test loss diverges only if the test features $\vec{x}_{test}$ have a large projection onto the training
- features $X$'s right singular vectors $V$. Ablation: By projecting the test features into the subspace of the
- leading singular modes, the divergence at the interpolation threshold is diminished or disappears entirely.
- Blue is training error.Orange is test error.
-
-
-
-Double descent should not occur if the test datum does not vary in different directions than the training features.
-Specifically, if the test datum lies entirely in the subspace of just a few of the leading singular directions, then the divergence is unlikely to occur.
-To test our understanding, we force the test data features to lie in the training features subspace: as we run the
-ordinary linear regression fitting process, and as we sweep the number of training data, we project the test features
-$\vec{x}_{test}$ onto the subspace spanned by the training features $X$ singular modes ([Fig 3](#fig_test_feat_in_train_feat_subspace)).
-
-
-### Factor 3 - Errors from Best Possible Model
-
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/california_housing/no_residuals_in_ideal.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/diabetes/no_residuals_in_ideal.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/student_teacher/no_residuals_in_ideal.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/who_life_expectancy/no_residuals_in_ideal.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 4. Required Factor #3: How well the best possible model in the model class can correlate variance in training
- features with training targets. The test loss diverges only if the residuals $E$ from the best possible model
- in the model class on the training data have a large projection onto the training features $X$'s left singular
- vectors $U$. Ablation: By ensuring the true relationship between features and targets is within the model class
- i.e. linear, the divergence at the interpolation threshold disappears.
- Blue is training error.Orange is test error.
-
-
-
-Double descent should not occur if the best possible model in the model class makes no errors on the training data.
-For example, if we use a linear model class on data where the true relationship is a noiseless linear relationship,
-then at the interpolation threshold, we will have $D=P$ data, $P=D$ parameters, our line of best fit will exactly match
-the true relationship, and no divergence will occur. To test our understanding, we ensure no residual errors exist in
-the best possible model: we first use the entire dataset to fit a linear model, then replace all target values
-with the predictions made by the ideal linear model. We then rerun our typical fitting process using these
-new labels, sweeping the number of training data ([Fig 4](#fig_no_residuals_in_ideal)).
-
-As a short aside, what could cause residual errors in the best possible model in the model class?
-
-1. __Noise__: If the data is noisy, then the best possible model in the model class will have residual errors.
-2. __Model Misspecification__: If the data is generated by a nonlinear model, but we use a linear model class (or vice versa), then the best possible model in the model class will have residual errors.
-3. __Missing Features__: Even if the data is noiseless and our model belongs to the correct model class, but we are missing covariates, then the best possible model in the model class will still have residual errors.
-
-### Divergence at the Interpolation Threshold
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/california_housing/least_informative_singular_value.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/diabetes/least_informative_singular_value.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/student_teacher/least_informative_singular_value.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/who_life_expectancy/least_informative_singular_value.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 5. The training features are most likely to obtain their smallest non-zero singular value when approaching the interpolation threshold.
-
-
-
-Why does this divergence happen near the interpolation threshold? The answer is that the first factor
-(small non-zero singular values in the training features $X$) is likely to occur at the interpolation
-threshold ([Fig 5](#fig_least_informative_singular_value)), but why?
-
-Suppose we're given a single
-training datum $$\vec{x}_1$$. So long as this datum isn't exactly zero, that datum varies in a single
-direction, meaning we gain information about the variance in that direction, but the variance in all
-orthogonal directions is exactly 0. With the second training datum $$\vec{x}_2$$, so long as this datum
-isn't exactly zero, that datum varies, but now, some fraction of $$\vec{x}_2$$ might have a positive
-projection along $$\vec{x}_1$$; if this happens (and it likely will, since the two vectors are unlikely
-to be exactly orthogonal), the shared direction gives us _more_ information about the variance
-in this shared direction, but _less_ information about the second orthogonal direction of variation.
-Ergo, the training data's smallest non-zero singular value after 2 samples is probabilistically smaller than
-after 1 sample. As we approach the interpolation threshold, the probability that each additional datum
-has large variance in a new direction orthogonal to all previous directions grows unlikely
-([Fig 5](#fig_geometric_smallest_nonzero_singular_value)), but as we move beyond the interpolation threshold, the variance
-in each covariate dimension becomes increasingly clear.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/smallest_nonzero_singular_value/data_distribution.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/smallest_nonzero_singular_value/data_distribution_num_data=1.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/smallest_nonzero_singular_value/data_distribution_num_data=2.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/smallest_nonzero_singular_value/data_distribution_num_data=3.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/smallest_nonzero_singular_value/data_distribution_num_data=8.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/smallest_nonzero_singular_value/data_distribution_num_data=100.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 6. Geometric intuition for why the smallest non-zero singular value reaches its lowest value near the interpolation threshold.
- If $1$ datum is observed, variance exists in only 1 direction. If $2$ data are observed, a second axis of
- variation appears, but because the two data are likely to share some component, the second axis is likely to have
- less variance than the first. At the interpolation threshold (here, $D=P=N=3$), because the three data are
- likely to share components along the first two axes, the third axis is likely to have even less variance.
- Beyond the interpolation threshold, additional data contribute additional variance to these three axes.
-
-
-
-
-### Generalization in Overparameterized Linear Regression
-
-You might be wondering why three of the datasets have low test squared error in the overparameterized regime (California
-Housing, Diabetes, Student-Teacher) but one (WHO Life Expectancy) does not. Recall that the overparameterized regime's prediction
-error has another term $$\hat{y}_{test,over} - y_{test}^*$$ not present in the underparameterized regime:
-
-$$
-\begin{equation}
-\vec{x}_{test} \cdot (X^T (X X^T)^{-1} X - I_D) \beta^*.
-\label{eq:bias}
-\end{equation}
-$$
-
-To understand why this bias exists, recall that our goal is to correlate fluctuations in the covariates
-$\vec{x}$ with fluctuations in the targets $y$. In the overparameterized regime, there are more parameters
-than data; consequently, for $N$ data points in $D=P$ dimensions, the model can "see" fluctuations in at
-most $N$ dimensions, but has no ``visibility" into the remaining $P-N$ dimensions. This causes information
-about the optimal linear relationship $\vec{\beta}^*$ to be lost, thereby increasing the overparameterized
-prediction error.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/overparameterized_generalization.jpg" class="img-fluid rounded z-depth-1"%}
-
-
-
- Figure 7. Geometry of Generalization in Overparameterized Ordinary Linear Regression.
- The rowspace of the training features $X$ forms a subspace (here, $\mathbb{R}^1$) of the ambient space
- (here, $\mathbb{R}^2$). For test datum $\vec{x}_{test}$, the linear model forms an internal representation
- of the test datum $\hat{\vec{x}}_{test}$ by orthogonally projecting the test datum onto the rowspace via
- projection matrix $X^T (X X^T)^{-1} X$. The generalization error will then increase commensurate with the
- inner product between $\hat{\vec{x}}_{test} - \vec{x}_{test}$ and the best possible parameters for the
- function class $\vec{\beta}^*$. Three different possible $\vec{\beta}^*$ are shown with
- low (blue), medium (green)
- and high (red) generalization errors.
-
-
-
-We previously saw that away from the interpolation threshold, the variance is unlikely to affect the
-discrepancy between the overparameterized model's predictions and the ideal model's predictions,
-meaning most of the discrepancy must therefore emerge from the bias (Eqn. \ref{eq:bias}).
-This bias term yields an intuitive geometric picture ([Fig 7](#fig_overparameterized_generalization)) that
-also reveals a surprising fact: _overparameterized linear regression does representation learning!_
-Specifically, for test datum $$\vec{x}_{test}$$, a linear model creates a representation of the test datum
-$$\hat{\vec{x}}_{test}$$ by orthogonally projecting the test datum onto the row space of the training
-covariates $$X$$ via the projection matrix $$X^T (X X^T)^{-1} X$$:
-
-$$
-\begin{equation*}
-\hat{\vec{x}}_{test} := X^T (X X^T)^{-1} X \; \vec{x}_{test}.
-\end{equation*}
-$$
-
-Seen this way, the bias can be rewritten as the inner product between (1) the difference between its representation of the test datum and the test datum and (2) the ideal linear model's fit parameters:
-
-$$
-\begin{equation}\label{eq:overparam_gen_bias}
-(\hat{\vec{x}}_{test} - \vec{x}_{test}) \cdot \vec{\beta}^*.
-\end{equation}
-$$
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/california_housing/test_bias_squared.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/diabetes/test_bias_squared.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/student_teacher/test_bias_squared.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_ablations/who_life_expectancy/test_bias_squared.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 8. Test Error of Overparameterized Models. Large inner product between the ideal model's parameters and
- the difference between the fit model's internal representations of the test data and the test data creates
- large test squared error for overparameterized models.
-
-
-
-
-Intuitively, an overparameterized model will generalize well if the model's representations capture the essential
-information necessary for the best model in the model class to perform well ([Fig. 8](#fig_test_bias_squared)).
-
-## Adversarial Test Data and Adversarial Training Data
-
-Our key equation (Eqn. \ref{eq:variance}) also reveals _why_ adversarial test data and adversarial training data exist
-(at least in linear regression) and _how_ mechanistically they function. For convenience, we repeat the equation:
-
-$$
-\begin{equation*}
-\sum_{r=1}^R \frac{1}{\sigma_r} (\vec{x}_{test} \cdot \vec{v}_r) (\vec{u}_r \cdot E).
-\end{equation*}
-$$
-
-Adversarial test examples are a well-known phenomenon in machine learning that we can see in this equation.
-The adversarial test features correspond to $$\vec{x}_{test} \cdot \vec{v}_r$$ being large, where one can drastically increase
-the test squared error by moving the test example in the direction of the right singular vector(s) with the smallest non-zero
-singular values ([Fig 9](#fig_adversarial_train_data)).
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/california_housing/adversarial_test_datum.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/diabetes/adversarial_test_datum.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/student_teacher/adversarial_test_datum.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/who_life_expectancy/adversarial_test_datum.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 9. Adversarial Test Examples in Linear Regression. Adversarial examples arise by pushing
- $\vec{x}_{test}$ far along the trailing singular modes in the training features $X$.
- Blue is training error.Orange is test error.
-
-
-
-
-Less well-known are adversarial training data, akin to dataset poisoning
-or backdoor attacks .
-Adversarial training examples correspond to $$\vec{u}_r \cdot E$$ being large, where one can drastically
-increase the test squared error by moving the training errors $E$ in the direction of the left singular vector(s) with the smallest
-non-zero singular value. This gives a practical way to construct _adversarial training data_: training features and targets
-whose training loss is unchanged from unaltered training data, but causes the test loss to be 1-3 orders of magnitude
-larger ([Fig 10](#fig_adversarial_train_data)).
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/california_housing/adversarial_train_data.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/diabetes/adversarial_train_data.png" class="img-fluid rounded z-depth-1" %}
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/student_teacher/adversarial_train_data.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/real_data_adversarial/who_life_expectancy/adversarial_train_data.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 10. Adversarial Training Dataset in Linear Regression. By manipulating the residual errors $E$
- that the best possible model in the model class achieves on the training data, we construct training datasets
- that increase the test error of the learned model by 1-3 orders of magnitude without affecting its training
- error. Blue is training error.Orange is test error.
-
-
-
-## Intuition for Nonlinear Models
-
-
-Although we mathematically studied ordinary linear regression, the intuition for why the test loss diverges extends
-to nonlinear models, such as polynomial regression and including certain classes of deep neural networks .
-For a concrete example about how our intuition can shed
-light on the behavior of nonlinear models, Henighan et al. 2023
-recently discovered interesting properties of shallow nonlinear autoencoders: depending on the number of training data,
-(1) autoencoders either store data points or features, and (2) the test loss increases sharply between these two
-regimes ([Fig. 11](#fig_henighan)).
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-double-descent-demystified/henighan2023superposition.png" class="img-fluid rounded z-depth-1"%}
-
-
-
- Figure 11. Superposition, Memorization and Double Descent in Nonlinear Shallow Autoencoders.
- Figure from Henighan et al. 2023 .
-
-
-
-Our work sheds light on the results in two ways:
-
-
-1. Henighan et al. 2023 write, "It’s interesting to note that we’re observing double descent in the absence of label noise." Our work clarifies that noise, in the sense of a random quantity, is _not_ necessary to produce double descent. Rather, what is necessary is _residual errors from the perspective of the model class_ ($E$, in our notation). Those errors could be entirely deterministic, such as a nonlinear model attempting to fit a noiseless linear relationship, or other model misspecifications.
-
-2. Henighan et al. 2023 write, "[Our work] suggests a naive mechanistic theory of overfitting and memorization: memorization and overfitting occur when models operate on 'data point features' instead of 'generalizing features'." Our work hopefully clarifies that this dichotomy is incorrect: when overparameterized, data point features are akin to the Gram matrix $X X^T$ and when underparameterized, generalizing features are akin to the second moment matrix $X^T X$. Our work hopefully clarifies that data point features can and very often do generalize, and that there is a deep connection between the two, i.e., their shared spectra.
-
-
-## Conclusion
-
-In this work, we intuitively and quantitatively explained why the test loss misbehaves based on three interpretable
-factors, tested our understanding via ablations, connected our understanding to adversarial test examples and
-adversarial training datasets, and added conceptual clarity of recent discoveries in nonlinear models.
\ No newline at end of file
diff --git a/_posts/2024-05-07-dpi-fsvi.md b/_posts/2024-05-07-dpi-fsvi.md
deleted file mode 100644
index a8d084ac..00000000
--- a/_posts/2024-05-07-dpi-fsvi.md
+++ /dev/null
@@ -1,1267 +0,0 @@
----
-layout: distill
-title: "Bridging the Data Processing Inequality and Function-Space Variational Inference"
-description: >-
- This blog post explores the interplay between the Data Processing Inequality (DPI), a cornerstone concept in information theory, and Function-Space Variational Inference (FSVI) within the context of Bayesian deep learning. The DPI governs the transformation and flow of information through stochastic processes, and its unique connection to FSVI is employed to highlight FSVI's focus on Bayesian predictive posteriors over parameter space. Throughout the post, theoretical concepts are intertwined with intuitive explanations and mathematical rigor, offering a comprehensive understanding of these complex topics. The post concludes by bringing together various ideas to explain why the choice of predictive priors (initial probability distributions assumed for model predictions before training) is important for training machine learning models and preventing overfitting. It also discusses the practical implications of these concepts in areas such as continual learning and knowledge distillation. By examining these concepts in depth, the post provides valuable insights for both theory and practice in machine learning, making it an informative resource for researchers and practitioners.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Andreas Kirsch
- url: "https://www.blackhc.net"
- affiliations:
- name: University of Oxford (Former Affiliation)
-
-# authors:
-# - name: Albert Einstein
-# url: "https://en.wikipedia.org/wiki/Albert_Einstein"
-# affiliations:
-# name: IAS, Princeton
-# - name: Boris Podolsky
-# url: "https://en.wikipedia.org/wiki/Boris_Podolsky"
-# affiliations:
-# name: IAS, Princeton
-# - name: Nathan Rosen
-# url: "https://en.wikipedia.org/wiki/Nathan_Rosen"
-# affiliations:
-# name: IAS, Princeton
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-dpi-fsvi.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Introduction
- - name: "Background: Information-Theoretic Notation"
- - name: "Data Processing Inequality"
- subsections:
- - name: "Example: Image Processing Pipeline"
- - name: "Example: Supervised Learning"
- - name: "Example: Autoencoders"
- - name: "Proof of the DPI"
- - name: "🥬 Data Processing Inequality"
- subsections:
- - name: "Example: Comparing Image Distributions"
- - name: "Counter-Example: Bayesian Inference"
- - name: "Proofs of the 🥬 DPI"
- - name: Overall Statement
- - name: "Other Data Processing Inequalities"
- subsections:
- - name: "Jensen-Shannon Divergence"
- - name: "JSD-DPI"
- - name: "Mutual Information"
- - name: "Function-Space Variational Inference"
- subsections:
- - name: "Problem Setting & Notation"
- - name: "Chain Rule of the 🥬 Divergence & DPI"
- - name: "Deriving the Functional ELBO"
- - name: "Choosing the \"Coreset\""
- - name: "Application to Continual Learning"
- - name: Comparison to FSVI in the Literature
- - name: The Equality Case and Equivalence Classes
- subsections:
- - name: "Equivalence Classes"
- - name: "Consistency"
- - name: "Equality & Symmetries"
- - name: "Predictive Prior"
- - name: "Parameter Priors vs. Predictive Priors"
- subsections:
- - name: "Label Entropy Regularization"
- - name: "Knowledge Distillation"
- - name: Conclusion
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
- .box-note, .box-warning, .box-error, .box-important {
- padding: 15px 15px 15px 10px;
- margin: 20px 20px 20px 5px;
- border: 1px solid #eee;
- border-left-width: 5px;
- border-radius: 5px 3px 3px 5px;
- }
- d-article .box-note {
- background-color: #eee;
- border-left-color: #2980b9;
- }
- d-article .box-warning {
- background-color: #fdf5d4;
- border-left-color: #f1c40f;
- }
- d-article .box-error {
- background-color: #f4dddb;
- border-left-color: #c0392b;
- }
- d-article .box-important {
- background-color: #d4f4dd;
- border-left-color: #2bc039;
- }
- html[data-theme='dark'] d-article .box-note {
- background-color: #333333;
- border-left-color: #2980b9;
- }
- html[data-theme='dark'] d-article .box-warning {
- background-color: #3f3f00;
- border-left-color: #f1c40f;
- }
- html[data-theme='dark'] d-article .box-error {
- background-color: #300000;
- border-left-color: #c0392b;
- }
- html[data-theme='dark'] d-article .box-important {
- background-color: #003300;
- border-left-color: #2bc039;
- }
- html[data-theme='dark'] d-article blockquote {
- color: var(--global-text-color) !important;
- }
- html[data-theme='dark'] d-article summary {
- color: var(--global-text-color) !important;
- }
- d-article aside * {
- color: var(--global-text-color) !important;
- }
- d-article p {
- text-align: justify;
- text-justify: inter-word;
- -ms-hyphens: auto;
- -moz-hyphens: auto;
- -webkit-hyphens: auto;
- hyphens: auto;
- }
- d-article aside {
- border: 1px solid #aaa;
- border-radius: 4px;
- padding: .5em .5em 0;
- font-size: 90%;
- }
- d-article aside p:first-child {
- margin-top: 0;
- }
- d-article details {
- border: 1px solid #aaa;
- border-radius: 4px;
- padding: .5em .5em 0;
- }
- d-article summary {
- font-weight: bold;
- margin: -.5em -.5em 0;
- padding: .5em;
- display: list-item;
- }
- d-article details[open] {
- padding: .5em;
- }
- d-article details[open] summary {
- border-bottom: 1px solid #aaa;
- margin-bottom: .5em;
- }
-categories:
-- Data Processing Inequality
-- Information Theory
-- Data Processing Inequality
-- Information Theory
-- Function-Space Variational Inference
-- Parameter Equivalence Classes
-- Entropy Regularization
-- Label Entropy Regularization
----
-
-{% raw %}
-
-{% endraw %}
-
-## Introduction
-
-In information theory, the **data processing inequality (DPI)** expresses a fundamental idea: processing data (stochastically) cannot increase information. The DPI provides us with a powerful intuition about what information processing systems can do and what the limitations of data processing are.
-
-In this blog post, we first study the DPI, developing intuition through vivid examples and detailed proofs---especially the equality case, which is arguably the best way to understand inequalities. We will consider classic forms of the DPI as well as DPIs relating probability distributions more broadly.
-Then, we explore the intriguing connection between DPI and **function-space variational inference (FSVI)**, a modern Bayesian deep learning technique that focuses on the Bayesian predictive posterior rather than the parameter space. Exploring this connection is important because it can provide new insights into FSVI on a fundamental level. We apply the DPI to recover several interesting results from the literature in a simple form and build intuitions for the relationship between parameter and functional priors.
-
-Most importantly, we consider how FSVI can measure a *predictive* divergence between the approximate and true posterior which is independent of parameter symmetries. (With parameter symmetries, I refer to different parameters that yield the same predictions, which is very common in over-parameterized neural networks: think of parameter symmetries like different paths leading to the same destination; they might look different but end up at the same predictionsThanks to ChatGPT for this analogy! 🤗.) Explaining this connection is one of the main goals of this article and will help you understand the relationships between DPI, FSVI, and other deep learning methods.
-As a concrete example and application, we relate FSVI to training with knowledge distillation and label entropy regularization: potentially more meaningful priors than the ones usually used in Bayesian neural networksIn many papers, an isotropic Gaussian is used because of its simplicity. Indeed, there are better alternatives, see Fortuin et al (2022) and Fortuin (2022).. This connection highlights the practical relevance of the theoretical concepts discussed in this post and will hopefully inspire the reader to view Bayesian deep learning from a new point of view.
-
-### TL;DR
-
-The following sections summarize the key takeaways of this blog post. If they don't make sense, don't worry: they will after reading this post.
-
-#### Data Processing Inequality
-
-The data processing inequality examines how information cannot increase due to processing. In information theory, it is usually stated based on a Markov chain of random variables $$X \rightarrow Y \rightarrow Z$$ and their mutual information. We will look at different data processing inequalities that relate different distributions instead of different random variables. However, the blog posts in particular looks at the DPI when formulated using Kullback-Leibler (KL) divergences between distributions. I will use "🥬 divergence" in headings to add a bit of color. 😊
-
-Concretely, this KL DPI states that processing data stochastically can only reduce information. More formally:
-
-
-
-That is, the KL divergence between $$\qof{Y}$$ and $$\pof{Y}$$ cannot be larger than the one between the original $$\qof{\W}$$ and $$\pof{\W}$$. Intuitively, the stochastic mapping $$\opf$$ induces a bottleneck that reduces how well we can distinguish between $$\opp$$ and $$\opq$$. Finally we have equality when $$\Kale{\qof{\W \given Y}}{\pof{\W \given Y}} = 0$$.
-
-The paper "*Understanding Variational Inference in Function-Space*" by Burt et al. (2021) succinctly summarizes the DPI as follows:
-
-
-The data processing inequality states that if two random variables are transformed in this way, they cannot become easier to tell apart.
-
-
-#### Function-Space Variational Inference
-
-Generally, *variational inference* is a powerful technique for approximating complex Bayesian posteriors with simpler distributions. In its usual form, it optimizes an approximate, *variational* distribution to match the *Bayesian **parameter** posterior* as closely as possible. This way, it transforms the problem of Bayesian inference into an optimization problem.
-
-However, especially for deep neural networks, obtaining a good approximation of the parameter space can be difficult. One reason is the sheer size of the parameter space. Additionally, the parameterization of a neural network often contains many symmetries---different parameter configurations can lead to the same predictions of the model---that are not taken into account either.
-
-Here, **Function-space variational inference (FSVI)** side-steps some of these restrictions by only requiring that the variational distribution matches the *Bayesian **predictive** posterior*:
-Whereas regular variational inference regularizes towards a parameter prior, FSVI regularizes towards a data prior. This is especially useful when the parameter prior is not very meaningful, e.g. an isotropic Gaussian prior, which is often used in Bayesian neural networks.
-
-
-
-## Background: Information-Theoretic Notation
-
-Information theory deals with the communication of informationSee the excellent "Visual Information Theory" by Chris Olah for a visual introduction to information theory.. In this blog post, we use a unified information-theoretic notation to express various quantities related to probability distributions and their relationshipsIt largely follows "A Practical & Unified Notation for Information-Theoretic Quantities in ML".. Here are some key concepts we will use:
-
-The **information content** of an event $$x$$ is denoted as $$\Hof{x}$$ and is defined as $$-\log \pof{x}$$. It represents the minimum amount of information needed to describe the occurrence of $$x$$ given an underlying probability distribution.
-In machine learning, this information content is often used as a minimization objective, represented as the negative log-likelihood or cross-entropy when averaged over a dataset.
-
-The **entropy** $$\Hof{X}$$ of a random variable $$X$$ is the expectation of its information content:
-
-$$
-\Hof{X} \triangleq \E{\pof{x}}{\Hof{x}} = \E{\pof{x}}{-\log \pof{x}}.
-$$
-
-The entropy measures the average amount of information needed to describe the random variable $$X$$. It provides a measure of uncertainty or randomness associated with $$X$$. We can similarly define the entropy of a conditional distribution $$\Hof{X \given Y}$$ and the joint entropy $$\Hof{X, Y}$$.
-
-The **mutual information** $$\MIof{X;Y}$$ between two random variables $$X$$ and $$Y$$ is a measure of the amount of information that one random variable contains about the other. It is defined as:
-
-$$
-\begin{aligned}
-\MIof{X;Y} & \triangleq \Hof{X} - \Hof{X \given Y} \\
-&= \Hof{Y} - \Hof{Y \given X} \\
-&= \Hof{X} + \Hof{Y} - \Hof{X, Y}.
-\end{aligned}
-$$
-
-We will also use the **Kullback-Leibler divergence** $$\Kale{\pof{X}}{\qof{X}}$$ and the **cross-entropy** $$\CrossEntropy{\pof{X}}{\qof{X}}$$:
-
-$$
-\begin{aligned}
-\CrossEntropy{\pof{X}}{\qof{X}} & = \E{\pof{x}}{-\log \qof{x}}\\
-\Kale{\pof{X}}{\qof{X}} & = \CrossEntropy{\pof{X}}{\qof{X}} - \Hof{X}
-\end{aligned}
-$$
-
-The cross-entropy quantifies the average number of bits needed to encode samples drawn from the true distribution $$\pof{X}$$ using a different distribution $$\qof{X}$$. The Kullback-Leibler divergence is a measure of the difference between two probability distributions and captures the additional bits needed to encode samples from $$\pof{X}$$ compared to encoding them using the true distribution $$\qof{X}$$.
-
-Now that we have covered the notation, let's delve into the data processing inequality.
-
-## Data Processing Inequality
-
-The **data processing inequality (DPI)** is a fundamental inequality in information theory that states the mutual information between two random variables cannot increase through processing. The original DPI is typically stated for a Markov chain of random variables $$X \rightarrow Y \rightarrow Z$$ and relates the mutual information terms as follows:
-
-$$
-\MIof{X;Y} \ge \MIof{X;Z}.
-$$
-
-We can view $$\rightarrow$$ as a processing or transition step that maps $$X$$ to $$Y$$ and $$Y$$ to $$Z$$, whereas the mapping can be deterministic or stochastic.
-The inequality tells us that processing the random variable $$X$$ to obtain $$Y$$ and further processing $$Y$$ to obtain $$Z$$ cannot increase the mutual information between $$X$$ and $$Z$$ compared to the mutual information between $$X$$ and $$Y$$.
-
-The following three scenarios illustrate the data processing inequality using different mappings:
-
-### Example: Image Processing Pipeline
-
-Consider an image processing pipeline with the following steps. Let:
-
-* $$X$$ be the original image data;
-* $$Y$$ be a compressed version of the image; and
-* $$Z$$ be $$Y$$ after adding blur and pixelation.
-
-In this case, $$X$$ has more mutual information with $$Y$$ than with $$Z$$. The compression reduces information, but the image is still recognizable. However, after the additional processing of blurring and pixelating, the mutual information between $$X$$ and $$Z$$ is further reduced. This gives an intuitive example of how additional processing on data reduces the mutual information with the original data. Each processing step results in some loss of information.
-
-### Example: Supervised Learning
-Consider a supervised learning pipeline with the following steps. Let
-
-* $$X$$ be the input features;
-* $$Y$$ be the intermediate representations learned by the model; and
-* $$Z$$ be the model predictions.
-
-Here, $$X \rightarrow Y \rightarrow Z$$ forms a Markov chain. The data processing inequality tells us that the mutual information between the inputs $$X$$ and predictions $$Z$$ cannot exceed the mutual information between the inputs $$X$$ and intermediate representations $$Y$$:
-
-$$\MIof{X; Y} \geq \MIof{X; Z}.$$
-
-This makes intuitive sense---the intermediate representations $$Y$$ are obtained by processing the raw inputs $$X$$, so they cannot contain more information about $$X$$ than $$X$$ itself. The predictions $$Z$$ are obtained by further processing $$Y$$, so additional information may be lost, reducing the mutual information with the original inputs $$X$$.
-
-As a more concrete example, consider an image classification model. Let:
-
-* $$X$$ be the input images;
-* $$Y$$ be the activations of the convolutional layers; and
-* $$Z$$ be predicted image labels.
-
-The convolutional layers will extract features from the input images, but cannot extract more information than present in the original images. The predicted labels are obtained by further processing these convolutional features, so may lose some fine-grained information about the original inputs.
-
-### Example: Autoencoders
-An autoencoder compresses the input $$X$$ into a latent code $$Y$$ and then tries to reconstruct the original input from the code, producing $$\hat{X}$$. Let:
-
-* $$X$$ be the input;
-* $$Y$$ be the latent code; and
-* $$\hat{X}$$ be the reconstruction;
-
-The data processing inequality tells us again:
-
-$$\MIof{X; Y} \geq \MIof{X; \hat{X}}.$$
-
-The latent code $$Y$$ is obtained by compressing $$X$$, so cannot contain more information. The reconstruction $$\hat{X}$$ tries to recover $$X$$ from $$Y$$, but some information may be lost, reducing the mutual information with $$X$$.
-
-Intuitively, autoencoders try to preserve as much mutual information between inputs $$X$$ and reconstructions $$\hat{X}$$ as possible by learning latent representations $$Y$$ that compress inputs without losing too much information. The data processing inequality quantifies this information bottleneck.
-
-### Proof of the DPI
-
-The proof is simple and connects the DPI to another important inequality.
-
-First we note that the Markov Chain implies the following factorization of the joint distribution:
-
-$$
-\pof{x, y, z} = \pof{x} \pof{y \given x} \pof{z \given y}.
-$$
-
-Using this factorization, we can express the mutual information terms:
-
-$$
-\begin{aligned}
-\MIof{X;Y} &= \Hof{X} - \Hof{X \given Y} \\
-&\ge \Hof{X} - \Hof{X \given Z} \\
-&= \MIof{X;Z}.
-\end{aligned}
-$$
-
-This relies on $$\Hof{X \given Y} \le \Hof{X \given Z}$$. Why is this true?
-
-We have the following chain of inequalities:
-
-$$
-\Hof{X \given Y} = \underbrace{\MIof{X ; Z \given Y}}_{\overset{(1)}{=}0} + \Hof{X \given Y, Z} \overset{(2)}{\le} \Hof{X \given Z}.
-$$
-
-**(1)** follows from the Markov chain property: when $$X \rightarrow Y \rightarrow Z$$, $$X$$ does not depend on $$Z$$ at all when conditioned on $$Y$$; and **(2)** follows from the fact that conditioning reduces entropy, i.e. $$\Hof{A \given B} \le \Hof{A}.$$
-
-The equality gap $$\Hof{X \given Y, Z} - \Hof{X \given Z}$$ corresponds to the mutual information $$\MIof{X ; Y \given Z}$$. This mutual information measures the extra information about $$X$$ contained in $$Y$$ that is not already conveyed by $$Z$$. It is zero if and only if $$X \rightarrow Z \rightarrow Y$$ forms a Markov chain, indicating that $$Z$$ is a sufficient statistic for $$X$$.
-
-
-Proof of (2) "Conditioning Reduces Entropy":
-We can easily show that conditioning reduces entropy by using the non-negative property of the mutual information:
-
-$$
-\begin{aligned}
-0 &\le \Kale{\pof{X,Y}}{\pof{X}\pof{Y}} \\
-&= \MIof{X;Y} \\
-&= \Hof{X} - \Hof{X \given Y} \\
-\implies \Hof{X \given Y} &\le \Hof{X}.
-\end{aligned}
-$$
-
-
-The fact that conditioning reduces entropy, $$\Hof{X} \ge \Hof{X \given Y}$$, is an important property by itself and is reminiscent of the data processing inequality.
-The conditional entropy $$\Hof{X \given Y}$$ quantifies the remaining uncertainty about $$X$$ after observing $$Y$$. If $$X$$ and $$Y$$ are independent, then $$\Hof{X} = \Hof{X \given Y}$$, as knowing $$Y$$ does not provide any information about $$X$$. On the other hand, if $$Y$$ completely determines $$X$$, then $$\Hof{X \given Y} = 0$$, as there is no remaining uncertainty about $$X$$ once $$Y$$ is known. In general, conditioning can only reduce the uncertainty about $$X$$, but it does not necessarily reduce it to zero.
-
-Let's move on and consider the KL data processing inequality.
-
-## 🥬 Data Processing Inequality
-
-A similar DPI can be expressed for different distributions $$\pof{x}$$ and $$\qof{x}$$ of the same random variable and the KL divergence between them.
-This DPI states that if we evolve two distributions using the same *transition function*, they cannot become less similar. The KL divergence is sometimes also referred to as "relative entropy", so we could also call this the "*relative data processing inequality*".
-
-This can be formalized for distributions $$\pof{x}$$ and $$\qof{x}$$ and a stochastic transition function $$X \overset{\fof{y \given x}}{\longrightarrow} Y$$. Here, we use that such a stochastic mapping $$Y = \fof{X}$$ is equivalent to having a probability (density) $$\fof{y \given x}$$:
-
-$$
-\Kale{\pof{X}}{\qof{X}} \ge \Kale{\pof{Y}}{\qof{Y}},
-$$
-
-where $$\pof{y \given x} = \fof{y \given x} = \qof{y \given x}$$. The marginals after the transition are $$\pof{y} = \E{\pof{x}}{\fof{y \given x}}$$ and $$\qof{y} = \E{\qof{x}}{\fof{y \given x}}$$, so more explicitly:
-
-$$
-\Kale{\pof{X}}{\qof{X}} \ge \Kale{\E{\pof{x}}{\fof{Y \given x}}}{\E{\qof{x}}{\fof{Y \given x}}}.
-$$
-
-In their book [Elements of Information Theory](https://www.wiley.com/en-us/Elements+of+Information+Theory%2C+2nd+Edition-p-9780471241959), Thomas and Cover describe this as "relative entropy never increases" and relate it to the second law of thermodynamics.
-
-### Example: Comparing Image Distributions
-
-As an example, let:
-
-* $$\pof{x}$$ be the true distribution of images in a dataset;
-* $$\qof{x}$$ be a generative model that tries to mimic $$\pof{x}$$; and
-* $$\fof{y \given x}$$ be a function that thresholds images $$x$$ into bilevel black and white images $$y$$.
-
-Then $$\pof{y}$$ and $$\qof{y}$$ will be more difficult to distinguish after the thresholding operation than $$\pof{x}$$ and $$\qof{x}$$. Converting to black and white images has lost information that could help distinguish the real and generated distributions.
-
-This provides some intuition for why the KL divergence between distributions decreases under a shared stochastic mapping, as formalized by the KL data processing inequality. Processing through $$\fof{y \given x}$$ makes the distributions harder to tell apart.
-
-### Counter-Example: Bayesian Inference
-
-It might be inviting to think that this data processing inequality also applies to Bayesian inference, that is updating the model parameters based on new evidence. Then, we could argue that if two agents start with different prior beliefs but update based on the same evidence, their posterior beliefs will become more similar. However, this intuition is flawed: the data processing inequality does not apply to Bayesian inference.
-
-Let's walk through why. Consider:
-
-* $$\pof{\w}$$ be an agent's prior belief;
-* $$\qof{\w}$$ be another agent's different prior;
-* $$\pof{\w\given x}$$ is the posterior after observing data $$x$$; and
-* $$\qof{\w\given x}$$ is the other agent's posterior.
-
-The priors $$\pof{\w}$$ and $$\qof{\w}$$ may have large divergence, representing very different initial beliefs. However, when conditioning on the same data $$x$$, the KL divergence between $$\pof{\w \given x}$$ and $$\qof{\w \given x}$$ could increase or decrease---the data processing inequality does not give us any guarantee.
-
-This is because $$\pof{\w}$$ and $$\qof{\w}$$ are not evolving under the same stochastic mapping. Rather, each prior is mapped to its respective posterior via Bayes' rule, which operates differently on $$\opp$$ and $$\opq$$:
-
-$$
-\begin{aligned}
-\pof{\w \given x} &= \frac{\pof{x \given \w}}{\pof{x}} \, \pof{\w}\\
-\qof{\w \given x} &= \frac{\qof{x \given \w}}{\qof{x}} \, \qof{\w}.
-\end{aligned}
-$$
-
-Even assuming that both agents have the same internal model, that is they use the same likelihood $$\pof{x \given \w} = \qof{x \given \w}$$, the priors $$\pof{\w}$$ and $$\qof{\w}$$ will still influence the posterior distributions differently because they lead to different evidence terms $$\pof{x}$$ and $$\qof{x}$$:
-
-$$
-\begin{aligned}
-\pof{x} &= \E{\pof{\w}}{\pof{x \given \w}}\\
-\qof{x} &= \E{\qof{\w}}{\qof{x \given \w}}.
-\end{aligned}
-$$
-
-Thus, the correct intuition is that observing the same data $$x$$ does not necessarily bring the posterior beliefs closer together---they depend on the interplay between their specific priors and likelihoods. The data processing inequality does not directly apply to this Bayesian updating scenario:
-
-$$
-\Kale{\qof{\W}}{\pof{\W}} {\color{red}{\not\ge}} \Kale{\qof{\W \given \mathcal{D}}}{\pof{\W \given \mathcal{D}}},
-$$
-
-This counterexample highlights the importance of precisely understanding the assumptions underlying conceptual principles like the DPI. While the DPI provides insight about information dynamics in many cases, it does not universally apply, as exemplified here by Bayesian updating under different priors. As always, bear in mind that:
-
-
-
-As we currently also seem to experience a world of increasing polarization, this counterexample might also serve as a reminder that different priors can lead to different beliefs, even when observing the same evidence. This is a fundamental aspect of Bayesian inference and the scientific method.
-
-### Proofs of the 🥬 DPI
-
-We will prove this inequality in two different ways. First, we will develop a "brute-force" proof, and then we will look at a more elegant proof that follows Thomas and Cover. Importantly, we will also consider the equality case in detail.
-
-#### Brute-force Proof
-
-If $$\opp$$ does not have support in $$\opq$$, the inequality is trivially true because then $$\Kale{\pof{Y}}{\qof{Y}}=\infty$$.
-
-Thus, let's now assume that $$\opp$$ has support in $$\opq$$. Then, we can brute-force using the definitions, starting from the cross-entropy:
-
-$$
-\begin{aligned}
-\CrossEntropy{\pof{Y}}{\qof{Y}}&=\CrossEntropy{\pof{Y}}{\E{\qof{x}}{\pof{Y \given x}}}\\
-&=\CrossEntropy{\pof{Y}}{\E{\qof{x}}{\frac{\pof{x \given Y}\pof{Y}}{\pof{x}}}}\\
-&=\CrossEntropy{\pof{Y}}{\E{\pof{x \given Y}}{\frac{\qof{x}}{\pof{x}}}}+\CrossEntropy{\pof{Y}}{\pof{Y}}\\
-&\overset{(1)}{=}\CrossEntropy{\pof{Y}}{\E{\pof{x \given Y}}{\frac{\qof{x}}{\pof{x}}}}+\xHof{\pof{Y}}\\
-&\overset{(2)}{\le}\CrossEntropy{\pof{X, Y}}{\frac{\qof{X}}{\pof{X}}}+\xHof{\pof{Y}}\\
-&\overset{(3)}{=}\CrossEntropy{\pof{X}}{\frac{\qof{X}}{\pof{X}}}+\xHof{\pof{Y}}\\
-&\overset{(4)}{=}\Kale{\pof{X}}{\qof{X}}+\xHof{\pof{Y}}\\
-\iff \Kale{\pof{Y}}{\qof{Y}}&\le\Kale{\pof{X}}{\qof{X}},
-\end{aligned}
-$$
-
-where we have used **(1)** that the cross-entropy of a distribution with itself is just the entropy, **(2)** that the cross-entropy is convex and we can apply Jensen's inequality, **(3)** that the RHS side of the cross-entropy does not depend on $$Y$$ and we can trivially marginalize it out, and **(4)** that the definition of the Kullback-Leibler divergence is equivalent an (unnormalized) cross-entropy over a fraction.
-
-This makes it difficult to extract the case for equality, however.
-
-#### Equality Case
-
-We have only one inequality in above proof, and it stems from applying Jensen's inequality. Remembering the equality case for Jensen's inequality, we recall:
-
-
-
-For **(2)**, this is sadly slightly more complex than it might seem on first glance.
-Let's unwrap the term:
-
-$$
-\CrossEntropy{\pof{Y}}{\E{\pof{x \given Y}}{\frac{\qof{x}}{\pof{x}}}} = \E{\pof{y}}{-\log \E{\pof{x \given y}}{\frac{\qof{x}}{\pof{x}}}}.
-$$
-
-We take an expectation over $$\pof{y}$$, so we need to look at almost all $$\pof{x \given y} \not= 0$$ for (almost all) $$\pof{y} \not= 0$$ separately to consider equality. $$-\log x$$ is strictly convex---and thus not linear---so we need $$f(x) = \frac{\qof{X}}{\pof{X}}$$ to be constant for any fixed $$y$$ with $$\pof{y} \not= 0$$---only then have we equality in Jensen's inequality.
-
-In the following, I will limit myself to the discrete case to avoid having to deal with measure theoryI currently don't have a good 'toolbox' to express simple ideas cleanly in measure theory. I'm working on it..
-To obtain equality, for all $$y$$ with $$\pof{y} \not= 0$$ (i.e. we have support) and for all $$x_1, x_2$$ with $$\pof{x_1 \given y}, \pof{x_2 \given y} \not= 0$$, we need $$\frac{\qof{x_1}}{\pof{x_1}} = \frac{\qof{x_2}}{\pof{x_2}}$$.
-Equivalently (for the reader, why is then $$\pof{x_1} \not= 0?$$):
-
-$$
-\begin{aligned}
-\frac{\qof{x_1}}{\pof{x_1}} &= \frac{\qof{x_2}}{\pof{x_2}} \\
-\iff \qof{x_1} &= \frac{\qof{x_2}}{\pof{x_2}} \, \pof{x_1} \\
-\end{aligned}
-$$
-
-This means that $$\qof{x} = C_y \pof{x}$$ piecewise for all $$x$$ for which $$\pof{x \given y} \not= 0$$ for some fixed $$y$$ with $$\pof{y} \not= 0$$. That is if we keep $$y$$ fixed, all the $$x$$ for which $$\pof{x \given y} \not= 0$$ have the same constant factor $$C_y$$. Then for all $$y$$ with $$\pof{y} \not= 0$$, we have equality and overall equality in **(2)**.
-
-If for any $$x$$ there are multiple $$y$$, e.g. $$y_1, y_2$$ for which $$\pof{x \given y} \not= 0$$, then we have $$C_{y_1} = C_{y_2}$$.
-
-As an example, at the simplest, if this is the case for all $$y$$, then $$C_y = 1$$ constant.
-
-As a side-note, this is a great reason why we often require full support for distributions as we then can avoid these piecewise constant factors (and the headaches they might cause).
-
-#### Simpler Elegant Proof
-
-Thomas and Cover provide a beautifully simple proof:
-
-
-
-What does this mean? Whereas $$\fof{y \given x}$$ is the 'forward' transition function, $$\pof{x \given y}$$ and $$\qof{x \given y}$$ are the 'backward' transition functions. We only have equality when the backward transition functions are equal (almost everywhere).
-
-The statement on equality is not very informative yet though, so we have to put in a bit more work. Again, this is written for the discrete case.
-
-This time we explicitly use Bayes' rule to connect the forward and backward transition functions.
-First, we have to fix $$y$$ such that $$\pof{y} \not= 0$$ (i.e. $$y$$ is in the support of $$\pof{y}$$) and then $$\qof{y} \not=0$$.
-We have:
-
-$$
-\begin{aligned}
-\pof{x \given y} &= \qof{x \given y} \\
-\overset{\text{ass. }\pof{y} \not= 0}{\iff} \frac{\fof{y \given x}\pof{x}}{\pof{y}} &= \frac{\fof{y \given x}\qof{x}}{\qof{y}} \\
-\overset{\text{ass. }\fof{y \given x}\not= 0}{\iff} \frac{\pof{x}}{\pof{y}} &= \frac{\qof{x}}{\qof{y}} \\
-\iff \pof{x} &= \frac{\pof{y}}{\qof{y}} \, \qof{x}.
-\end{aligned}
-$$
-
-For a given $$y$$ with $$\pof{y} \not=0$$, for the equality case, we see that for all $$x$$ with $$\fof{y \given x} \not= 0$$, $$\pof{x}$$ and $$\qof{x}$$ have to be coupled via piecewise constant factors.
-
-As another example, if $$\fof{y \given x} \not=0$$ (has full support) for all possible $$x$$, for the equality case we have $$\pof{x} = \qof{x}$$.
-
-Compared to the previous equality case, we went a bit deeper and rewrote the conditions to consider the ratios between $$x$$ and $$y$$. Note we could have shown the same thing in the "brute-force" proof, too.
-
-Altogether, we have see that both $$x$$ and $$y$$ are modulated by the same constant factor between $$\pof{\cdot}$$ and $$\qof{\cdot}$$. Essentially, this tells us that we could split our support into unconnected sub-domains and examine each individually for the equality case.
-
-
-
-### Overall Statement
-We have the following overall statement:
-
-
-($$\pof{x} \ll \qof{x}$$ means that $$\qof{x} > 0$$ implies $$\pof{x} > 0$$, so the KL divergence is not $$\infty$$.) But more precisely, for $$\pof{x} \ll \qof{x}$$, we have equality when:
-
-$$
-\forall y, \pof{y} \not= 0 \exists C_y \in \mathbb{R}_{> 0} \forall x, \fof{y \given x}\not=0\colon \pof{x} = C_y \, \qof{x}.
-$$
-
-## Other Data Processing Inequalities
-
-Now, we can use these ideas to derive a few additional results and even close the circle to the original data processing inequality.
-
-### Jensen-Shannon Divergence
-
-The KL divergence is not a metric: the triangle inequality does not hold, and it is not symmetric.
-
-However, we can symmetrize it to obtain the [Jensen-Shannon divergence (JSD)](https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence). The JSD is defined as the mean of the two KL divergences of the two distributions from their average. In essence, it makes the KL divergence symmetric:
-
-$$
-\begin{aligned}
-\fof{x} &= \frac{\pof{x} + \qof{x}}{2}\\
-\JSD{\pof{x}}{\qof{x}} &= \frac{1}{2} \Kale{\pof{x}}{\fof{x}} + \frac{1}{2} \Kale{\qof{x}}{\fof{x}}.
-\end{aligned}
-$$
-
-Similar approaches can be used to "symmetrize" other concepts; for example matrices: $$\frac{1}{2} A + \frac{1}{2} A^T$$ is also symmetric by construction for any matrix $$A$$.
-
-The JSD is still not a metric, but the square root of the Jensen-Shannon divergence is symmetric and satisfies the triangle inequality and gives us the *Jensen-Shannon distance*, a metric.
-
-### JSD-DPI
-
-We can also obtain a data processing inequality for the Jensen-Shannon divergence and the Jensen-Shannon distance:
-
-
-
-The proof uses the KL data processing inequality:
-
-$$
-\begin{aligned}
-\JSD{\pof{X}}{\qof{X}} &= \frac{1}{2} \Kale{\pof{X}}{\fof{X}} + \frac{1}{2} \Kale{\qof{X}}{\fof{X}}\\
-&\ge \frac{1}{2} \Kale{\pof{Y}}{\fof{Y}} + \frac{1}{2} \Kale{\qof{Y}}{\fof{Y}}\\
-&= \JSD{\pof{Y}}{\qof{Y}}.
-\end{aligned}
-$$
-
-We verify $$\fof{y} = \frac{\pof{y} + \qof{y}}{2}$$ is the average of $$\pof{y}$$ and $$\qof{y}$$:
-
-$$
-\begin{aligned}
-\fof{y} &= \E{\fof{x}}{\fof{y \given x}}\\
-&= \E{\frac{\pof{x}+\qof{x}}{2}}{\fof{y \given x}}\\
-&= \frac{1}{2} \E{\pof{x}}{\fof{y \given x}} + \frac{1}{2} \E{\qof{x}}{\fof{y \given x}}\\
-&= \frac{1}{2} \pof{y} + \frac{1}{2} \qof{y}.
-\end{aligned}
-$$
-
-Finally, $$\pof{x}, \qof{x} \ll \fof{x}$$, and the equality condition of the KL data processing inequality gives us:
-
-$$
-\begin{aligned}
-&\Kale{\pof{X \given Y}}{\fof{X \given Y}} = 0 &\\
-\land \quad &\Kale{\qof{X \given Y}}{\fof{X \given Y}} = 0 &\\
-\iff &\pof{x \given y} = \fof{x \given y} \land \qof{x \given y} = \fof{x \given y}& \forall x,y \\
-\iff &\pof{x \given y} = \qof{x \given y}& \forall x,y.
-\end{aligned}
-$$
-
-### Mutual Information
-
-The JSD can also be expressed as a mutual information. For
-$$
-\begin{aligned}
-Z &\sim \mathrm{Bernoulli}(\frac{1}{2}) = \fof{Z} \\
-X \given Z = 0 &\sim \pof{x}\\
-X \given Z = 1 &\sim \qof{x},
-\end{aligned}
-$$
-
-we have:
-
-$$
-\JSD{\pof{X}}{\qof{X}} = \MIof{X;Z}.
-$$
-
-This follows from rewriting the mutual information as a KL divergence:
-
-$$
-\begin{aligned}
-\MIof{X;Z} &= \Kale{\fof{X \given Z}}{\fof{X}}\\
-&= \E{\fof{z}} {\Kale{\fof{X \given Z = z}}{\fof{X}}}\\
-&= \frac{1}{2} \Kale{\pof{x}}{\fof{x}} + \frac{1}{2} \Kale{\qof{x}}{\fof{x}}\\
-&= \JSD{\pof{X}}{\qof{X}}.
-\end{aligned}
-$$
-
-We can generalize this to the Markov chain $$Z \rightarrow X \rightarrow Y$$ with $$\fof{z, x, y} = \fof{z} \fof{x \given z} \fof{y \given x}$$ for any distribution $$\fof{z}$$:
-
-$$
-\begin{aligned}
-\MIof{X;Z} &= \Kale{\fof{X \given Z}}{\fof{X}}\\
-&= \E{\fof{z}} {\Kale{\fof{X \given z}}{\fof{X}}}\\
-&\overset{(1)}{\ge} \E{\fof{z}} {\Kale{\fof{Y \given z}}{\fof{Y}}}\\
-&= \Kale{\fof{Y \given Z}}{\fof{Y}}\\
-&= \MIof{Y;Z},
-\end{aligned}
-$$
-
-where $$(1)$$ follows from the KL data processing inequality.
-
-This is just the data processing inequality we presented initially. We have gone full circle!
-
-The equality gap (*Jensen gap*) is $$\Kale{\fof{X \given Y, Z}}{\fof{X \given Y}}$$, and we have equality when:
-
-$$
-\begin{aligned}
-\Kale{\fof{X \given Y, Z}}{\fof{X \given Y}} &= 0\\
-\iff \MIof{X;Z \given Y} &= 0.
-\end{aligned}
-$$
-
-This is exactly when $$X$$ is independent of $$Z$$ given $$Y$$. ($$Y$$ is a sufficient statistic in that case.)
-
-## Function-Space Variational Inference
-
-So far we've explored the foundational aspects of the data processing inequality (DPI) and its extended forms, in particular the KL data processing inequality. Through detailed derivations and intuitive examples, we've demonstrated how these inequalities can be applied, emphasizing their significance and limitations. Specifically, we've shown how the KL data processing inequality relates to the reduction in information as data is processed. The examples and counterexample have hopefully demonstrated the nuances of applying these inequalities in different contexts.
-
-This exploration sets the stage for diving into function-space variational inference and building up a robust understanding of it, leveraging the insights gained about the DPI and its implications in Bayesian deep learning.
-
-### Problem Setting & Notation
-
-In the following, we will consider a classification task with cross-entropy loss, and we will use the following the random variables and distributions:
-
-- $$\y$$ is the label,
-- $$\x$$ is the input,
-- $$\qof{\y \given \x}$$ is the predictive distribution we want to learn,
-- $$\pdata{\y \given \x}$$ is the data distribution,
-- $$\Dany$$ is the (training) dataset, and
-- $$C$$ is the number of classes.
-
-The probabilistic model is:
-
-$$\pof{\y, \w \given \x} = \pof{\y \given \x, \w} \, \pof{\w}.$$
-
-As before, I use upper-case letters for random variables, which we take an expectation over, e.g. in the KL divergence, and lower-case letters when I'm referring to specific observations or values that could be substituted (with the exception of $$\Dany$$).
-
-
-### Chain Rule of the 🥬 Divergence & DPI
-
-An important property of the KL divergence is the chain rule:
-
-$$
-\begin{aligned}
-&\Kale{\qof{\Y_n,...,\Y_1}}{\pof{\Y_n,...,\Y_1}} \\
-&\quad = \sum_{i=1}^n \Kale{\qof{\Y_i \given
-\Y_{i-1}, ..., \Y_1}}{\pof{\Y_i \given \Y_{i-1}, ..., \Y_1}}.
-\end{aligned}
-$$
-
-The chain rule yields a *chain inequality* for the DPI as well:
-
-$$
-\begin{aligned}
-\Kale{\qof{\W}}{\pof{\W}} &\ge \Kale{\qof{\Y_n,...,\Y_1}}{\pof{\Y_n,...,\Y_1}}\\
-&\ge \Kale{\qof{\Y_{n-1},...,\Y_1}}{\pof{\Y_{n-1},...,\Y_1}}\\
-&\ge \Kale{\qof{\Y_1}}{\pof{\Y_1}},
-\end{aligned}
-$$
-
-where we start from the KL DPI and then apply the chain rule.
-
-### Deriving the Functional ELBO
-
-The DPI has an intriguing connection to FSVI. Let's say we want to approximate a Bayesian posterior $$\pof{\w \given \Dany}$$ with a variational distribution $$\qof{\w}$$. In standard VI, we would minimize $$\Kale{\qof{\W}}{\pof{\W \given \Dany}}$$ to match the variational distribution to the Bayesian posterior. Specifically:
-
-$$
-\begin{aligned}
-&\Kale{\qof{\W}}{\pof{\W \given \Dany}} =\\
-&\quad = \underbrace{\E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\W}}{\pof{\W}}}_{\text{Evidence}\ \text{Bound}} + \log \pof{\Dany} \ge 0 \\
-&\iff \underbrace{-\log \pof{\Dany}}_{=\xHof{\pof{\Dany}}} \le \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\W}}{\pof{\W}}.
-\end{aligned}
-$$
-
-This is an information-theoretic evidence (upper) bound on the information content $$-\log \pof{\Dany}$$ of the data $$\Dany$$ under the variational distribution $$\qof{\w}$$, which we can minimize as an objective to approximiate $$\pof{\w \given \Dany}$$ via $$\qof{\w}$$.
-
-In more probability-theory inspired literature, the negative of this bound is called the *evidence lower bound (ELBO)* and is maximized.
-
-Both the ELBO and the information-theoretic evidence upper-bound are equivalent, and we can use either objective, but the information-theoretic perspective is obviously superior 🙃 I'll refer to this as evidence bound from now on.
-
-In FSVI (with a caveat I detail below), we apply the DPI to the prior KL divergence term and obtain a "functional" version of the evidence bound:
-
-$$
-\begin{aligned}
-\Kale{\qof{\W}}{\pof{\W}} \ge \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}},
-\end{aligned}
-$$
-
-where $$\Y... \given \x...$$ are (finite or infinite) sets of samples. That is, we do not only optimize marginal distributions but also joint distributions.
-
-
-
-
-The resulting objective:
-
-$$
-\begin{aligned}
-\E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}
-\end{aligned}
-$$
-
-is equal to the (negative) *functional ELBO (fELBO)* in "*Functional variational Bayesian neural networks*" by Sun et al. (2019)---with caveats that we discuss below.
-
-### Choosing the "Coreset" $$\x...$$
-
-One important detail is the question of how to choose the $$\x...$$:
-
-Ideally, we want to choose them such that the DPI inequality is as tight as possible.
-
-Given the chain inequality, it is obvious that the larger the set $$\x...$$, the tighter the inequality will be.
-Hence, if we could choose an infinite set of points well, we might be able to get the tightest possible inequality.
-However, this might not be tractable, and in practice, it is often not.
-
-Some works take a supremum over finite subsets of a certain size, essentially building a core-set as an approximation (Rudner et al., 2022a/b);
-others take an expectation over finite sets of input samples (Sun et al., 2019), which is not necessarily yielding the tightest inequality but provides an unbiased estimate; while again other works focus on finite datasets for which the all points can be taken into account (Klarner et al., 2023).
-
-We will discuss the tightness of the inequality and the implications in the data limit below.
-
-Focusing on the most important aspect of FSVI, we observe:
-
-
-
-### Application to Continual Learning
-
-When we directly optimize the KL divergence on a finite input dataset, for example, we align $$\opq$$ with the prior of $$\opp$$ where it matters most: on the predictions of the observed data.
-
-This is of particular interest in continual learning, where the prior for the next task is chosen to be the posterior from the previous task. In this case, the functional ELBO can be used to approximate the posterior of the previous model while incorporating new data.
-
-For two great papers that are very readable and provide further insights, see "*Continual learning via sequential function-space variational inference*" and "*Tractable function-space variational inference in Bayesian neural networks*", both by Rudner et al. (2022).
-
-## Comparison to FSVI in the Literature
-
-
-
-In practice, both works by Rudner et al. (2022), linearize the logitsThe logits are the final activations of the neural network before applying the softmax function (in multi-class classification). They are not to be confused with the pre-logits, e.g. embeddings before the final linear layer. (similar to a Laplace approximation) and use the DPI to show (in their notation):
-
-$$
-\mathbb{D}_{\mathrm{KL}}\left(q_{f(\cdot ; \boldsymbol{\Theta})} \| p_{f(\cdot ; \boldsymbol{\Theta})}\right) \leq \mathbb{D}_{\mathrm{KL}}\left(q_{\Theta} \| p_{\Theta}\right)
-$$
-
-which in my notation is equivalent to the first application of the DPI above:
-
-$$
-\Kale{\qof{\L...\given \x...}}{\pof{\L...\given \x...}} \le \Kale{\qof{\W}}{\pof{\W}}.
-$$
-
-They maximize the fELBO objective:
-
-$$
-\begin{aligned}
-\mathcal{F}\left(q_{\boldsymbol{\Theta}}\right) &=\mathbb{E}_{q_{f\left(\mathbf{x}_{\mathcal{D}} ; \boldsymbol{\Theta}\right)}}\left[\log p_{\mathbf{y} \mid f(\mathbf{X} ; \boldsymbol{\Theta})}\left(\mathbf{y}_{\mathcal{D}} \mid f\left(\mathbf{X}_{\mathcal{D}} ; \boldsymbol{\theta}\right)\right)\right]\\
-&\quad -\sup _{\mathbf{X} \in \mathcal{X}_{\mathbb{N}}} \mathbb{D}_{\mathrm{KL}}\left(q_{f(\mathbf{X} ; \boldsymbol{\Theta})} \| p_{f(\mathbf{X} ; \boldsymbol{\Theta})}\right),
-\end{aligned}
-$$
-
-which is equivalent to minimizing the information-theoretic objective:
-
-$$
-\E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\L... \given \x...}}{\pof{\L... \given \x...}},
-$$
-
-if we choose the $$\x...$$ to tighten the DPI inequality as much as possible (i.e. by "finding" the supremum).
-
-Using the inequality chain from above, we can sandwich their objective between a regular (negative) ELBO and the (negative) functional ELBO, we have derived above:
-
-$$
-\begin{aligned}
-&\E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\W}}{\pof{\W}} \\
-&\quad \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\L... \given \x...}}{\pof{\L... \given \x...}} \\
-&\quad \ge \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}.
-\end{aligned}
-$$
-
-**Why are they using logits instead of probabilities?** In practice, using the probabilities instead of logits when performing linearization is often cumbersome due to the non-linearity of the softmax functions, which requires Monte-Carlo sampling of the logits to obtain an approximation of the final probabilities. Furthermore, I speculate that sampling the logits can be more benign given that we often use ReLUs in the underlying neural networks. (Don't quote me too strongly on this, though.)
-
-Conceptually, this explains the derivation of their ELBO objective and also relates them to the 'purer' and simpler functional evidence bound derived above, but this raises the question of how these inequalities are different and what the gap between them tells us. Let's address this question next.
-
-## The Equality Case and Equivalence Classes
-
-When do we have equality? That is, when do we have:
-
-$$\Kale{\qof{\W}}{\pof{\W}} = \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}?$$
-
-And what does it tell us?
-
-As we have seen in the first part of this post, we have equality in the DPI if and only:
-
-$$\Kale{\qof{\W \given \Y..., \x...}}{\pof{\W \given \Y..., \x...}}=0$$.
-
-Given that we are trying to approximate the Bayesian posterior $$\pof{\w \given \Y..., \x...}$$ using $$\qof{\w}$$, this equality condition tells us that we would have to find the exact posterior for equality.
-Hence, it is unlikely that we will have equality in practice. From this, the next question immediately follows: what does this predictive prior term
-
-$$\Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}}$$
-
-provides us with?
-
-Another way to think about the gap between the two KL divergences is that one is parameter-based and the other one is not. This points to a deeper truth about overparameterized models used in deep learning:
-
-
-
-The functional KL divergences won't be affected by this as they are parameter-free and do not take into account the parameters of the model but only the predictions.
-The regular parameter-based KL divergence, however, would be affected by this---depending on the prior $$\pof{\w}$$, they might express differences between the parameter distributions that have no effect on the outputs.
-
-In other words, if the prior assigns different probability to otherwise equivalent parameters, this obviously changes the parameter posterior, while the outputs are invariant to these changes if the overall assigned probability to a given output remains the same.
-
-
-
-For example, the paper "Deep Ensembles: A Loss Landscape Perspective" by Fort et al. (2020) examines the similarity of the predictions of models trained from different initializations and shows that the prediction space has a multi-modal loss landspace. In the language of FSVI, this is similar to analyzing the function-space distances between different models.
-
-### Equivalence Classes
-
-Unless there are other considerations, it makes sense to use priors that assign the same density to parameters that are equivalent.
-Hence, for a given function $$\fof{\x ; \w}$$, which determines the likelihood $$\pof{\y \given \x, \w} \triangleq \pof{y \given \fof{\x ; \w}}$$, we can define an equivalence relation such that $$\w \sim \w'$$ if and only if $$\fof{\x; \w} = \fof{\x; \w'}$$ *for all* $$\x$$.
-This equivalence relation partitions the parameter space into equivalence classes:
-
-$$[\w] \triangleq \{\w' : \fof{x ; \w} = \fof{x ; \w} \quad \forall x \}.$$
-
-A prior $$\pof{\w}$$ induces a prior $$\hpof{[\w]}$$ over the equivalence classes:
-
-$$\hpof{[\w]} \triangleq \sum_{\w' \in [\w]} \pof{\w'}.$$
-
----or $$\int_{[\w]} \pof{\w'} \, d \w'$$ for continuous $$\w$$---with the corresponding model:
-
-$$
-\begin{aligned}
-\hpof{\y, [\w] \given \x} &\triangleq \hpof{\y \given \x, [\w]} \, \hpof{[\w]} \\
-&= \pof{\y \given \x, \w} \, \hpof{[\w]}.
-\end{aligned}
-$$
-
-
-
-
-### Consistency
-
-Importantly, the definition of the equivalence classes above is consistent with Bayesian inference:
-
-
-
-This is easy to show with using Bayes' rule:
-
-$$
-\begin{aligned}
-\hpof{[\w] \given \Dany} &= \hpof{\Dany \given [\w]} \, \hpof{[\w]} / \hpof{\Dany} \\
-&= \pof{\Dany \given \w} \sum_{\w' \in [\w]} \pof{\w'} / \hpof{\Dany} \\
-&= \sum_{\w' \in [\w]} \pof{\Dany \given \w'} \, \pof{\w'} / \hpof{\Dany} \\
-&= \sum_{\w' \in [\w]} \pof{\w' \given \Dany} \, \pof{\Dany} / \hpof{\Dany} \\
-&= \sum_{\w' \in [\w]} \pof{\w' \given \Dany}.
-\end{aligned}
-$$
-
-The last step follows from $$\hpof{\Dany}=\pof{\Dany}$$:
-
-$$
-\begin{aligned}
-\hpof{\Dany} &= \sum_{[\w]} \hpof{\Dany, [\w]} \\
-&= \sum_{[\w]} \sum_{\w' \in [\w]} \pof{\Dany, \w'} \\
-&= \sum_{\w'} \pof{\Dany, \w} \\
-&= \pof{\Dany}.
-\end{aligned}
-$$
-
-This also tells us that, for any $$\x$$ and $$\y$$:
-
-$$\pof{\y... \given \x...} = \hpof{\y... \given \x...}$$.
-
-Given this consistency, we don't have to differentiate between $$\hat\opp$$ and $$\opp$$ and can use $$\opp$$ interchangeably.
-The same holds for $$\opq$$.
-
-
-
-### Equality & Symmetries
-
-We can view $$[\w]$$ as a projection from $$\w$$ to its equivalence class $$[\w]$$. The DPI then gives us:
-
-$$
-\Kale{\qof{\W}}{\pof{\W}} \ge \Kale{\qof{[\W]}}{\pof{[\W]}}.
-$$
-
-And again: what does the gap between the two terms tell us?
-
-
-
-Let's look at a few examples to get a better understanding of this.
-
-#### 1. Trivial Constant Case
-
- Let $$\fof{\x ; \w} = 0$$ independent of any $$f$$. Then $$[\w] = [\w']$$ for any $$\w$$, $$\w'$$.
-
- For any approximate distribution $$\qof{\w}$$, the induced $$\Kale{\qof{[\W]}}{\pof{[\W]}}=0$$, while $$\Kale{\qof{\W}}{\pof{\W}}$$ also includes superfluous divergence.
-
-#### 2. Unused Parameter
-
- Let $$\y \given (\w_1, \w_2) = \w_1$$ deterministic but independent of $$\w_2$$. Then $$[(\w_1, \w_2)] = [(\w_1, {\w'}_2)]$$ for any $${\w'}_2$$ and $$[(\w_1,*)]\not=[({\w'}_1, *)]$$ for any $$\w_1 \not= \w'_1$$.
-
- $$\Kale{\qof{[\W]}}{\pof{[\W]}}=\Kale{\qof{\W_1}}{\pof{\W_1}}$$ captures the meaningful divergence between approximate and true distribution, while $$\Kale{\qof{\W}}{\pof{\W}}$$ also includes any divergence across $$\w_2$$ that has no effect on the predictions.
-
-#### 3. Periodic Parameter Space
-
- Finally, let's assume that the predictions are periodic in some way. That is, for example $$\y = \sin \w$$. We then have $$[\w] = [\w + 2\pi]$$.
-
- Further, let $$\pof{\w} = \operatorname{U}(\w; [0,2\pi \, N))$$ for some $$N$$ that determines the number of periods. Then, if we introduce another random variable $$K$$, that captures which period we are in, we can (again) use the chain rule to write:
-
- $$
- \begin{aligned}
- \Kale{\qof{\W}}{\pof{\W}} &= \Kale{\qof{\W \given \W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \given \W \in [K\,2\pi, (K+1)\,2\pi]}} \\
- &\quad + \Kale{\qof{\W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \in [K\,2\pi, (K+1)\,2\pi]}} \\
- &= \Kale{\qof{[\W]}}{\pof{[\W]}} \\
- &\quad + \Kale{\qof{\W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \in [K\,2\pi, (K+1)\,2\pi]}}.
- \end{aligned}
- $$
-
- This follows from the setup of this specific example. Finally, we have:
-
- $$\Kale{\qof{\W \in [K\,2\pi, (K+1)\,2\pi]}}{\pof{\W \in [K\,2\pi, (K+1)\,2\pi]}} \le \log N.$$
-
- So, if $$\opq$$ only had support in a single period for example, the difference between $$\Kale{\qof{\W}}{\pof{\W}}$$ and $$\Kale{\qof{[\W]}}{\pof{[\W]}}$$ would be $$\log N$$: the redundancy.
-
-### Predictive Prior
-
-How does the predictive prior term fit into this? The DPI again yields the answer:
-
-
-
-This tells us that the predictive prior term can at best measure the KL divergence between the equivalence classes of the parameters---and not between the parameters itself---but luckily, this is the more meaningful divergence anyway!
-
-For the equality cases, we observe that:
-
-1. we need a 1:1 mapping between parameters and equivalence classes for the first bound to be tight, and
-2. we need $$\Kale{\qof{[\W] \given \Y_n,\x_n,...,\Y_1,\x_1}}{\pof{[\W] \given \Y_n,\x_n,...,\Y_1,\x_1}} \to 0$$ for $$n \to \infty$$ for the second bound to be tight.
-
-For **2.**: as we know from the chain rule that
-
-$$\Kale{\qof{\Y_n,...\Y_1\given\x_n,...,\x_1}}{\pof{\Y_n,...\Y_1\given\x_n,...,\x_1}}$$
-
-is monotonically increasing in $$n$$, and it is bounded by $$\Kale{\qof{[\W]}}{\pof{[\W]}}$$ from above, it *must* convergeIt is a bounded monotonically increasing sequence.. So, when does it close the gap?
-
-To give intuition that it might do that, and without attempting to prove this formally, we can appeal to [*Bernstein von Mises* theorem](https://en.wikipedia.org/wiki/Bernstein%E2%80%93von_Mises_theorem), which states that the posterior distribution of the parameters converges to a Gaussian distribution with mean and variance given by the maximum likelihood estimate (MLE) as the number of data points tends to infinity *as long as the model parameters are identifiable, that is the true parameters we want to learn are unique, and that they have support*.
-
-For the evidence bound to be meaningful, we already know that we need support of the approximate distribution $$\opq$$ in the prior $$\opp$$---otherwise, the LHS is $$\infty$$. Moreover, realizing that we take an expectation over $$\qof{\Y_n ,..., \Y_1 \given \x_n ,..., \x_1}$$, we can decompose the KL term for the gap as:
-
-$$
-\begin{aligned}
-&\Kale{\qof{[\W] \given \Y_n,\x_n,...,\Y_1,\x_1}}{\pof{[\W] \given \Y_n,\x_n,...,\Y_1,\x_1}} \\
-&\quad = \E{\qof{\y_n,...,\y_1\given\x_n,...,\x_1}}{\Kale{\qof{[\W]\given \y_n, \x_n, ..., \y_1, \x_1}}{\pof{[\W]\given \y_n, \x_n, ..., \y_1, \x_1}}} \\
-&\quad = \simpleE{\qof{[\w']}}{\E{\qof{\y_n,..,.\y_1\given\x_n,...,\x_1, [\w']}}{\Kale{\qof{[\W]\given \y_n, \x_n, ..., \y_1, \x_1}}{\pof{[\W]\given \y_n, \x_n, ..., \y_1, \x_1}}}}.
-\end{aligned}
-$$
-
-That is, we sample a $$[\w'] \sim \qof{[\w']}$$ and then sample $$\y_n,...\y_1\given\x_n,...,\x_1$$ from the corresponding $$\qof{\y_n,...\y_1\given\x_n,...,\x_1, [\w']}$$ and marginalize over these. Crucially, $$[\w']$$ are the true parameters of the data-generating process for the inner KL divergence term. We thus take an expectation over KL terms fulfilling the conditions of the Bernstein von Mises theorem:
-
-$$
-\begin{aligned}
-\Kale{\qof{[\W] \given \y_n,\x_1...\y_1, \x_1}}{\pof{[\W] \given \y_n,\x_1...\y_1, \x_1}} \to 0.
-\end{aligned}
-$$
-
-In other words, for a given $$[w']$$, in the space of equivalence classes as defined previously, the equivalence class of all MLE solutions in the data limit, $$[MLE]$$, will be unique by definition---the model is identifiable---and match $$[\w']$$This follows from the consistency of MLE estimators but also from Berstein von Mises with a flat/uninformative prior.. As the MLE is prior-independent once there is support for it, both $$\opq$$ and $$\opp$$ will converge to the MLE $$[\w']$$ with sufficient data. Taking the expectation, this yields $$\Kale{\qof{[\W]\given \Y,..., \x...}}{\pof{[\W] \given \Y,..., \x...}} \to 0$$ for $$n \to \infty$$, and thus, we have:
-
-$$
-\begin{aligned}
-& \Kale{\qof{[\W]}}{\pof{[\W]}} = \\
-&\quad = \sup_{n\in \mathbb{N}} \Kale{\qof{\Y_n,...,\Y_1\given\x_n,...,\x_1}}{\pof{\Y_n,...,\Y_1\given\x_n,...,\x_1}}.
-\end{aligned}
-$$
-
-(Again, this is not a formal proof but an intuition for why the gap might close in the data limit.)
-
-In my opinion, this is a great result. We have shown both that the predictive prior term converges given our assumptions and that it converges to the symmetry-free parameter-based divergence in the data limit. This is a strong argument for the predictive prior term being meaningful and not just a technical trick.
-
-Let's appreciate one more thing: the predictive prior can consist of infinitely many data points and still converge to a finite value.
-
-## Parameter Priors vs. Predictive Priors
-
-What is the advantage of this all?
-
-In Bayesian deep learning, we often use parameter priors that are not meaningful and which also do not take parameter symmetries into account. For example, a unit Gaussian prior over the parameters of a neural network does not induce different predictions for different parameters necessarily. While this prior can be sensible from a parameter compression perspective (e.g. see Hinton and van Camp (1993)), this does not have to be the only consideration guiding us.
-
-With function priors and predictive priors, we can specify more meaningful priors because we can focus on the predictions and ignore the parameters. More importantly, this connects Bayesian approaches to data augmentation and other regularization techniques as we will see next.
-
-Given that priors over equivalence classes are difficult to express explicitly though, using the DPI to obtain a functional ELBO can be an easier way to express and approximate them.
-
-### Label Entropy Regularization
-
-All this also helps us gain a new perspective on label entropy regularization. The functional evidence bound can be lower-bounded using the chain rule by:
-
-$$
-\begin{aligned}
-\E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \Kale{\qof{\Y... \given \x...}}{\pof{\Y... \given \x...}} \\
-\ge \E{\qof{\w}}{-\log \pof{\Dany \given \w}} + \E{\pdata{\x}}{\Kale{\qof{\Y \given \x}}{\pof{\Y \given \x}}},
-\end{aligned}
-$$
-
-where we can expand the term under the second expectation to:
-
-$$
-\Kale{\qof{\Y \given \x}}{\pof{\Y \given \x}}=\CrossEntropy{\qof{\Y \given \x}}{\pof{\Y \given \x}} - \xHof{\qof{\Y \given \x}}.
-$$
-
-*Assuming that our prior yields a uniform distribution over the labels*, we can drop the cross entropy term because it is constant and obtain:
-
-$$
-\E{\qof{\w}}{-\log \pof{\Dany \given \w}} - \E{\pdata{\x}}{\xHof{\qof{\Y \given \x}}}.
-$$
-
-This is the same as an MLE minimization objective with an additional entropy regularization term $$-\xHof{\qof{\Y \given \x}}$$ for different $$\x$$ that prevents the model from overfitting to the labels and collapsing to the one-hot encoding of the labels.
-
-Thus, in the simplest approximation, the DPI and functional variational inference give us a new perspective on label entropy regularization.
-
-### Knowledge Distillation
-
-Obviously, assuming non-uniform prior predictions, $$\E{\pdata{\x}}{\Kale{\qof{\Y \given \x}}{\pof{\Y \given \x}}}$$ can be related to knowledge distillation in deep neural networks as introduced by Hinton et al. (2015).
-
-The main technical difference is that knowledge distillation is using the reverse KL divergence instead of the forward KL divergence, while the conceptual difference is that we are not distilling the knowledge from a teacher model but from the prior that we downweigh while also training our model on the data itself. However, the connection between knowledge distillation and continual learning using informative priors is manifest.
-
-## Conclusion
-
-In this blog post, we took a deep dive into the data processing inequality (DPI) and its surprisingly far-reaching implications for modern Bayesian deep learning. By carefully examining the assumptions, equality conditions, and chain rule of the DPI, we arrived at an intuitive understanding of why function-space variational inference (FSVI) can be such a powerful tool. The DPI perspective illuminates how FSVI side-steps issues with high-dimensional parameter spaces by focusing on matching Bayesian predictive posteriors.
-
-Reasoning about parameter equivalence classes under the lens of the DPI, we saw how predictive KL divergences can capture meaningful differences between models while ignoring superficial discrepancies due to symmetries. This provides a fresh perspective on the advantages of predictive priors over standard parameter priors commonly used in Bayesian neural networks.
-
-While our treatment only scratched the surface of the full mathematical story, the intuitions we developed allowed us to re-derive key results from the literature and uncover deep connections between seemingly disparate methods like entropy regularization, continual learning, and knowledge distillation. The examples and proofs peppered throughout solidified the core concepts.
-
-More than a bag of technical tricks, the DPI reveals itself to be a powerful conceptual tool for reasoning about models, objectives, and algorithms. I hope this post inspires the reader to seek the fundamental principles underpinning machine learning innovations and to use those principles as a guide for future research. With a solid grasp of foundational tools like the DPI, we can all contribute to demystifying and unifying the rapidly evolving field of Bayesian deep learning.
-
----
-
-**Acknowledgements.** Many thanks to [Freddie Bickford Smith](https://fbickfordsmith.com/) for very helpful comments and feedback on this post and to [Tim Rudner](https://timrudner.com/) for additional pointers to relevant literature and feedback on the FSVI section in particular 🤗
-
diff --git a/_posts/2024-05-07-exploring-meta-learned-curiosity-algorithms.md b/_posts/2024-05-07-exploring-meta-learned-curiosity-algorithms.md
deleted file mode 100644
index 4d3bcab5..00000000
--- a/_posts/2024-05-07-exploring-meta-learned-curiosity-algorithms.md
+++ /dev/null
@@ -1,484 +0,0 @@
----
-layout: distill
-title: Exploring Meta-learned Curiosity Algorithms
-description: This blog post delves into Alet et al.'s ICLR 2020 paper, Meta-learning curiosity algorithms, which introduces a unique approach to meta-learning curiosity algorithms. Instead of meta-learning neural network weights, the focus is on meta-learning pieces of code, allowing it to be interpretable by humans. The post explores the two meta-learned algorithms, namely Fast Action Space Transition (FAST) and Cycle-Consistency Intrinsic Motivation (CCIM).
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Batsirayi Mupamhi Ziki
- affiliations:
- name: University of Cape Town
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-exploring-meta-learned-curiosity-algorithms.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Introduction
- - name: Background
- subsections:
- - name: Reinforcement Learning
- - name: Meta-learning and Meta-RL
- - name: Random Network Distillation
- - name: BYOL-Explore
- - name: Meta-learning curiosity algorithms
- subsections:
- - name: Meta-Learned Components and their DAGs
- - name: Method
- - name: FAST
- - name: CCIM
- - name: Experiments
- subsections:
- - name: Emperical Design
- - name: Empty grid-world
- - name: Deep sea
- - name: Results
- - name: Discussion
- - name: Conclusion
----
-
-## Introduction
-
-Dealing with environments with sparse rewards, i.e., feedback comes at a low frequency, in reinforcement learning (RL) requires meaningful exploration.
-One way to encourage the RL agent to perform meaningful exploration is by instilling intrinsic motivation into the agents. This intrinsic motivation usually comes in the form of curiosity. As Schmidhuber highlighted : One becomes curious as soon as one believes there's something about the world that one does not know. It is because of this that curiosity or intrinsic rewards are usually predictive errors. For instance, an RL agent equipped with a world model is given the current state of the environment, $$s_t$$, and attempts to predict the next state, $$s_{t+1}$$. The error in this prediction is the intrinsic reward. As the world model improves one should expect the intrinsic rewards to decrease as the agent's knowledge about environment increases. This is known as curiosity-driven exploration.
-
-Now there has been success with curious agents solving environments with sparse rewards . Curiosity algorithms such as Random Network Distillation (RND) and BYOL-Explore are hand-designed and are able to perform well across different environments.
-However, in the 2020 paper , Meta-learning curiosity algorithms, Alet et al. took a unique approach to discovering new curisoity algorithms. They did this by meta-learning pieces of code.
-Similar to the code segments used by researchers when crafting curiosity algorithms such as neural networks with gradient descent mechanisms, trained objective functions, ensembles, buffers, and various regression models.
-Two new interpretable algorithms were learned by meta-learning these pieces of code: Fast Action Space Transition (FAST) and Cycle-Consistency Intrinsic Motivation (CCIM).
-It is these two algorithms that we will explore and compare their behaviour to our baselines: RND and BYOL-Explore.
-
-The roadmap for exploring FAST and CCIM is organised as follows. We begin with a brief introduction to RL, meta-learning, and meta-reinforcement learning (meta-RL). Next, we provide concise explanations of how curiosity-driven exploration baselines, RND and BYOL-Explore, operate. Subsequently, we delve into the discovery process of FAST and CCIM. Following that, we explore the intricacies of FAST and CCIM, evaluating their performance and studying their behaviour in both the empty grid-world environment and the `bsuite` deep sea environment. We then compare them to curiosity-driven baselines and a non-curious agent. Finally, we conclude our journey.
-
-## Background
-
-### Reinforcement Learning
-
-RL is inspired by how biological systems learn as animals are to able learn through trial-and-error. In RL we have an agent that tries to maximise the sum of rewards it recieves by learning from its interactions with the environment. This agent-environment interaction is usually modelled as a Markov decision process (MDP). Figure 1 below illstrustates this agent-environment interaction.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/MDP.png" class="img-fluid" width="100px" %}
-
- Figure 1. The agent-environment interaction as a MDP. Taken from .
-
-
-From the figure we can see that the agent observes a state and then takes action. The agent can then decide on its next action based on the next state it observes and the rewards it receives from the critic in the environment. The critic decides on what reward the agent receives at every time-step by evaluating its behaviour.
-
-As Sutton et al. highlighted in Figure 1 can be misleading though. It implies that the agent-environment boundary is similar to the physical boundary between an organism's entire body and the outside world. In RL we consider anything that the agent cannot change through its actions as the environment. For example, if a human was an RL agent their skeletal structure or their muscles could be considered part of the environment. So we can then see that when it comes to RL we have two types of environments: The internal environment, such as sensory organs of an animal, and the external environment. Also, the reward the agent receives is not always from the external environment. The rewards can be seen as reward signals like a human's brain releasing dopamine when one achieves an objective.
-Thus, the critic can also be in inside the RL agent.
-The figure below shows an extended view of the agent-environment interactions.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/extended_mdp.png" class="img-fluid" width="100px" %}
-
- Figure 2. The extended agent-environment interaction. Taken from .
-
-
-Singh et al. highlighted in that Figure 2 shows that an RL agent has a motivational system since the critic can be within the internal environment of the agent. And this motivational system should ideally remain consistent across a wide range of diverse environments. Since we can view the critic as being inside the agent we can instil intrinsic motivation into the agent. This means that the agent can receive two types rewards, namely extrinsic rewards from the external environments and intrinsic rewards from the internal environment.
-Singh et al. () highlighted the advantages of endowing an agent with intrinsic motivation. They pointed out that an agent equipped with a collection of skills learned through intrinsic reward can more easily adapt to and learn a wide variety of extrinsically rewarded tasks compared to an agent lacking these skills.
-
-### Meta-RL and Meta-learning
-
-The next stop on our journey takes us to meta-learning. Meta-learning is about learning how to to learn. The goal is for meta-learning agents to enhance their learning abilities over time, enabling them to generalise to new, unseen tasks. Meta-learning involves two essential loops: the inner loop and the outer loop. In the inner loop, our learning algorithm adapts to a new task using experiences obtained from solving other tasks in the outer loop, which is referred to as meta-training .
-
-The inner loop addresses a single task, while the outer loop deals with the distribution of tasks. Figure 3 illustrates this concept of meta-learning.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/meta-learning.png" class="img-fluid" %}
-
-
- Figure 3. An illustration of meta-learning. Taken from .
-
-Moving into the intersection of meta-learning and reinforcement learning (RL) is meta-RL, where the agent learns how to reinforcement learn . In meta-RL, the agent aims to maximise the sum of rewards from a distribution of MDPs.
-
-In basic RL, we have an algorithm $$f$$ that outputs a policy, mapping states to actions. However, in meta-RL, our algorithm has meta-parameters $$\theta$$ that outputs $$f$$, and $$f$$ then produces a policy when faced with a new MDP.
-Figure 4 illustrates that the meta-RL process. Note that in the outer loop the meta-parameters $$\theta$$ are updated.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/meta-rl.png" class="img-fluid" %}
-
-
- Figure 4. An illustration of meta-RL. Taken from .
-
-
-### Random Network Distillation
-
-We now move onto our curiosity-driven exploration baselines. The first baseline that we will briefly discuss is RND . RND works by having two neural networks. One is the predictor network and the other is the target network. The target network is randomly initialised and its parameters stay fixed during training. Given a state, $$s_t$$, it then outputs the feature representation of that state $$f_t$$. The predictor network then tries to predict to $$f_t$$ given $$s_t$$ as well. The error in this prediction is then the intrinsic reward, $$r_i$$, given to the agent and it is given by the following formula,
-
-$$
-r_i=\|\hat{f}_t - f_t\|_2^2,
-$$
-
-where $$ \hat{f}_t$$ is the output of the predictor network. The formula above also serves as the loss function of the predictor network.
-We normalise $$r_i$$ by dividing it by the running estimate of the standard deviations of
-the intrinsic returns. We do this because the intrinsic rewards can be very different in various environments. Normalising the intrinsic rewards make it easier to pick hyperparameters that work across a wide range of environments. As the agent explores more the predictor network will get better and the intrinsic rewards will decrease. The key idea in RND is that the predictor network is trying to predict the output of a network that is deterministic, the target network.
-The figure below illustrates the process of RND.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/RND.png" class="img-fluid" %}
-
- Figure 5. The process of RND. Taken from .
-
-
-### BYOL-Explore
-
-BYOL-Explore builds upon Bootstrap Your Own Latent (BYOL) , a self-supervised learning algorithm used in computer vision and representation learning. BYOL-Explore is similar to RND in that there's a network that tries to predict the output of a target network. In BYOL-Explore we have an online network that consists of an encoder, a close-loop recurrent neural network (RNN) cell, an open-loop RNN cell and a predictor. While the target network just consists of an encoder. The key difference is that the target's network parameters do not stay fixed like in RND. We update the target network's parameters using the exponential moving average (EMA) of the online network's predictor parameters. The update is performed using the formula below:
-
-$$
-\phi \leftarrow \alpha\phi + (1-\alpha)\theta.
-$$
-
-In the above equation, $$\phi$$, is the target network's parameters, $$\theta$$ is the online network's predictor parameters and $$\alpha$$ is the EMA smoothing factor. In our implementation of BYOL-Explore we do not make use of the RNN cells as we are dealing with simple environments, we call our implementation BYOL-Explore Lite.
-In our implementation the online network is composed of a multilayer perceptron (MLP) encoder and a predictor. The target network, $$h$$, is just composed of an MLP encoder. In the BYOL-Explore Lite process the current state of the environment, $$s_t$$, is inputted into the encoder $$f$$, which outputs a feature representation of the state, $$f(s_t)$$. This feature representation is then passed to both the RL agent and the predictor $$g$$. The RL agent uses $$f(s_t)$$ to decide on its next action and determine the value of that state. The predictor uses $$f(s_t)$$ to predict $$h(s_{t+1})$$, i.e., the predictor is attempting to predict the target network's output for the next state. There are two losses namely the encoder loss and the predictor loss. The predictor loss is given by,
-
-$$
-\mathcal{L}_p=\left\|\frac{g(f(s_{t}))}{\|g(f(s_{t}))\|_2}-\frac{h(s_{t+1})}{\|h(s_{t+1})\|_2}\right\|_2^2.
-$$
-
-Since the RL agent and the predictor both make use of the online network's encoder its loss is given by the sum of the RL loss and the predictor loss. Importantly, the loss $$\mathcal{L}_p$$ serves as the intrinsic reward that the RL agent receives at each step. We normalise the intrinsic rewards by dividing it by the EMA estimate of their standard deviation.
-
-BYOL-Explore Lite also makes use of something known as reward prioritisation. Reward prioritisation involves focusing on parts of the environment where the agent receives high intrinsic rewards while disregarding those with low intrinsic rewards. This enables the agent to concentrate on areas it understands the least. Over time the previously ignored areas with low intrinsic rewards become the priority for the agent. To do this we take the EMA mean relative to the successive batch of normalised intrinsic rewards, $\mu$. Note that $\mu$ is used as a threshold
-to separate the high intrinsic rewards and the low intrinsic rewards. Therefore, the intrinsic rewards that agent obtains after reward prioritisation is,
-
-$$
-i_t=\max(ri_t-\mu,\,0),
-$$
-
-where $ri_t$ is the normalised intrinsic reward.
-
-## Meta-learning curiosity algorithms
-
-Alet et al. view curiosity as a mechanism that is found through natural selection. As a result they turn to meta-learning to discover new curiosity algorithms.
-In this case the outer loop searches over the curiosity algorithm space while the inner loop performs the standard RL procedure.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/mlc.png" class="img-fluid" %}
-
- Figure 6. The process of how the meta-learned curiosity algorithm should work. Taken from .
-
-
-In the above figure we can see that the curiosity algorithm, $$\mathcal{C}$$, takes in the state and reward from the environment and then feeds proxy reward $$\hat{r}$$ to the RL agent. The RL algorithm used is a fully-specified algorithm, i.e., all its hyperparameters are specified. There were two stages in the authors search because the module $$\mathcal{C}$$ is made of two components.
-The first component, $$\mathcal{I}$$, calculates the intrinsic reward given the current state, next state and the action taken. The second component, $$\chi$$, then takes the extrinsic reward, the intrinsic reward and the current normalised time step to combine them and output $$\hat{r}$$.
-
-### Meta-Learned Components and their DAGs
-
-As mention earlier Alet et al. focused on meta-learning pieces of code or rather meta-learning in a space of programs or operations. The programs and operations are represented in a domain-specific language (DSL). The DSL used to find component $$\chi$$ consisted of operations such as arithmetic, Min, Max and more.
-While the DSL used to find component $$\mathcal{I}$$ consisted of programs such as neural networks complete with gradient-descent mechanisms, L2 distance calculation, and ensembles of neural networks and more. Component $$\mathcal{I}$$'s DSL can describe many other hand-designed curiosity algorithms in literature, such as RND.
-
-The components $$\mathcal{I}$$ and $$\chi$$ are represented as Directed Acyclic Graphs (DAGs). The DAGs consist of the following types of modules:
-- Input modules: These are the inputs we put in each component of module $$\mathcal{C}$$.
-- Parameter and Buffer modules: This module either consists of the weights of a neural network which can be updated via back-propagation or First In, First Out queues that output a finite list of the most recent $$k$$ inputs.
-- Functional modules: This type of module calculates the output given some input.
-- Update modules: These modules can add real-valued outputs to the loss function of the neural network or add variables to buffers.
-
-The DAGs also have an output node which is a single node and the output of this node is the output of the entire program. To make these ideas more concrete, let us look the DAG that describes RND.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/RND_DAG.png" class="img-fluid" %}
-
- Figure 7. The DAG of RND. Taken from .
-
-
-The blue rectangles represent the input modules, and we can see from the figure that the inputs are states from the environment.
-The parameter modules are the gray rectangles and these are the parameters of the target network and the predictor network.
-Note that the target network's parameters are given by $$\theta$${1} and the predictor network's parameter's are given by $$\theta$${2}.
-The functional modules are the white rectangles and these are the neural networks. The update module is the pink rectangle which is the loss function.
-
-The output node is the green rectangle and is the L2 distance between the output of predictor network and the target network. This is the loss function described in the RND section. Note that the $$\theta$${2} rectangle has a pink border and a pink arrow, this indicates that it can be updated via back-propagation. While the $$\theta$${1} rectangle has black border and a black arrow indicating the parameters are not updated via back-propagation. Also note that the functional module that makes use of those parameters has the word "Detach" indicating the gradient information is not flowing back. Recall that $$\theta$${1} represents the parameters of the target network, which remain fixed, and $$\theta$${2} represents the parameters of the predictor network, which are updated during training.
-
-Now a very important idea is that the DAGs used in the paper have polymorphic types for the inputs and outputs. There are four types:
-- $$\mathbb{R}$$, the real numbers.
-- $$\mathbb{S}$$, the state space of the environment.
-- $$\mathbb{A}$$, the action space of the environment.
-- $$\mathbb{F}$$, the feature space.
-
-The instantiation of some types depends on the environment. For example in Figure 7, if $$\mathbb{S}$$ is an image then both the target network and the predictor network are instantiated as a convolutional neural network.
-If $$\mathbb{S}$$ is just an array of numbers then target network and the predictor network are fully connected neural networks. We now look at the method used to find the components $$\mathcal{I}$$ and $$\chi$$.
-
-### Method
-
-We now turn our attention to how component $$\mathcal{I}$$ was searched for. Alet et al. decided to focus on environment that has sparse rewards. They chose an image-based grid-world. In this environment the agent is tasked with finding the goal position and only obtains a reward if it finds the goal position. This environment has sparse rewards as the agent only receives feedback once it finds the goal position. They limited the number of operations that component $$\mathcal{I}$$ could perform to 7 so that the search space remains manageable, and we can still interpret the algorithm. They focused on finding a component $$\mathcal{I}$$ that optimises the number of distinct cells visited. From the search 13 of the top 16 components found where variants of FAST and 3 of them were variants of CCIM. We will cover FAST and CCIM in the upcoming sections.
-
-For the component $$\chi$$ they focused on the Lunar Lander environment as it has a strong external reward signal. The algorithm used to output the intrinsic reward was a variant of RND. The main difference was that instead of single neural network for the predicator network an ensemble is used. This algorithm came from a preliminary set of algorithms that all resemble RND. The best reward combiner found was,
-
-$$
-\hat{r}_t = \frac{(1+ri_t-t/T)\cdot ri_t+ r_t\cdot t/T}{1+ri_t}.
-$$
-
-Here $$r_t$$ is the external reward, $$t$$ is the current time-step, $$T$$ is the maximum steps possible in the episode, and $$ri_t$$ is the intrinsic reward.
-However, in this blog post we decided not to focus on the reward combiner $$\chi$$ but instead focus on FAST and CCIM.This decision arises because we felt our exploration of the reward combiner was not exhaustive enough..
-
-
-
-### FAST
-
-FAST is very simple algorithm in that it only contains one neural network. Below is the DAG of FAST.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/FAST_diagram.png" class="img-fluid" %}
-
- Figure 8. The DAG of FAST. Taken from .
-
-
-This single neural network in FAST is a policy-mimicking network, $$\hat{\pi}$$. The network $$\hat{\pi}$$ tries to predict what action the agent took given a state of the environmentWe assume the environment has a discrete action space but this not be the case.. Then the loss of the policy-mimicking network will be the negative log likelihood (NLL) loss. Note that by looking at the DAG the output of FAST is not the same as loss function of the policy-mimicking network. The output is given by,
-
-$$
-ri_t=\|\hat{\pi}(s_{t+1})-\hat{\pi}(s_{t})\|_2.
-$$
-
-This is different from RND and BYOL-Explore Lite. The intrinsic reward is not given by a predictive error or the loss function of one of the networks in the program.
-We understood the above formula as the L2 difference between the logits of the current state and the next state.
-The agent is then rewarded if the next state's logits is different from the current state.
-Importantly, the agent isn't rewarded for taking a different action in the next state. Alet et al. pointed out that if the policy-mimicking network has a uniform distribution over the action space in all states, the agent will receive an intrinsic reward of zero. Therefore, in environments where the action probability distributions outputted by the policy-mimicking network vary across states, we expect this algorithm to generate intrinsic rewards.
-We hypothesize that this algorithm may not perform well in environments where the optimal policy requires the agent to visit states with very similar action probability distributions.
-While the agent explores by going to different states, ideally, we wish for the intrinsic rewards to decrease as the agent explores. Looking at the output of FAST it is not clear to use how the intrinsic reward decreases, and we expect that this could cause issues.
-
-
-
-### CCIM
-
-CCIM took us quite a while to understand and process. Let us first go through its DAG below.
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/CCIM_diagram.png" class="img-fluid" %}
-
- Figure 9. The DAG of CCIM. Taken from .
-
-
-We can see that there are 3 neural networks: a random network, a random and forward network, and a backward network. The parameters $$\theta$${1} are the parameters of the random network, $$\theta$${2} are the parameters of the backward network, and $$\theta$${3} are the parameters of the random and forward network. Looking at the black border of $$\theta$${1}'s rectangle we can see that the random network's parameters stay fixed during training like in RND. Let us denote the random network as
-$$ r_{\theta_1}$$, the backward network as $$b_{\theta_2}$$, and the random and forward network as $$ fr_{\theta_3}$$.
-Let us look at the loss function of the $$b_{\theta_2}$$ and $$ fr_{\theta_3}$$. The loss function of $$b_{\theta_2}$$ is given by,
-
-$$
-\mathcal{L}_b=\|b_{\theta_2}(fr_{\theta_3}(s_t))-r_{\theta_1}\|_2+\|b_{\theta_2}(fr_{\theta_3}(s_{t+1}))-fr_{\theta_3}(s_t)\|_2,
-$$
-
-and the loss function for $$fr_{\theta_3}$$ is
-
-$$
-\mathcal{L}_f=\|b_{\theta_2}(fr_{\theta_3}(s_t))-r_{\theta_1}\|_2.
-$$
-
-Note the first term in $$\mathcal{L}_b$$ is the same as $$\mathcal{L}_f$$. The intrinsic reward, i.e., the output of this program is given by,
-
-$$
-ri_t=\|b_{\theta_2}(fr_{\theta_3}(s_{t+1}))-b_{\theta_2}(fr_{\theta_3}(s_t))\|_2.
-$$
-
-Looking at the equations, we can see that CCIM borrows ideas from the cycle-consistency seen in the Image-to-Image Translation literature. The cycle-consistency ensures that if you translate from space $$A$$ to space $$B$$, then given space $$B$$, you should be able to translate back to space $$A$$. To see how CCIM applies this, let us turn our attention to $$\mathcal{L}_f$$'s equation. The $$fr_{\theta_3}$$ network applies a random embedding to state $$s_t$$. It then forwards this random embedding to the "next state". The $$b_{\theta_2}$$ network then takes this forwarded random embedding of state $$s_t$$ and undoes the forward transformation so that we end up again with just the random embedding of state $$s_t$$. Now, the random embedding that $$fr_{\theta_3}$$ applied should match the random embedding that $$r_{\theta_1}$$ applied to the state $$s_t$$ for the loss to be minimised.
-In other words, once we apply a forward transformation to the random embedding of the state, we should be able to undo that transformation and end up where we started.
-
-Let us look at the second term in $$\mathcal{L}_b$$ given by $$\|b_{\theta_2}(fr_{\theta_3}(s_{t+1}))-fr_{\theta_3}(s_t)\|_2$$. We apply a forward and then a backward transformation to the random embedding of state $$s_{t+1}$$, so we should end up with just the random embedding of state $$s_{t+1}$$. We then apply $$fr_{\theta_3}$$ to state $$s_t$$ and end up with the forwarded random embedding of state $$s_t$$, which should equal the random embedding of $$s_{t+1}$$.
-
-The intrinsic reward confuses us. Looking at the DAG of CCIM, we see that the output is given by the L2 distance between $$\mathcal{L}_f$$ and $$\mathcal{L}_b$$; hence, we initially thought the intrinsic reward was given by $$ \|b_{\theta_2}(fr_{\theta_3}(s_{t+1}))-fr_{\theta_3}(s_t)\|$$. The difference between this equation and the original intrinsic reward equation is that the backward model, $$b_{\theta_2}$$, is not applied to the $$fr_{\theta_3}(s_t)$$ term. Looking at the original formula of the intrinsic reward, we can see that it is just the difference between the random embedding of
-the current state and the next stateIf we assume that the backward network can undo the forward transformation., so it is not clear to us as to how the intrinsic reward
-will decrease as the agent explores.
-Not only that, but we also noticed unexpected behaviour in the loss function of the $$fr_{\theta_3}$$ network in our experiments. We then watched Alet et al.'s presentation of their paper to see where we went wrong, and we noticed in the presentation they swapped the labels for $$fr_{\theta_3}$$ and $$b_{\theta_2}$$ networks.
-After reaching out to them about this discrepancy, they did confirm that the equations in the paper are correct, and the labels in the talk are wrong. So for our implementation, we used the equations as found in the paper.
-
-#### CCIM-slimmed
-
-Through our communication with them, Alet et al. recommended we try ablations of CCIM and they suggested the following slimmed-down version of CCIM:
-- Network $$r_{\theta_1}$$ remains unchanged and its parameters stay fixed.
-- Network $$fr_{\theta_3}$$ changes to just being a forward network, $$f_{\theta_3}$$.
-- The loss function of the $$f_{\theta_3}$$ is now $$\mathcal{L}_f=\|f_{\theta_3}(r_{\theta_1}(s_t))-r_{\theta_1}(s_{t+1})\|_2^2$$.
-- Network $$b_{\theta_2}$$'s loss function, $$\mathcal{L}_b$$, also changes. $$\mathcal{L}_b=\|b_{\theta_2}(r_{\theta_1}(s_{t+1}))-r_{\theta_1}(s_{t})\|_2^2$$.
-- The intrinsic reward is now $$\mathcal{L}_f+\mathcal{L}_b$$.
-
-This slimmed down version of CCIM was much easier to implement. Since the sum of the loss functions also act as the intrinsic reward it is clearer to us as to how the intrinsic rewards will decrease as the agent explores. As agent explores both the forward and backward networks become better at predicting what the random embedding of the next state and previous state will be, respectively.
-
-## Experiments
-
-### Emperical Design
-
-
-In devising the methodology for our experiments, we sought guidance from the principles outlined in Patterson et al.'s cookbook, "Empirical Design in Reinforcement Learning" . Our codebase is derived from PureJaxRL and can be found [here](https://github.com/Ziksby/MetaLearnCuriosity).
-Specifically, we leverage PureJaxRL's Proximal Policy Optimization (PPO) implementation as our chosen reinforcement learning (RL) algorithm.
-We compare each meta-learned curiosity algorithm to a non-curious agent (normal PPO) and our baselines.
-The foundation of our experiments is laid upon a JAX implementation of Minigrid's grid-world environment , which uses gymnax's API . Additionally, we make use of gymnax's deep sea environment implementation as well.
-
-Each RL agent undergoes training for 500,000 time steps across four vectorized environments, employing 30 seeds for each RL algorithm.
-To assess performances on the environments, we calculate the average episode return across seeds at the end of training with a 95% confidence interval determined through the percentile bootstrapped method.
-We are not just interested in how well these curiosity algorithms perform but also in understanding the behaviour of these algorithms.
-We therefore also visualise the sample standard deviation during training to see the performance variations. This assists us in seeing how consistent the behaviour is for each curiosity algorithm and the normal PPO algorithm.
-
-Now since we are not testing the reward combiner found, it is not clear how we should combine the external reward and the intrinsic reward. However, we treat both the external reward and the intrinsic reward as episodic and therefore we use the following formula, $$ \hat{r} = r_t + \lambda ri_t $$, where $$\lambda$$ is some weight factor.
-These are the optimal values we found for $$\lambda$$ for each curiosity algorithm:
-
-- FAST: $$\lambda = 0.003$$.
-- CCIM-slimmed: $$\lambda = 0.17$$.
-- CCIM: $$\lambda = 0.003$$.
-- BYOL-Explore Lite: $$\lambda = 0.006$$
-- RND: $$\lambda = 0.2$$.
-
-For FAST, CCIM, and CCIM-slimmed we normalise the intrinsic reward using the same method as RND. Next we describe the environments we use in more detail.
-
-### Empty grid-world
-
-The empty grid-world is a very simple environment. As mentioned earlier the agent's task is to reach the goal position. The size is $$16\times 16$$ and the maximum number of steps is 1024.
-In our implementation the agent starts at the bottom left corner and has to reach the top right corner. The reward that agent recieves if it finds the goal is `1 - 0.9 * (step_count / max_steps)`. The gif shows a RL agent exploring the environment to reach the goal.
-
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/anim_BYOL_0.gif" class="img-fluid" %}
-
-The empty grid-world environment.
-
-
-### Deep sea
-
-The deep sea environment is one of the `bsuite` environments developed by Google
-Deepmind .
-This is a $$ N \times N$$ grid environment that focuses on testing the exploration capabilities of an RL algorithm. The figure below shows the environment.
-{% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/deepsea.png" class="img-fluid" %}
-
- Figure 10. The Deep sea environment. Taken from .
-
-The agent starts at the top left corner and its goal is to reach the bottom right corner.
-At each time step the agent descends one row. The agent can either go left or right. There's a small penalty of going right which is $$ −0.01/N $$ while going left just gives a reward of zero. The agent receives a reward of 1 if it finds the treasure at the bottom right corner.
-The max number of steps in the environment is $$N$$. Therefore, the optimal policy is to go right at every time step ignoring the greedy action. In our experiments we set $$N=10$$.
-
-### Results
-
-#### CCIM
-
-We start with the deep sea environment. The left of Figure 11 shows the sample standard deviation during training. We only show it for the first 10,000 steps because after that we notice the graphs plateau. We see that RND and BYOL-Explore Lite produce the most consistent agents in the deep sea environment. And CCIM-slimmed produces more consistent agents than CCIM and PPO. Looking at the right of Figure 11 we can see the mean episode return across the 30 seeds with the 95% confidence intervals. RND, BYOL-Explore, and CCIM-slimmed all perform better than PPO. However, CCIM does performs roughly the same as PPO at the end of training. From our experiments we also noticed that intrinsic rewards produced by CCIM increase and then plateau. The CCIM random and forward network's loss continued to increase during training as well.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/DeepSea-bsuite_CCIM_mean_seeds_std.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/DeepSea-bsuite_ccim_mean_seeds_CI.png" class="img-fluid" %}
-
-
-
-
- Figure 11. The sample standard deviation during training (left) and the average episode return (right) in deep sea environment.
-
-
-Next we move onto the empty grid-world. Looking at the left of Figure 12 we can see that all curiosity algorithms produce more consistent agents than PPO due to their sample
-standard deviations being lower. CCIM and CCIM-slimmed both actually produce more consistent agents than RND and PPO in this environment. The right of Figure 12 also indicate that CCIM performed much better in the empty grid-world and was closer to the baselines. However in this environment we did once again notice the raw intrinsic reward
-increased then plateaued and the loss of random forward network increased during training. It should also be noted the confidence intervals of all the RL algorithms overlap in the empty grid-world environment.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/Empty-misc_CCIM_mean_seeds_std.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/Empty-misc_CCIM_mean_seeds_CI.png" class="img-fluid" %}
-
-
-
-
- Figure 12. The sample standard deviation during training (left) and the average episode return (right) in empty grid-world environment.
-
-
-Next we decided to plot the RND, BYOL-Explore Lite, normal PPO, CCIM and CCIM-slimmed heatmaps in Figure 13 and 14. To make the heatmaps we looked at the best 15 seeds for
-each algorithm and kept track of the paths each seed took. Looking at Figure 13 and Figure 14, we can see that the CCIM and CCIM-slimmed covered more of the map than RND and BYOL-Explore Lite. However, they only covered slightly more of the map than PPO.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/heatmap_rnd_30.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/heatmap_byol_lite_30.png" class="img-fluid" %}
-
-
-
- Figure 13. Heatmaps of the RND agent (left) and the BYOL-Explore Lite agent (right) in empty grid-world.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/heatmap_ccim_30.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/heatmap_ccim_slimmed_30.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/heatmap_dis_ppo_30.png" class="img-fluid" %}
-
-
-
-
-
- Figure 14. Heatmaps of the CCIM agent (left), CCIm-slimmed agent (middle), and the normal PPO agent (right) in empty grid-world.
-
-
-
-
-#### FAST
-
-Let us now turn our attention to how FAST performed. We began with the deep sea environment. In Figure 15 we plot the sample deviation for the first 10,000 steps, as we observe no significant difference beyond this point.
-The left side of Figure 15 indicates that PPO and our curiosity-driven baselines produces more consistent agents than FAST as they exhibit a lower sample standard deviation.
-
-On the right side of Figure 15, we see that FAST, similar to CCIM, performs poorly on this environment compared to our baselines. Notably, during training we noticed the intrinsic reward of the FAST agents also increased.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/DeepSea-bsuite_FAST_mean_seeds_std.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/DeepSea-bsuite_FAST_mean_seeds_CI.png" class="img-fluid" %}
-
-
-
-
- Figure 15. The sample standard deviation during training (left) and the average episode return (right) in deep sea environment.
-
-
-The right side of Figure 16 shows FAST's performance in the empty grid-world is better than its performance in the deep sea environment; it is now comparable to our baselines despite its intrinsic rewards also increasing over time. Once again, similar to CCIM's results, we observe overlapping confidence intervals in the empty grid-world. Figure 16 shows that not only has its performance improved in the empty grid-world but it now produces more consistent agents than RND and PPO as its sample standard deviation is lower.
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/Empty-misc_FAST_mean_seeds_std.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/Empty-misc_FAST_mean_seeds_CI.png" class="img-fluid" %}
-
-
-
- Figure 16. The sample standard deviation during training (left) and the average episode return (right) in empty grid-world environment.
-
-
-We once again plot the heatmap of FAST and compare it to PPO's heatmap using the best 15 seeds. When comparing Figure 17 (left) with both Figure 17 (right) and Figure 13, we observe that FAST covered more of the grid-world than PPO, BYOL-Explore Lite, and RND.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/heatmap_fast_30.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-exploring-meta-learned-curiosity-algorithms/heatmap_dis_ppo_30.png" class="img-fluid" %}
-
-
-
- Figure 17. Heatmaps of the FAST agent (left) and the normal PPO (right) in empty grid-world.
-
-
-
-## Discussion
-
-Alet et al. provided a unique approach to meta-learning. The performance of CCIM and FAST in the empty grid-world then did not surprise us as that was the environment used to search for the algorithms. Note in Figure 17 that the 15 best seeds of FAST covered more of the map, i.e., most of the seeds took different parts to the goal compared to PPO.
-However for the CCIM and CCIM-slimmed heatmaps we notice that these algorithms only slightly covered more of the map then PPO. It should be noted that by looking at the heat maps that
-CCIM-slimmed, CCIM, and FAST both covered more of the map than our baselines which makes sense given Alet et al. looked for curiosity that optimise the number of distinct cells visited when searching for the curiosity algorithms.
-
-From the sample deviation plots, we can see that FAST and CCIM do not produce consistent agents than PPO and the curiosity-driven baselines in the deep sea environment. While CCIM-slimmed produced more consistent agents than PPO but not the baselines. However, in the empty grid-world environment FAST, CCIM, and CCIM-slimmed is able to produce more consistent agents than PPO and RND.
-In the mean episode return plots, CCIM, CCIM-slimmed, and FAST perform better than PPO and RND in the empty grid-world environment which makes sense as the empty grid-world environment was used to find these curiosity algorithms. However, in the deep sea environment we see that the meta-learned curiosity algorithms perform worse than our curiosity-driven baselines.
-
-From the mean episode return plots we can see that BYOL-Explore Lite is the best performing algorithm. Even in the empty grid-world environment it performs better than the meta-learned curiosity algorithms.
-We believe this is because of the reward prioritisation implemented in BYOL-Explore. This could explain its performance is better than the meta-learned curiosity algorithms and why it produces the most consistent agents.
-
-One major concern we still have is how the intrinsic rewards for FAST and CCIM didn't decrease during training for both environments used in our experiments. However, we noted that the
-intrinsic rewards for CCIM-slimmed decreased during training. We believe the decrease in intrinsic rewards as training progresses is one of the main reasons why BYOL-Explore and RND are
-effective and why we see the improved performance of the CCIM-slimmed algorithm. Even with the reward combiner, we still believe that the intrinsic rewards not decreasing could potentially cause an issue, as it did with the deep-sea environment.Recall that the reward combiner has the following formula,
-
-$$
-\hat{r}_t = \frac{(1+ri_t-t/T)\cdot ri_t+ r_t\cdot t/T}{1+ri_t}.
-$$
-
-Now if $$t=T$$ then the $$\hat{r}_t \approx r_t $$ if $$ 0 \leq ri_t \ll 1$$. However for us the intrinsic rewards were not much less than zero during training. We believe that it is important for curiosity algorithms that the intrinsic reward decreases as the agent becomes more familiar with its environment. We believe that this is why CCIM-slimmed performed better than CCIM and FAST in the deep sea environment. Another concern we have is how the CCIM random and forward network's loss increased during training. It is possible that there's a bug somewhere in our code which we have not found yet.
-
-In the future we think it will be interesting to repeat this experiment using the deep sea environment to find the curiosity algorithms that output the intrinsic reward.
-Additionally, exploring the use of a variant of FAST or CCIM to find a reward combiner is also of interest to us. We wonder why a variant of FAST or CCIM wasn't employed for this purpose, as a variant of RND was used to find the reward combiner. As stated earlier, FAST, CCIM and CCIM-slimmed do not make use reward prioritisation like BYOL-Explore Lite does. Therefore, repeating the experiments with the meta-learned curiosity algorithms where some form of reward prioritisation is implemented is another interesting path we hope to explore. We would also like to increase the number of seeds used to reduce the confidence intervals. Since we are training end-to-end in JAX in simple environments, increasing the number of seeds should not be much of an issue.
-
-## Conclusion
-
-In this blog post, we studied two meta-learned curiosity algorithms, namely FAST and CCIM. We compared them to a non-curious agent and our baselines for the curiosity algorithms: RND and BYOL-Explore. Our experiments were conducted using both the empty grid-world environment and the deep-sea environment.
-
-FAST and CCIM both performed well in the empty grid-world, covering more of the map than the baselines when examining their heatmaps. This aligns with our expectations since this was the environment used to search for the curiosity algorithms. However, in the deep-sea environment, both algorithms did not perform well compared to the baselines. Conversely, CCIM-slimmed, a slimmed down version of CCIM, showed performance comparable to the baselines.
-We suspect that this is because the intrinsic reward decreased as the agent explored more. This behaviour was not observed in FAST and CCIM, which we believe is not ideal and consider it the main flaw of these algorithms.
-
-This approach of meta-learning curiosity algorithms is novel, and we believe there's interesting work that can be done following the same approach as Alet et al., trying it with different environments to search for curiosity algorithms, such as the deep-sea environment. Moreover, BYOL-Explore makes use of reward prioritisation. Therefore, in the future, we hope to include reward prioritisation in our FAST, CCIM, and CCIM-slimmed implementations to see if it improves performance. Another avenue is using the meta-learned curiosity algorithms to search for the reward combiner.
diff --git a/_posts/2024-05-07-fairness-ai-two-phil-or-just-one.md b/_posts/2024-05-07-fairness-ai-two-phil-or-just-one.md
deleted file mode 100644
index 3dd47b8b..00000000
--- a/_posts/2024-05-07-fairness-ai-two-phil-or-just-one.md
+++ /dev/null
@@ -1,217 +0,0 @@
----
-layout: distill
-title: "Fairness in AI: two philosophies or just one?"
-description: The topic of fairness in AI has garnered more attention over the last year, recently with the arrival of the EU's AI Act. This goal of achieving fairness in AI is often done in one of two ways, namely through counterfactual fairness or through group fairness. These research strands originate from two vastly differing ideologies. However, with the use of causal graphs, it is possible to show that they are related and even that satisfying a fairness group measure means satisfying counterfactual fairness.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-
-authors:
- - name: MaryBeth Defrance
- url: https://orcid.org/my-orcid?orcid=0000-0002-6570-8857
- affiliations:
- name: University of Ghent
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-fairness-ai-two-phil-or-just-one.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Why fairness?
- - name: What is fairness?
- subsections:
- - name: Explainable AI
- - name: Group fairness
- - name: Unifying these philosophies
- subsections:
- - name: Measurement error - Demographic parity
- - name: Selection on label - Equalized odds
- - name: Selection on predictor - conditional use accuracy equality
- - name: Confirmation with experiments
- - name: What can we take away?
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-This blog post is based on the paper of Anthis and Veitch. The original paper is enriched with a wide overview of fairness concepts used in research and visuals aiding the readers in gaining a deeper understanding. The blog post aims to raise questions about the dichotomy between procedural and outcome fairness, that they perhaps should not be treated as separate research fields as is currently often the case.
-
-## Why fairness?
-The spread of AI exposed some of the dark patterns that are present in society. Some well known examples are the COMPAS case which showed discrimination against black defendants and the Amazon hiring tool which showed a preference towards men compared to women. However, these AI system were most likely not the source of this disparate treatment. This behavior stems from the data that was used to train the system, thus this behavior comes from people who were behind the creation of that data.
-
-Fairness in AI is a research strain which aims to remove the biases in the AI models that result in that disparate treatment. The goal of these models is that people are treated more fairly, perhaps even more than a human decision.
-
-## What is fairness?
-The question of what is fair does not have a single answer. Even when stepping away from the computer science context, a universal definition, that can be used to determine if something is fair or not, cannot be found. The concept of fair is heavily influenced by a person, but also society's biases. The fluidity of the notion therefore gives rise to multiple philosophies in what a fair AI system would be.
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Two_categories.png" class="img-fluid" %}
-
-
- Figure 1: Some examples of the concepts used in the respective philosophies.
-
-
-Two main philosophies can be found in research. The first one, often called explainable AI, aims to either create explainable models or to create explanations for the results obtained from a model. This can also be described as aiming for procedural fairness. The second philosophy is called group fairness. Group fairness focusses on outcome fairness. This means that the predictions from the AI system should have similar properties across groups that only differ in a certain personal attribute.
-
-### Explainable AI
-The most famous example of explainable AI is __fairness through unawareness__. Fairness through unawareness means that no personal attributes are passed into the system, unless these are relevant for the prediction. The system does therefore not have access to the personal attributes, which means it cannot directly discriminate. Fairness through unawareness is often used as the basic model for fairness. However, the systems from both the COMPAS and Amazon example used fairness through unawareness and they still exhibited disparate treatment. The personal attributes that were removed from the data still had an influence on the dataset itself. For instance, a ZIP code can function as a proxy for race or someone's gender influenced their writing style.
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Feature_selection.png" class="img-fluid" %}
-
-
- Figure 2: Examples of Fairness Through Unawareness (FTU) and fair feature selection on the Adult dataset.
-
-
-Related to fairness through unawareness is __fair feature selection__ . Instead of removing the personal attributes, only features that are deemed appropriate remain in the dataset. It needs to be noted that one universal agreement for what are fair features to use is unlike due to the aforementioned biases of people and cultures. Oftentimes, there exists an overlap between the features removed in fairness through unawareness and fair feature selection as is evident in Figure 2.
-
-__Counterfactual fairness__ is a currently popular type of explainable AI. Counterfactual fairness stems from systems that check for direct discrimination, meaning that simply changing a personal attribute would change a person's prediction. An example of direct discrimination can be found in Figure 3, where changing the sex would result into a different prediction. From a legal standpoint it is clear that if a model would exhibit this behavior, it can be deemed unfair.
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Direct_discrimination.png" class="img-fluid" %}
-
-
- Figure 3: Example of direct discrimination where changing the personal attribute of sex changes the prediction a person would receive.
-
-
-Models for counterfactual fairness change both the personal attributes of a person and other features are also adjusted according to a causal model related to the personal attributes. For example changing someone's race might also require to change someone's ZIP code or high school they went to. Figure 4 contains an example of creating counterfactuals. That system is unfair as some of the counterfactuals have a different prediction from the original. Satisfying counterfactual fairness can also be achieved through requiring independence between the personal attributes and the prediction itself. A more stringent constraint is to require that the prediction is independent on all proxy features in the dataset.
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Counterfactual_fairness.png" class="img-fluid" %}
-
-
- Figure 4: Imaginary examples of a system that would not satisfy counterfactual fairness. Changing features in accordance with the personal attributes and data distribution results in a different prediction.
-
-
-### Group Fairness
-Group fairness is a different philosophy regarding fairness of an AI system. Instead of requiring the process of the system is fair, it requires the outcome of the model to be fair. This verdict of fairness is based on the equality of a chosen statistical measure between groups. People are divided into these groups based on their personal attributes. Three definitions are most commonly used for group fairness namely, demographic parity, equalized odds and conditional use accuracy equality.
-
-__Demographic parity__ requires that the selection rate is equal across groups. This means that an equal percentage of people from both groups receives a positive prediction. This definition is independent of the ground truth, which means that for example a perfect predictor could never satisfy demographic parity if the base rates differ between groups. Therefore, from the observation of the dataset it must seem that the prediction is independent of the personal attributes.
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Demographic_Parity.png" class="img-fluid" %}
-
-
- Figure 5: A representation of demographic parity. Two groups are distinguished one male, one female. The circled individuals are the ones to receive a positive prediction.
-
-
-A second fairness measure used in group fairness in __equalized odds__. This fairness measure requires that both the true positive and true negative rates are equal across groups. This means that given the ground truth, there is an equal chance of giving a positive prediction irrespective of a person's group. In other words equalized odds requires the prediction is independent of the personal attribute given the ground truth. Unlike demographic parity, equalized odds is dependent on the ground truth.
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Equalized_odds.png" class="img-fluid" %}
-
-
- Figure 6: A representation of predictions which satisfy equalized odds. Two groups are distinguished one male, one female. The circled individuals are the ones to receive a positive prediction. The colors of the individuals indicates the ground truth of the samples. The male groups has a base rate of 0.8 and the female group a base rate of 0.6.
-
-
-The final common fairness measure in group fairness is __conditional use accuracy equality__. In order to satisfy conditional use accuracy equality, the precision and false omission rate must be equal between groups. Similar to equalized odds, conditional use accuracy equality requires two statistical properties to be equal between groups, namely precision and false omission rate. Put differently, this requires that given the prediction there is an equal chance that this prediction is correct regardless of the group a person belongs to. Conditional use accuracy equality is therefore defined similarly to equalized odds; the roles of the prediction and ground truth are simply reversed. This equality also holds for the independent condition, conditional use accuracy equality requires that the ground truth is independent of the personal attribute if the prediction is known.
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Conditional_use_accuracy_equality.png" class="img-fluid" %}
-
-
- Figure 7: A representation of predictions which satisfy conditional use accuracy equality. Two groups are distinguished one male, one female. The circled individuals are the ones to receive a positive prediction. The colors of the individuals indicates the ground truth of the samples. The male groups has a base rate of 0.8 and the female group a base rate of 0.6.
-
-
-## Unifying these philosophies
-The previous two sections discussed the different concepts used for explainable AI and group fairness. It is clear that they employ a different basis for their philosophy of fairness. However, when looking at these definitions, the concept of independence returns in both counterfactual fairness and the fairness measures used for group fairness. This property of requiring independence allows to unify these notions that they accomplish the same result. Table 1 provides an overview of the fairness measures and the respective independence they require.
-
-In the following section $$ Y $$ symbolises the perceived label, $$ D $$ the prediction, $$ A $$ the personal attributes, $$ S $$ the selection of a sample in the dataset, $$ X^{\bot}_A $$ the data independent of the personal attributes, $$ X^{\bot}_Y $$ the data independent of the prediction and $$ \tilde{Y} $$ the real label.
-
-
- Table 1: A summary of the independence requirement of the fairness notions discussed.
-
-
-
-| Name | Probability definition | Independence |
-| ------------- |:-------------:| -----:|
-| Demographic parity | $$ P(D=1\vert A=1) = P(D=1\vert A=0) $$ | $$ D \bot A $$ |
-| Equalized odds | $$P(D=1 \vert A=1, Y=y) = P(D=1 \vert A=0, Y=y) $$ | $$ D \bot A \vert Y $$ |
-| Conditional use accuracy equality | $$ P(Y=1\vert A=1, d=y) = P(D=1 \vert A=0, D=y) $$ | $$ Y \bot A \vert D $$ |
-
-
-### Measurement error - Demographic parity
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Measurement_error.png" class="img-fluid" %}
-
-
- Figure 8: A directed acyclic graph showing the relation between the prediction and the data, in the situation of measurement error.
-
-
-Measurement error is a first type of dependence that can be resolved in order to be counterfactually fair. Measurement errors means that there is some bias on the perceived ground truth in the dataset. For example in system that determines whether pulling a car over is justified or not (whether a crime was committed or not). More crimes can be uncovered if a full car search happens, however a car search is not always undertaken resulting in a bias of more positive samples for a population where a car search is more likely to happen. In this situation the label is whether or not a crime was detected, not wether a crime was committed. The imbalance car searches for a group with a certain personal attribute will then have an effect on the label. This influence of the personal attributes on the label, but not the ground truth is shown in Figure 6.
-
-A second example of measurement error can be found in healthcare prediction. Predicting someone's health is abstract as this is not quantifiable. A proxy for health is the costs related to the healthcare an individual receives. However, costs are not universal for each group in society. Certain groups can thus have lower costs while managing more health problem due to the care that they receive or perhaps not receive. This faulty proxy is another example of measurement errors.
-
-This system is thus made counterfactually fair if the dependence between the personal attribute and the label is removed. The same independence that is requires to satisfy demographic parity.
-
-### Selection on label - Equalized odds
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Selection_on_label.png" class="img-fluid" %}
-
-
- Figure 9: A directed acyclic graph showing the relation between the prediction and the data, in the situation of selection on label.
-
-
-Selection on label is a type of bias that arises by that not only someone's label affects their adoption in the dataset but also their personal attribute. A subtype of this type of bias is self-selection bias. This means that certain groups of the population are more represented in certain dataset due to that certain groups are more likely to interact with the data collection system. An example of this is in voluntary studies where certain groups are more likely to participate than others leading to a skewed dataset in favor of the participating group. A study around self-selection bias in nutrition trials also found that a person's ground truth influences their participation in the trial (healthy eaters were more likely to apply for this trial).
-
-The directed acyclic graph in Figure 7 shows how to decouple the label itself with the personal attribute by introducing the variable of the selection bias in S, which is an observed variable. $$ A $$ and $$ X^{\bot}_A $$ are only connected through a path that includes $$ Y $$ which means that given $$ Y $$, $$ A $$ and $$ X^{\bot}_A $$ are independent, which is the condition of equalized odds.
-
-### Selection on predictor - conditional use accuracy equality
-
-
-{% include figure.html path="assets/img/2024-05-07-fairness-ai-two-phil-or-just-one/Selection_on_predictor.png" class="img-fluid" %}
-
-
- Figure 10: A directed acyclic graph showing the relation between the prediction and the data, in the situation of selection on predictor.
-
-
-Selection on predictor is similar to selection on label, but instead of the label influencing the prediction is it the features themselves that influence the prediction together with the personal attributes. An example of this can be seen in the student population of engineering degrees. A relevant feature such as what a person studied in high school influence their choice to do engineering. However, there is a large discrepancy in the number of male versus female student who pursue engineering even though that difference does not exist in that degree when graduating high school. This shows that both relevant features, but also personal attributes influence their presence in a dataset about engineering students.
-
-The acyclic graph in Figure 8 for selection on predictor is similar to that for selection on label. The features and label are simply reversed in this situation. This is also in accordance with the similarity seen between equalized odds and conditional use accuracy equality. Through $$ X^{\bot}_A $$, are $$ A $$ and $$ Y $$ connected, which means that if the prediction is known, which is captured in $$ X^{\bot}_A $$, then $$ A $$ and $$ Y $$ are independent, which is necessary to satisfy conditional use accuracy.
-
-### Confirmation with experiments
-This relation between counterfactual fairness and group fairness is supported by experiments. These experiments were done on a synthetic version of the Adult dataset. A simulated protected class A was added where the incidence is balanced (50/50 odds of belonging to the protected class or not). If someone belonged to the protected class, then there is a causal effect of A on X: $$P(race=other) = 0.8 $$. This thus means that A will loosely relate to someone's race being noted as "other". This dataset serves as the target distribution for the biased datasets.
-
-A counterfactually fair model is achieved by by taking the average prediction of an instance if it were part of the protected class and if it was not. Three biased datasets are created based on the directed acyclic graphs in Figures 8, 9, and 10. Table 2 shows that satisfying counterfactual fairness for a certain type of dataset will satisfy a corresponding fairness measure, confirming the theoretical results above.
-
-
- Table 2: The results of applying counterfactual fairness to a model with its performance on different fairness measures.
-
-
-
-| | Demographic parity difference | Equalized odds difference | Conditional use accuracy equality |
-| ------------- | ------------- | ------------- | ------------- |
-| Measurement Error | __-0.0005__ | 0.0906 | -0.8158 |
-| Selection on Label | 0.1321 | __-0.0021__ | 0.2225 |
-| Selection on Predictors | 0.1428 | 0.0789 | __0.0040__ |
-
-## What can we take away?
-
-Procedural and outcome fairness have tended to coexist in research. They are each their own field with their philosophy with the common goal of creating fairer AI systems. The strengths of techniques like counterfactual fairness lie in their explainability and thus allow for an easier determination of whether they are fair or not. The group fairness techniques know many implementations and have been proven to be powerful. However, they are not very interpretable. In order to determine what is fair a first abstraction must be made into converting the meaning of fairness into a mathematical fairness measure. The determination of whether the system is fair is thus dependent on the interpretation of the fairness measure and the quality of the dataset. If the dataset is not representative then there is no guarantee that the system will have a fair outcome.
-
-This relation between the procedural fairness and outcome fairness opens certain research possibilities, perhaps allowing for the strength of the outcome fairness techniques to be combined with the interpretability of the procedural fairness concepts. A future research possibility is to investigate if the techniques to satisfy fairness measure also satisfy some explainability notions or what adjustments would be needed.
\ No newline at end of file
diff --git a/_posts/2024-05-07-hidden-convex-relu.md b/_posts/2024-05-07-hidden-convex-relu.md
deleted file mode 100644
index 6b44cba5..00000000
--- a/_posts/2024-05-07-hidden-convex-relu.md
+++ /dev/null
@@ -1,659 +0,0 @@
----
-layout: distill
-title: The Hidden Convex Optimization Landscape of Two-Layer ReLU Networks
-description: In this article, we delve into the research paper titled 'The Hidden Convex Optimization Landscape of Regularized Two-Layer ReLU Networks'. We put our focus on the significance of this study and evaluate its relevance in the current landscape of the theory of machine learning. This paper describes how solving a convex problem can directly give the solution to the highly non-convex problem that is optimizing a two-layer ReLU Network. After giving some intuition on the proof through a few examples, we will observe the limits of this model as we might not yet be able to throw away the non-convex problem.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-authors:
- - name: Victor Mercklé
- url: "https://victormerckle.fr/"
- affiliations:
- name: LabHC, LJK - France
- - name: Franck Iutzeler
- url: "https://iutzeler.org/"
- affiliations:
- name: Institut de Mathématiques de Toulouse, Université de Toulouse, CNRS
- - name: Ievgen Redko
- url: "https://ievred.github.io/"
- affiliations:
- name: Paris Noah's Ark lab
-
-#authors:
-# - name: Albert Einstein
-# url: "https://en.wikipedia.org/wiki/Albert_Einstein"
-# affiliations:
-# name: IAS, Princeton
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-hidden-convex-relu.bib
-
-#TODO make sure that TOC names match the actual section names - they do
-toc:
- - name: I. Overview and Motivation
- subsections:
- - name: Problem and notation
- - name: Research context
- - name: II. Convex Reformulation
- subsections:
- - name: Small example walkthrough
- - name: Specifics about equivalence
- - name: Activation patterns
- - name: Extensions of the convex reformulation to other settings
- - name: III. Can we Forget the Non-Convex Problem?
- subsections:
- - name: Solving the convex problem efficiently is hard
- - name: Activation patterns are not a constant in the non-convex problem
- - name: On large initialization scale
- - name: On very small initialization
- - name: Conclusion
-
-_styles: >
-
- .remark {
- display: block;
- margin: 12px 0;
- font-style: italic;
- }
- .remark:before {
- content: "Remark.";
- font-weight: bold;
- font-style: normal;
- }
- .remark[text]:before {
- content: "Remark (" attr(text) ") ";
- }
-
- .center {
- display: block;
- margin-left: auto;
- margin-right: auto;
- }
-
- .legend {
- display: block;
- margin-left: 50px;
- margin-right: 50px;
- }
-
- .framed {
- border: 1px var(--global-text-color) dashed !important;
- padding: 20px;
- }
-
- d-article {
- overflow-x: visible;
- }
-
- .underline {
- text-decoration: underline;
- }
-
----
-
-
-
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/teaser_movie.gif" class="img-fluid" %}
-
-
There exists an equivalent convex formulation to the classical non-convex ReLU two-layer network training. That sounds like great news but is it the case in practice? Let's find out together.
-
-The code for _this plot_ is available and reproducible on this __[Jupyter Notebook]({{'assets/html/2024-05-07-hidden-convex-relu/hidden-convex-relu.ipynb' | relative_url}})__ (or in __[HTML]({{'assets/html/2024-05-07-hidden-convex-relu/hidden-convex-relu.html' | relative_url}})__).
-
-## I. Overview and Motivation
-
-50 years ago, two-layer networks with non-linear activations were known to be universal approximators, however, they did not catch on as they were hard to train. The recent years have been marked by deeper networks running on dedicated hardware with very large datasets. Those networks have since been at the top of the benchmark in many applications including self-driving and text generation. The pragmatic method to train such models is to run stochastic gradient descent on the non-convex optimization problem, which is concretely tuning the weights (and bias) until the model is accurate enough. The best models usually require billions of parameters and very large datasets. The training, in turn, requires millions of dollars of hardware and electricity to run gradient descent and train a single model.
-
-Deep learning is not without faults. Even though the test performance can surpass those of many machine learning models, it is very hard to know what the network has learned because of its black-box nature. Interpretability in neural networks is crucial for creating trustworthy AI systems, one of the biggest obstacle to AI adoption. It may also lead us to simpler models that are cheaper to run, are more robust, generalize better, and are easier to adapt to specific tasks.
-
-To figure out what a neural network learns, we will focus in this post on the training of a shallow ReLU network by vanilla gradient descent, using the full batch of data at each step, in a regression setting. More precisely, we will investigate how the construction of a convex equivalent to the non-convex training problem can enlighten us on how neurons evolve during the training phase, with a specific focus on the activation of the ReLU functions and their consequences.
-
-### Problem and notation
-
-Our problem of interest will be the training of a simple two-layer neural network with ReLU activation. We focus on a classical regression problem with a mean squared error loss and we add a weight decay term (whose importance will be underlined later). This leads to the following full-batch gradient method (note that we make a slight abuse of notation by denoting by $\nabla$ the output of the derivative of the parameters, obtained, for instance, by backpropagation).
-
-Because there are only two layers, we will integrate the biases of the neurons directly into the data by adding a dimension filled with ones.
-
-
-
-Even the simplest ReLU models have non-trivial non-convexity as depicted in the figure below. We plot the loss function $$\mathcal{L}$$ of a network with two neurons on one-dimensional data. We only optimize the first layer here so we have a total of two parameters to optimize. Despite the simple setup, a gradient descent starting from a random initialization can converge to three different values, two of them being bigger than zero. However, there always exists a path of non-increasing loss from initialization to the global minimum (as predicted by a ).
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/threed.png" class="img-fluid" %}
-
-
Loss landscape of a network with two parameters, one for each ReLU neuron, and two data points: $(x_1, y_1) = (-1, 1)$ and $(x_2, y_2) = (1, 2)$ are fixed. Since all labels are positive, we fix the second layer $\alpha_1, \alpha_2$ to 1 to plot the loss in 2D without a loss of generality. The black lines represent the loss for only one neuron (since the other is equal to 0). The red lines(critical points) are paths of parameters for which the loss is constant and the gradient is zero. They represent the parameters for which the neuron fits exactly one data point and is deactivated for the other and thus suffers a loss of $(y_1)^2$ for the red line on the left and $(y_2)^2$ for the other. The exact formula to compute each point of the loss landscape is:
-
-\begin{equation*}
-\begin{split}
-\mathcal{L}(w_1, w_2) =&\ \left(\max(0, x_1 w_1) + \max(0, x_1 w_2) - y_1\right)^2 \\
-+&\ \left(\max(0, x_2 w_1) + \max(0, x_2 w_2) - y_2\right)^2
-\end{split}
-\end{equation*}
-
-
-To avoid the local minima, one idea is to add constraints to the parameters. The constrained problem where $w_1$ has to be positive and $w_2$ has to be negative, _is_ convex, and a simple gradient descent will find the global minima of the original unconstrained problem. In , they find a more general way to build an equivalent convex problem to our ReLU shallow network training problem.
-
-In this blog post, we will first work out the intuition needed to understand why an equivalent, finite convex problem even exists. Then we will study the exact links between the problem in practice and the convex problem, and go over the limits of such an approach both in theory and in practice.
-
-### Research context
-
-The question of how neural networks learn is a very active domain of research with many different paths of investigation. Its main goal is to lay a mathematical foundation for deep learning and for that goal, shallow neural networks act as a stepping stone for understanding deeper and more complex networks.
-
-For networks with a hidden layer of infinite width, it is proven that gradient descent converges to one of the global minima under the _NTK regime_, or by considering them as Wasserstein gradient flows. Studying the NTK amounts to analyzing the first-order Taylor expansion of the network, treating the network as a linear regression over a feature map. This approximation is accurate if the neurons are initialized with a large scale(far from zero), large enough that neurons do not move far from their initialization. This is also called the _lazy regime_ , in contrast with the _feature learning regime_ where neurons align themselves to a finite amount of directions. While it is noticeable, we are also interested here in a feature-learning regime with small initialization where we can observe actual non-convex behavior such as neuron alignment, incremental learning and saddle to saddle dynamic.
-
-Examining the loss landscape reveals that shallow networks with more neurons than data points always have a non-increasing path to a global minimum. This is a favorable property for (stochastic) gradient convergence. In '_The Hidden Convex Optimization Landscape of Regularized Two-Layer ReLU Networks_', the authors extend those results by adding the weight decay regularization.
-
-Regularization plays a pivotal role as it let us influence which local minimum we will reach with gradient descent, usually to favor a simpler solution. Even if no explicit regularization is used, it is known that there is an implicit bias of gradient descent for linear activations, and more recently for ReLU networks using the convex reformulation.
-
-Other convex approaches are limited to an infinite amount of neurons, or to optimization in neuron-by-neuron fashion which requires solving many non-convex problems. The setting studied here allows for any number of neurons.
-
-To sum up, the convex reformulation approach described in this post contrasts with what precedes by presenting results for a shallow network with __finite width layers__, in a __regression__ setting with __ReLU__ activation and __weight decay__ regularization.
-
-## II. Convex reformulation
-
-### Small example walkthrough
-
-First, let's get familiar with and understand the inherent convexity caused by ReLU and the second layer. To do so, we will take simple yet non-convex examples and find their global minima using a convex problem.
-
-#### One ReLU, no second layer, no regularization
-
-Below is the loss of a single ReLU neuron ($$w_1 \in \RR$$) trained on two data points: $$(x_1, y_1)=(-1, 1)$$ and $$(x_2, y_2) = (1, 0.5)$$
-
-
-
-Because our only trainable parameter is one-dimensional, we can directly plot the entire loss landscape.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/redloss.png" class="img-fluid" %}
-
-
\(\color{cvred}{\mathcal{L}}\) is non-convex in a strong sense: two local minima exist and have distinct values (\((y_1)^2\) and \((y_2)^2\)). In practice, a gradient descent will never be able to switch from fitting one data point to the other (switching from positive to a negative weight $w_1$ can only be done by increasing the loss).
-
-We say that the ReLU neuron can _activate_ one or more data points if the output of its ReLU is non-zero when evaluated on said data. The output of a one-neuron ReLU network is $$\color{cvblue}{\max(0, x ~ w_1)}$$, we can plot both the output and the two data points on the same graph.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/blueoutput.png" class="img-fluid" %}
-
-
Plot of the output of a one-neuron ReLU network with a positive weight $w_1$. The ReLU only activates the second data point (as $x_2>0$ and $w_1 > 0$) so the network can fit the second data point. However, doing so means it cannot activate $x_1$ and will incur a constant loss $(y_1)^2$. Overall, depending on the sign of $w_1$, we will have a loss consisting of a constant term for not activating one example and a quadratic term for matching the label of the activated data point.
-
-
-Before moving on, the important fact here is that we have a true non-convexity of the loss(the difference between two local minima $\vert (y_1)^2 - (y_2)^2 \vert$ can be made arbitrarily large), even without a single layer or regularization. Now we will explore the corresponding convex problems.
-
-#### Activation
-
-We want to find the global minima of the one-neuron ReLU network loss function\eqref{eq:one_neuron_loss}. Recall that the loss has two local minima: $(y_2)^2$ for $w_1=y_1/x_1$ and $(y_1)^2$ for $w_1=y_2/x_2$.
-
-Which data points are activated plays a crucial role in the loss. In the specific example above, $x_2>0$ is activated and $x_1<0$ is not. If we fix the ReLU's activation to this pattern and __replace the max operators__ with $$\czero$$ or $$\cone$$:
-
-
-
-This problem is convex. A gradient descent from any initialization will converge to the optimal loss $(y_1)^2$ with the parameter $u_1 =y_2/x_2$. This parameter directly corresponds to one of the two local minima of the non-convex loss\eqref{eq:one_neuron_loss} by taking $w_1 = u_1$.
-
-
-
-Similarly, this convex problem's optimal solution directly corresponds to the second local minima: $(y_2)^2$ for $u_2 =-y_1/x_1$.
-
-All seems good. But keep in mind that we want to build an equivalent problem. If $u_2$ is positive, taking $w_1 = u_2$ does not lead to the same loss value in the original problem because a positive parameter will never activate the first data point.
-
-To make the issue obvious, consider this convex problem obtained by replacing the two $\max$ operators by $$\cone$$:
-
-
-
-While it is convex, there is no link between the ReLU parameter $w_1$, and this new problem's parameter $u_3$: it is not possible to activate both data points. This issue comes from the fact that replacing a $\max$ by $$\cone$$ only makes sense if what is inside the $\max$ is indeed positive. In other words, as long as $$x_1 ~ w_1$$ is positive we have that $$max(x_1 ~ w_1, 0) = \cone x_1 ~ w_1$$.
-
-
-
-We added the constraints corresponding to the activation, and it adequately restricts $u_3$ to be in $\{0\}$.
-
-As a simple reformulation of \eqref{eq:firsttry}, we vectorize (in the number of data points) the convex loss and we add the constraints:
-
-
-
-The diagonal activation matrix (named $$D_i \in \{0, 1\}^{n \times n}$$) summarize the on/off behavior of _one_ ReLU for _all_ data points. The constraints on $u_1$ are directly given by this activation matrix:
-
-$$\begin{bmatrix} -1 & 0 \\ 0 & 1 \end{bmatrix} = 2 \begin{bmatrix} \czero & 0 \\ 0 & \cone \end{bmatrix}- I_2 \qquad \text{$I_2$ the identity matrix of $\RR^2$}$$
-
-The other way around, we can define the activation pattern vector for a specific parameter $$u$$: $$(\mathbb{1}_{u ~ x_j \geq 0})_{j=1\dots n} \in \{0,1\}^n$$ with $n$ the number of data points. The activation matrix of $$u$$ is simply the matrix that has this vector for its diagonal.
-
-So we have exactly four possible activation matrices. $$D_1 = (\begin{smallmatrix} \czero & 0 \\ 0 & \czero \end{smallmatrix})$$ and $$D_2 = (\begin{smallmatrix} \cone & 0 \\ 0 & \cone \end{smallmatrix})$$ will have constraints that reduce to $w_1 = 0$, making them not interesting. The other two lead to convex problems with convex constraints. Solving them will give the parameters that correspond to the two local minima of the loss of ReLU neural network with only a single neuron\eqref{eq:one_neuron_loss}.
-
-
For any number $n$ of 1-D data points, there are $2^n$ distinct activation matrices but only two of them will be interesting: activating all positive data points, or only activating negative data points. Only some $D_i$ are interesting in higher dimensions, but finding all of them is not obvious.
-
-Replacing everything with the usual matrices ($$X=(\begin{smallmatrix}x_1 \\x_2\end{smallmatrix})$$, $$Y=(\begin{smallmatrix}y_1 \\y_2\end{smallmatrix})$$) will get us the equivalent convex problem to a one-neuron ReLU network, whose activation pattern is $D_i$:
-
-
-\begin{equation*}
-\min_{\substack{u_1 \in \RR\\ (2 D_i - I_2) X u_1 \geq 0}} \ \
-\big\| D_i X u_1 - Y \big\|_2^2
-\end{equation*}
-
-
-
-Later sections will investigate what we can say about a ReLU network with more than one neuron.
-
-#### Multiplicative non-convexity from the second layer
-
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/vraitroisd.png" class="img-fluid" %}
-
-
-
-\eqref{eq:ncvxlin} is not convex, it has two local minima. However, they are symmetric. Simply replace the term $x ~ y$ by a new variable $z$, and use a simple mapping such as $z \rightarrow (1, z)$ to get the solution of \eqref{eq:ncvxlin} from the solution of the convex problem: $$\min_{z \in \RR} (z-1)^2$$.
-
-The initial problem\eqref{eq:ncvxlin} with L2 regularization is non-convex as well:
-
-
-\begin{equation*}
-\min_{(x, y) \in \RR^2} (x ~ y - 1)^2 + \frac{\lambda}{2} ( \vert x \vert^2 + \vert y \vert^2)
-\end{equation*}
-
-
-The convex reformulation with one variable is:
-
-\begin{equation*}
-\min_{z \in \RR} (z - 1)^2 + \lambda \vert z \vert
-\end{equation*}
-
-
-We have to use a different mapping $$z \rightarrow (\sgn(z) \sqrt(\vert z \vert), \sqrt(\vert z \vert))$$. One can verify that plugging this mapping into the non-convex problem will give the same value. Therefore, you can solve the convex problem in lieu of the non-convex one.
-
-Back to non-linear activations, consider the non-convex problem of training a single ReLU neuron with a second layer($$\alpha_1$$) and a L2 regularization:
-
-
-
-We fix the activation to only activate $x_1$(as could be done for any activation pattern) and add the corresponding constraint as done in the previous section:
-
-
-
-\eqref{eq:ncvx1} is a non-convex problem because we are multiplying $w_1$ and $\alpha_1$ together (and some constant). However, this non-convexity can be ignored by considering an equivalent convex function in a very similar way to the $(x ~ y - 1)^2$ problem.
-
-
-
-$z_1$ takes the role of the product $w_1 ~ \alpha_1$. We can solve \eqref{eq:cvx1} to get an optimal $z_1$ and then use a mapping $$(w_1, \alpha_1) = (\sgn(z_1) ~ \sqrt{\vert z_1 \vert}, \sqrt{\vert z_1\vert})$$. However, the two problems do not have the same expressivity: $$ \max(0, x_1 ~ z_1) \alpha_1 $$ can be negative but not $$\cone ~ x_1 ~ z_1$$ because of the constraint. Let's add a second variable with the same constraint as $z_1$ that will take the role of a negative $\alpha_1$.
-
-
-
-The variable $$z_1$$ represents a neuron with a positive second layer and $$v_1$$ a neuron with the same activation pattern but with a negative second layer. This is a convex problem(adding a convex regularization preserves the convexity) with convex constraints. At the optimum, only one of the two variables will be non-zero. We consider this mapping:
-
-
-\begin{align*}
-(w_1, \alpha_1) &= (\sgn(z_1) ~ \sqrt{\vert z_1 \vert}, \sqrt{\vert z_1 \vert}) & \text{ if $z_1$ is non-zero}\\
-(w_1, \alpha_1) &= (\sgn(v_1) ~ \sqrt{\vert v_1 \vert}, - \sqrt{\vert v_1 \vert}) & \text{ if $v_1$ is non-zero}
-\end{align*}
-
-
-One can verify that this mapping does give the same value when plugged into \eqref{eq:ncvx1}. The two problems share the same global minima as we can easily map back and forth without altering the loss. The global minima of the two problems have the same value as they have the same expressivity, we can say the two problems are equivalent in the sense that we can solve one to get the solution of the other by a simple mapping.
-
-To summarize, here's the equivalent (with the above mapping) convex problem for a one-neuron ReLU Network with regularization and a second layer, whose activation pattern is $D_i$:
-
-
-\begin{equation*}
-\min_{\substack{(2 D_i - I_2) X u_1 \geq 0\\
-(2 D_i - I_2) X v_1 \geq 0}} \ \
-\big\| D_i ~ X (u_1 - v_1) - Y \big\|_2^2
-\end{equation*}
-
-
-#### Equivalent Convex problem with two neurons
-
-Before moving on to the general results, we want to fit two data points, *i.e.* having both data points activated. To do so, we need at least two neurons. The usual non-convex problem is as follows (with $$X=(\begin{smallmatrix}x_1 \\x_2\end{smallmatrix})$$, $$Y=(\begin{smallmatrix}y_1 \\y_2\end{smallmatrix})$$ and $m=2$):
-
-
-
-The best choice(only obvious in this 1-D data case) of activation matrices would be $$D_1 = (\begin{smallmatrix} \czero & 0 \\ 0 & \cone \end{smallmatrix})$$ and $$D_2 = (\begin{smallmatrix} \cone & 0 \\ 0 & \czero \end{smallmatrix})$$.
-
-Solving and mapping the solutions would give the optimal *global* solution to the problem of fitting two data points with a ReLU network with two neurons. More insights about why this is true are given after the general case section, and the complete proof can be found in the paper.
-
-#### General Case
-
-Let us consider a general two-layer ReLU network with an input of dimension $d$, an output of dimension 1 (vector output requires a similar but parallel construction) and a hidden layer of size $m$. With $n$ data points, the full regularized loss is
-
-
-This is the same loss as presented at the beginning of the article\eqref{eq:theloss} but with matrix and vectors. $$\pmb{X} \in \RR^{n \times d}$$ is the data matrix and $$\pmb{y} \in \RR^n$$ are the labels. Each neuron has its first layer parameter $$\pmb{w}_i \in \RR^d$$ and second layer $$\alpha_i \in \RR$$.
-
-By analogy with what we saw earlier, an equivalent convex problem can be found. Multiplications are replaced by scalar products in the definition of activation matrices and thus most insights about activation hold.
-
-
-
-$$\pmb{D}_i$$ are the activation matrix. The set of the constraints $$\mathcal{K}$$ is the concatenation of the constraints of all neurons. Each constraint can be written succintely: $$(2 \pmb{D}_i - \pmb{I}_n) X \pmb{u}_i \geq 0$$. If $$u_i$$ respects the constraint, its activation pattern is exactly $$D_i$$ and this is crucial to retrieve the optimal solution of the non-convex loss\eqref{eq:theloss} from the solution of the convex reformulation\eqref{eq:thecvx}.
-
-A conceptually easy way to have the two problems have the same global loss, is to consider a ReLU network with $$2^n$$ neurons, and to formulate the convex problem using all $$2^n$$ distinct activation matrices $$D_i$$. In that case, it is easy to see that they both have the same expressivity. In the paper, it is proved that in theory only $$n$$ neurons and activation patterns are required (using carathéodory's theorem), but the patterns are not given explicitly. The next section will give more insights on when the two problems are equivalent.
-
-From a solution of the convex problem\eqref{eq:thecvx}, the *convex neurons* $$u_i$$ can be mapped to the *non-convex neurons* $$(w_i, \alpha_i)$$ using this mapping:
-
-
-\begin{align*}
-(w_i, \alpha_i) &= (\frac{u_i}{\sqrt{\| u_i \|_2}}, \sqrt{\| u_i \|_2}) & \text{ if $u_i$ is non-zero}\\
-(w_i, \alpha_i) &= (\frac{v_i}{\sqrt{\| v_i \|_2}}, -\sqrt{\| v_i \|_2}) & \text{ if $v_i$ is non-zero}
-\end{align*}
-
-
-We use the same mapping as in the 1D case except the direction of the neuron ($$u_i$$) is now a vector in $$\RR^d$$
-
-
This is a very simple mapping from convex solution to non-convex neurons. We will call convex neurons the set of parameters that correspond to a neuron in the original, non-convex problem. One can expect similar trajectories between the non-convex and convex neurons during gradient descent.
-
-
-Here, we fixed the number of neurons and the corresponding activations. A few questions are left unanswered: how many different activation patterns need to be considered, and how many neurons should we consider for both convex and non-convex problems?
-
-### Specifics about equivalence
-
-Two problems are considered equivalent when their global optima can be seamlessly mapped back and forth.
-
-As seen before, there are only two *interesting* possible activation patterns in the one-dimensional case (a single neuron can either activate all the positive data points and none of the negative, or the opposite), but there are close to $$2^n$$ _interesting_ patterns when the data dimension is higher. An activation pattern is interesting if there exists a non-zero vector that can respect the constraints and in fine, the activation pattern.
-
-The (unique) optimal loss of the convex problem \eqref{eq:thecvx} with all possible activation patterns(for fixed data) $$D_i$$ is the best loss any non-convex network can reach. The following sections are dedicated to understanding why adding more neurons than there are activation patterns will not improve the loss.
-
-However, if we only consider a subset of all patterns, the convex problem will in general correspond to a local optimum of the non-convex network. Indeed, it is not as expressive as before. This would either correspond to a non-convex network with not enough neurons, or with too many neurons concentrated in the same regions.
-
-To explore this idea, we go back to one-dimensional data.
-
-#### 1-D EXAMPLE, ONE NEURON
-
-In the non-convex problem with only one neuron, there are exactly two local minima.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/oneneuron.png" class="img-fluid" %}
-
-
Plot of the output of a ReLU Network with one neuron, one for each of the parameter's local minima. The parameter on the left can be formulated as a solution of a convex problem with one convex neuron using the activation matrix \((\begin{smallmatrix} \czero & 0 \\ 0 & \cone\end{smallmatrix})\), and \((\begin{smallmatrix} \cone & 0 \\ 0 & \czero \end{smallmatrix})\) for the right output.
-
-As seen in the previous section, each local minimum can be found exactly by solving the convex problem with a subset of all possible activations, that is on the left and on the right. Here we cannot say that the convex problem (that considers only one pattern) is equivalent to the non-convex one because the global minimum of the non-convex cannot be achieved in the convex problem. However, once we reach a local minimum in the non-convex gradient descent, then it can be described as a convex problem, by considering one pattern or the other.
-
-#### 1-D EXAMPLE, TWO NEURONS
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/twoneuron.png" class="img-fluid" %}
-
-
The non-convex problem initialized with two random neurons and optimized with gradient descent will have three possible local minima (if there is some regularization, otherwise there's an infinite number of them). Either we initialize a neuron for each activation and it will reach the global optima (left), or two of them will end up in the same pattern (right), activating the same data point.
-
-In the case of two neurons, the convex equivalent problem is as follows:
-
-
-
-is equivalent to the non-convex problem i.e. solving it will give the global optimum of the non-convex objective. (the negative $v_i$ are zero at the optimal and are removed here only to be clear.)
-
-#### 1-D EXAMPLE, MANY NEURONS
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/manyneurons.png" class="img-fluid" %}
-
-
Plotting the positive part of many ReLU neurons. Summed up, they form a network output that perfectly fits the data.
-
-We draw one example of a usual local minimum for gradient descent in the specific case of having more neurons than existing patterns. In practice (with more data in higher dimensions) there are much fewer neurons than possible activations. However, there are many situations in which neurons will lead to the same activation patterns, and in the experiment section we will see how to force such dynamics.
-
-Note that we can merge neurons that are in the same activation pattern by summing them up (even in higher dimensions), creating a new neuron, and keeping both the output and the loss unchanged (although regularization might decrease). The fact that having more than one neuron in one pattern does not decrease the loss is at the core of the proof.
-
-### Activation patterns
-
-The equivalence proof is heavily based on ReLU, specifically that a ReLU unit divides the input space into two regions: one where it will output zero, and the other where it is the identity. If you consider a finite set of samples and a single ReLU, it will activate and deactivate some samples: this is called an activation pattern. A diagonal matrix $$\pmb{D}_i \in \{0,1\}^{n \times n}$$ describes one activation pattern, but not all are possible for a given dataset. There is a finite amount of such possible patterns, exponential in the dimension of the data.
-
-This section is important to understand the final animations in the experimental section and helps understand how active activation patterns evolve in the non-convex problem.
-
-#### Two-Dimensional Data
-
-In the previous part, we considered data to be one-dimensional which resulted in only two possible activation patterns. Let us consider two-dimensional data. To do so in the simplest way possible, we will consider regular one-dimensional data and a dimension filled with $$1$$s. This will effectively give the neural network a _bias_ to use without modifying the formulas.
-
-We consider two data points: $$\color{cvred}{\pmb{x}_1} = (-0.2, 1)$$ and $$\color{cvred}{\pmb{x}_2} = (1, 1)$$, each associated with their label $$y_1 = 0.5$$ and $$y_2 = 1$$. We plot the output of one ReLU unit initialized at $$\pmb{w}_1 = (0.3, 0.15)$$, $$\alpha_1 = 1$$. Therefore we have
-
-
-
-The activation pattern of $$\pmb{w}_1$$ is $$\pmb{D}_1=\left(\begin{smallmatrix} \czero & 0 \\ 0 & \cone \end{smallmatrix}\right)$$. There are only three other possible activation patterns, activating both data points: $$\pmb{D}_2=\left(\begin{smallmatrix} 1 & 0 \\ 0 & 1 \end{smallmatrix}\right)$$, activating only the first one with $$\pmb{D}_3=\left(\begin{smallmatrix} 1 & 0 \\ 0 & 0 \end{smallmatrix}\right)$$ and activating no data point with a zero matrix.
-
-One point of interest is the data for which the ReLU will be 0. This is where the output changes its slope: $$a_1 = -w_1^2/w_1^1$$ where $$w_1^i$$ is the i-th coordinate of $$\pmb{w}_i$$. Here, $$a_1 = 0.5$$. We call this the _activation point_ of the neuron $$\pmb{w}_1$$.
-
-We plot the output, $$\color{cvblue}{\max(0, (x, 1) ~ \pmb{w}_1^\top)}$$, of the network as a function of the first dimension of the data $$x^1$$ (here simply written $$x$$):
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/twodim.png" class="img-fluid" %}
-
-
A neuron initialized so that it activates only one data point i.e. its activation point is between the two samples, and its slope tells us if it activates on the left or on the right like in this case.
-
-__Illustration__.
-
-In the animation below, we train this network using vanilla gradient descent on the two data points $$\color{cvred}{\pmb{x}_1}$$ and $$\color{cvred}{\pmb{x}_2}$$, represented by the red crosses. We plot its $$\color{cblue}{\text{output}}$$ in blue for every possible data point (omitting the second dimension as it is always 1 in this example, playing the role of the bias), and we plot in red the label associated with the two data points. Each frame corresponds to one step of full-batch gradient descent with a small learning rate. We mark the $$\color{cgreen}{\text{activation point}}$$ of the neuron with a green triangle, pointing toward the side the neuron activates. The green triangle's height is the slope of the ReLU's output, equal to $$u_1^1 = w_1^1 \alpha_1$$, allowing us to visualize how important one neuron is for the output of the network.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/firstgif_movie.gif" class="img-fluid" %}
-
-
Training a single neuron network with gradient descent until it exactly fits two data points. It starts by fitting the only point it activates, \(\color{cvred}{\pmb{x}_2}\). As training progresses, the activation point represented by a green triangle shifts position. As soon as the activation point reaches \(\color{cvred}{\pmb{x}_1}\), it activates it and starts fitting both points at the same time. Its activation pattern shifts from \(\left(\begin{smallmatrix} \czero & 0 \\ 0 & \cone \end{smallmatrix}\right)\) to \(\left(\begin{smallmatrix} \cone & 0 \\ 0 & \cone \end{smallmatrix}\right)\) and stays the same until convergence.
-
-Adding more neurons will not create additional activation patterns, only adding more data points will. With only two data points $$\pmb{x}_1$$ and $$\pmb{x}_2$$, we only had 4 possible patterns, with four data points we have 10 possible patterns.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/annoying.png" class="img-fluid" %}
-
-
We plot the individual output and activation points of each of the ReLU neurons associated with the ten _interesting_ activation patterns in blue. Those are the 10 (20 with negative ones) neurons that need to be considered to get the global optima using the convex equivalent. When moving the activation point \(a_i\) of a neuron between two data points, its activation pattern does not change.
-
-
Notice that it is not possible to only activate the data points in the middle. However, if we increase the data's dimension, this becomes possible. This is also possible with a second layer of ReLU. In higher dimensions, we cannot visualize the activation patterns as easily, but we can understand that as dimensionality increases, more patterns are possible as it is easier to separate different data points.
-
-### Extensions of the convex reformulation to other settings
-
-Batch Normalization (BN) is a key process that adjusts a batch of data to have a mean of zero and a standard deviation of one, using two trainable parameters. In the convex equivalent, we replace $$\pmb{D}_i \pmb{X}$$ with $$\pmb{U}_i$$. This $$\pmb{U}_i$$ is the first matrix in the Singular Value Decomposition (SVD) of $$\pmb{D}_i \pmb{X} = \pmb{U}_i \pmb{\Sigma}_i \pmb{V}_i$$ . If the output is a vector, rather than a scalar, the regularization changes to require a nuclear norm in the convex equivalent . Three-layer networks also have a convex equivalent using all possible combinations of two activation matrices. Moreover, parallel networks are also linked to a convex problem . Lastly, in Wasserstein Generative Adversarial Network (WGAN) problems, the adversarial games played by two-layer discriminators are identified as instances of convex-concave games .
-
-## III. Can We Forget the Non-Convex Problem?
-
-### Solving the convex problem efficiently is hard
-
-In the last ten years, deep neural networks have been trained using (stochastic) gradient descent on the non-convex problem. The algorithm, the implementation, and even the hardware running the training have been heavily optimized, supported, and pushed by industrial and scientific applications. Such networks were practically abandoned for years after being discovered because there did not exist an efficient way to train them. Nowadays, it takes a few lines to train a network on dedicated hardware and this might make us forget how much engineering has made this possible. This should be kept in mind when comparing a new approach to the problem.
-
-Training a network with the non-convex problem can be time consuming as it requires tuning hyperparameters and rollbacks(retrieving a previous state) to get out of a bad minimum. In that case, the convex approach deals with much fewer parameters and has only one global minimum.
-
-In complexity terms, the convex reformulation with all possible activation patterns $D_i$ gives an algorithm in polynomial time for all parameters except for the rank of the data matrix. In practice and with usual datasets, the rank is high and there will be too many patterns to consider them all.
-
-There has been some work focused on solving the convex problem quickly. The first idea is to take a random subset of activation patterns and use standard convex solvers. Current convex solvers (ECOS, ...) are not tailored to problems with many constraints. There is some hope in considering the unconstrained version of the problem to build an approximation. In most deep learning scenarios, it is hard to be faster, or even start to compete against a simple gradient descent running on GPUs.
-
-| Dataset | Convex | Adam | SGD | Adagrad |
-|----------|--------|------|------|---------|
-| MNIST | 97.6 | 98.0 | 97.2 | 97.5 |
-| CIFAR-10 | 56.4 | 50.1 | 54.3 | 54.2 |
-
-
Test accuracy on popular datasets for a single layer network with 5000 neurons.
-
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/quantgraph.png" class="img-fluid" %}
-
-
Time to solve problems from the UCI datasets with Adam on the non-convex problem and a custom solver(using the augmented Lagrangian method). The code for the paper's experiments is available on github, as well as the convex problem toolkit.
-
-For relatively small datasets and networks, convex solvers are fast and do not require any tuning to get convergence. Adjusting the regularization will directly reduce the amount of neurons needed.
-
-
-A convex equivalent of deeper networks exists but exacerbates existing problems. The only way to make it possible is to optimize layer by layer. This is still a work in progress and needs further improvements to be competitive.
-
-### Activation patterns are not a constant in the non-convex problem
-
-Let's set aside the performance concerns and use the reformulation as a new point of view for observation. Our non-convex problem is equivalent to a convex and well-specified optimization problem with constraints. The global optima might be the same, but training the network with gradient descent almost always leads to a local minimum. Because there are too many activations to consider them all, the convex problem only find a local minimum. However, it is not clear if they find the same kind of local minimum.
-
-Activation patterns can and will change during gradient descent in the non-convex problem. In some cases, this pattern shifting is useful because the new activation patterns may lead to a better minimizer. To verify this, we monitor the number of unique activation patterns used by the network at each step of a gradient descent. If two neurons have the same activation pattern (_i.e._ they activate and deactivate the same data points), we would count them as one.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/nbactiv.png" class="img-fluid" %}
-
-
Training a network with 100 random data points in 10 dimensions. The network only has 20 randomly initialized neurons and the data is linearly dependent on the input. Each neuron has a unique activation pattern as can be seen on the graph. It is expected in this setting because there are so many possible activation patterns (close to $10^{25}$The number of activation patterns is the same as the number of regions in a partition by hyperplanes perpendicular to rows of $X$ and passing through the origin. This number of region is bounded by \(2 r \left(\frac{e ~ (n-1)}{r}\right)^r\) with $r$ the rank of $X$). However, as training progresses, neurons align themselves to the same pattern. After 300 steps, the 20 neurons only share 5 unique activation patterns.
-
-However, we can show an aspect that sets both formulations apart. The convex problem has fixed activation patterns. If the activations are missing important data, the convex solution will not be optimal. Meanwhile, in the non-convex problem, the gradient descent keeps shifting from pattern to pattern until it converges.
-
-__Illustration.__
-
-We will further study this setting with 100 data points and 20 neurons in high dimensions. To compare how the two methods deal with activation patterns, we will use the activation pattern of the neurons of the non-convex problem to construct a convex problem and solve it. To be more explicit, for each non-convex neuron $$\pmb{w}_i$$, we find its activation pattern and add a $$\pmb{u}_i$$ constrained to this pattern to the convex problem. In the end, we have a convex problem with 20 neurons that will activate the same data points as the non-convex neurons.
-
-We train the non-convex network using gradient descent, and at each step, we construct a convex problem, solve it, and compare its global minimum to our current non-convex loss. This convex problem fully describes the local minimum we would find if the non-convex problem was constrained to never change its activation patterns.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/cvx_vs.png" class="img-fluid" %}
-
-
-
-Training a 20-neuron network with gradient descent and using the same activation patterns to solve the convex equivalent. We plot for each step, the current loss of the non-convex network and the optimal loss of the convex problem. At initialization (first point on the graph), the non-convex loss is 1. We take the current activation pattern and build a convex problem and solve it, we find an optimal loss of $0.1$. In the next step, the non-convex loss decreases and the activation pattern has changed, thus we find a different optimal loss for the convex problem. The initial optimal loss of the convex is quickly beaten by gradient descent (at around step 175), this means that the activation patterns at step 0 were far from optimal. The convex loss at the start is quickly beaten by gradient descent, this means our initial choice of activation pattern was bad, and gradient descent continually improves them. We use cvxpy to define the problem and solve it using ECOS.
-
-
-In general, we cannot predict which patterns will be used by the neurons found by GD, or which patterns are the best. Thus we cannot hope that the convex problem will give us an insight as it requires us to know the activation patterns. We can however predict what (some of) the optimal solution will look like a spline interpolation on each training sample.
-
-In the next section, we focus on cases where the non-convex minima can be accurately described by convex problems.
-
-### On large initialization scale
-
-The initialization scale of the network is the absolute size of the neurons' parameters. To get a change in the scale, we can simply multiply every parameter by a scalar. The initial value of the neuron is a large topic in machine learning as it has a large influence on the quality of the local minimum. By default in popular libraries, _He initialization_ is used, it draws neurons from a normal distribution centered on 0 and with a variance in $$1/m$$ with $$m$$ the number of neurons. However, in the literature, there is a large choice to pick from.
-
-We say we are on a large scale when neurons do not move far from their initial value during descent. This typically happens when using large initial values for the parameters of each neuron.
-
-The theory states that you can push the scale used high enough so that neurons will not change their activation patterns at all. If this is verified, the convex reformulation will describe exactly the minima that gradient descent will reach. However, it is not possible to observe this in practice as the loss becomes very small and the training process is too slow to carry on to the end. The NTK briefly mentioned in the introduction operates in this setting, using the fact that the network is very close to its linear approximation. On a similar note, reducing the step size for the first layer guarantee convergence.
-
-__Illustration.__
-
-Using an animation, we plot every step of a gradient descent in the non-convex problem until the loss is small enough. As mentioned before, the training is too slow to continue until we reach a real local minimum described by the convex problem here. We plot the output of the network, which is the sum of all the neurons. We want to focus on the activation point of each neuron.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/bigscale_movie.gif" class="img-fluid" %}
-
-
-Training a network with 1000 neurons with big initial values using gradient descent. The output of the network is in blue, and the four data points (red crosses) represent linear data. Each green triangle represents one neuron with its activation point horizontally, and its norm vertically. The orientation of the triangle reveals which side the neuron will activate the data. At initialization, the repartition of the activation point is uniform. The movement of the activation point is minimal, only a few neurons will change their patterns, among the thousands.
-
-
-Here, computing the convex optimal gives us a single neuron to fit the linear data. While the non-convex problem has converged to very low loss, their outputs are completely different.
-
-
A side effect of the large initialization is catastrophic overfitting i.e. there are very large variations between data points which will negatively impact test loss.
-
-
-### On very small initialization
-
-At the other extreme, the small-scale setting effectively lets neurons align themselves before ever decreasing the loss. In theory, if you push the scale down enough, neurons will converge to a finite set of directions before trying to fit the objective.
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/smallscale_movie.gif" class="img-fluid" %}
-
-
-Training a network with 1000 neurons with very small initial values using gradient descent. The output of the network is in blue, the four data points (red crosses) represent linear data. Each green triangle represents one neuron with its activation point horizontally, and its norm vertically. The orientation of the triangle reveals which side the neuron will activate the data. At initialization, the repartition of the activation point is uniform. However, as training progresses most neurons that activate toward the right converge to $-1.3$. Once the norm of the neuron at activating at $-1.3$ is large enough, the loss decreases and we quickly reach convergence.
-
-
-Taking a look at the loss on the same problem, we can identify the two distinct regimes: alignment and fitting (then convergence).
-
-{% include figure.html path="assets/img/2024-05-07-hidden-convex-relu/lastgif_plot.png" class="img-fluid" %}
-
Plot of the loss during gradient descent in the same setting as the animation above. In the first half only the directions of the neurons are changing (i.e. their activation patterns), and start fitting the four data points once their parameters are large enough.
-
-If you take orthogonal data and a small scale, the behavior is very predictable even in a regression setting.
-
-
Unless mentioned otherwise, all experiments were run using full batch vanilla gradient descent. In experiments, it is clear that adding momentum or using the Adam optimizer is much easier to use on top of being faster to converge. However, the behavior is much less predictable.
-
-## Conclusion
-
-The main takeaway is that the best network for a given dataset can be found exactly by solving a convex problem. Additionally, the convex problem can describe every local minimum found by gradient descent in the non-convex setting. However, finding the global optima is impossible in practice, and approximations are still costly in precision. While there is no evident link between feature learning in the non-convex and the convex reformulation, many settings allow for a direct equivalence and the whole convex toolkit for proofs.
-
-The performance side of the convex reformulation will benefit from dedicated software as has been the case for gradient descent in deep networks. Only then will it offer a no-tuning alternative to costly stochastic gradient descent. In smaller settings, it already allows us to quickly find all the possible local minima that are so important in machine learning.
-
-Despite advancements in understanding the optimization landscape of neural networks, a significant gap persists in reconciling theory with practical challenges, notably because of early stopping. In real-world scenarios, networks often cease learning before reaching a local minimum and this has a direct impact (in large-scale initialization) but there are limited results.
-
-## Acknowledgements
-
-This work is partly funded by the ANR JCJC project ANR-21-CE23-0022-01.
diff --git a/_posts/2024-05-07-language-model-development-as-a-new-subfield.md b/_posts/2024-05-07-language-model-development-as-a-new-subfield.md
deleted file mode 100644
index bf4eb8ee..00000000
--- a/_posts/2024-05-07-language-model-development-as-a-new-subfield.md
+++ /dev/null
@@ -1,132 +0,0 @@
----
-layout: distill
-title: A New Alchemy: Language Model Development as a Subfield?
-description: This blog post makes the case that the body of research on language models become sufficiently large and mature that we can start thinking about “language model development” as a new subfield.
- To support this claim, we sketch out the focuses and methodologies of this new subfield.
- In addition, we provide some personal reflections on what to do when your field of study gives birth to a new one.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Colin Raffel
- url: "https://colinraffel.com/"
- affiliations:
- name: University of Toronto, Vector Institute
-
-# must be the exact same name as your blogpost
-# bibliography: 2024-05-07-distill-example.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-toc:
- - name: Some history
- - name: Language model development
- - name: A New Alchemy
----
-
-Historically, language models have served as an important component of many learning systems -- for example, to improve the transcriptions generated by a speech recognition system.
-However, the impact and usage of language models has grown dramatically over the past few years.
-Arguably, this growth is simply thanks to the fact that language models have gotten *better*, i.e. more accurate at predicting some text based on some context.
-Since most text-based tasks can be cast as predicting a response to a request (e.g. "summarize the following article", "write me a Python function that queries Wikipedia", etc.), recent large language models (LLMs) have proven somewhat effective at performing an incredibly wide range of tasks.
-Improvements in the language understanding and generation capabilities of LLMs have also led to their adoption in many larger systems (e.g. robots, image processing/generation, etc.), where they increasingly enable natural language to be used as an interface.
-These advances have led to a huge amount of research into building and using language models.
-I think this body of research has become sufficiently large and mature that we can start thinking about "language model development" as a new subfield.
-The goal of this blog post is to sketch out the focuses and methodologies of the subfield of language model development as well as to provide some personal reflections on what to do when your field of study gives birth to a new one.
-
-
-## Some history
-
-As a subfield, language modeling has many sibling and parent fields, including information theory, artificial intelligence, natural language processing, and machine learning.
-In my biased opinion, many recent advances in language modeling have stemmed from advances in deep learning.
-When thinking about fields like deep learning, I think it can be valuable to define what the assumptions and major problems of the field are.
-For deep learning, I would roughly say that the assumptions are:
-
-1. We should end-to-end optimize everything.
-1. Training a bigger model on a bigger dataset should yield improved performance, but we should also strive to develop efficient and performant model architectures.
-1. If we can bake structure into our model (e.g. convolutions for images), things work better...
-1. but what we really want is a system that can learn everything from data and relies on as few hard-coded assumptions as possible.
-1. We care less about theoretical guarantees and more about how well something works in practice.
-
-Notably, the assumptions of a field are not necessarily scientifically or philosophically motivated - they can be cultural or arise from extraneous factors (e.g. the availability of GPUs).
-The major problems of the field of deep learning might be:
-
-1. How can we design neural network architectures that work well for a given problem, or better yet, across a wide variety of problems?
-1. Similarly, what objective works best?
-1. How should we optimize that objective?
-1. How can we ensure all of the above can be scaled up effectively?
-
-Arguably, one of the biggest successes of recent deep learning research is a powerful recipe for training effective models on a wide variety of problems, namely, the Transformer trained with some variant of Adam.
-While the objective used can vary across problem settings, in text-based problems a simple language modeling objective works well (and, as discussed above, encapsulates pretty much any text-based task).
-An important aspect of this Transformer recipe is its scalability, i.e. the ability to attain predictable gains from scaling up training compute and/or dataset size.
-
-## Language model development
-
-I think the scalability of the Transformer has ushered in a new era of research that is distinct from deep learning research.
-For the first time, we can (to a significant degree) stop worrying about what model architecture to use, how to train the model, what objective to use, whether we'll continue to get returns from scaling, etc.
-Instead, this new line of research primarily aims to study the development of language models in order to expand and understand their capabilities.
-In addition, the fact that recent LLMs are reasonably competent at a huge range of tasks has led to major differences in terms of how we use LLMs (when compared to e.g. how we built and used neural networks in the context of deep learning)
-For lack of a better term, I'll refer to this new (sub)field as "language model development", which might have the following assumptions:
-
-1. We can assume that the model architecture, optimizer, and objective are basically fixed.
-1. We hope or expect that a given LLM can be induced to perform basically any task out-of-the-box without performing any additional training (i.e. updating its parameters), and in general we should avoid updating parameters to specialize a model to a given task (i.e. task-specific fine-tuning).
-1. The computational cost of getting a model to perform a task is mostly irrelevant, or at least, these costs will be resolved by something else (e.g. better/more hardware).
-1. If we invest more compute in training an LLM, it will [produce better results](https://arxiv.org/abs/2001.08361).
-
-Arguably, some of these assumptions could be considered consequences of the fact that many state-of-the-art language models are only available through black-box APIs.
-The major problems of language model development are something like:
-
-1. How can we get the model to do what we want (i.e. "prompt engineering")?
-1. How can we make the model run as efficiently as possible?
-1. To the extent that we are going to update a model, how can we update it so that it is better at following instructions and less likely to generate harmful content (i.e. alignment)?
-1. More broadly, if we are really hoping the model can do *anything*, how do we prevent it from doing things we don't want it to?
-1. How can we integrate language models into other systems (i.e. tool use, multimodality, etc.)?
-
-Let me give a few additional examples of papers and techniques that I think aim to attack these problems under the aforementioned assumptions.
-
-- An early technique for "getting an LLM to do what we want" (goal #1) is [few-shot in-context learning (ICL)](https://arxiv.org/abs/2005.14165), where a few examples of the desired input/output behavior are provided in the model's input before the model is asked to process an unseen example.
- Few-shot ICL avoids updating the model's parameters (assumption #1) and mostly ignores the fact that it significantly increases computational costs (assumption #3).
- A related and more recent variant of ICL is ["chain-of-thought prompting"](https://arxiv.org/abs/2201.11903), which adds reasoning steps to the in-context examples in hopes of improving performance by inducing the model to generate similar reasoning steps before generating its prediction.
- The fact that including reasoning steps further increases computational costs is, again, mostly ignored (assumption #3).
-- Techniques like [FlashAttention](https://arxiv.org/abs/2205.14135) and [Speculative Decoding](https://arxiv.org/abs/2211.17192) aim to make the model run more efficiently (goal #2) without changing the model or its outputs whatsoever (assumption #1).
- More broadly, techniques like the [Heavy-Hitter Oracle](https://arxiv.org/abs/2306.14048) or [quantization](https://arxiv.org/abs/2208.07339) aim to reduce memory or computational costs with minimal performance degradation.
- The pursuit of these techniques, along with orthogonal hardware advances like NVIDIA's Transformer Engine, arguably supports the apparent disregard for increases in computational cost that arise from using a larger model (assumption #3).
-- While there certainly has been some effort to improve over the Transformer architecture or the optimizer used to train LLMs (in violation of assumption #1), the vast majority of these improvements have not been widely adopted, either due to inertia (i.e., enforcement of assumption #1) or the apparent fact that [they do not always transfer across applications](https://arxiv.org/abs/2102.11972).
-
-Separately, a sign of the maturity of a new subfield is the development of teaching materials.
-I think my friend Sasha Rush is leading the charge here, with e.g. [GPTWorld for learning prompting](https://github.com/srush/GPTWorld), [LLM training puzzles for learning about distributed training](https://github.com/srush/LLM-Training-Puzzles), and [Transformer puzzles for understanding how Transformers might work](https://github.com/srush/Transformer-Puzzles).
-Another sign is the establishment of a conference on the subject, and we [have one of those now too](https://colmweb.org/).
-
-## A New Alchemy
-
-LLMs have ushered in a paradigm shift in the path toward imbuing computers with human-like capabilities.
-This paradigm shift is being felt in various fields, including deep learning (where the work of designing new architectures or optimizers is increasingly less relevant), natural language processing (where we now have a recipe that works reasonably well across subproblems that previously demanded custom methodologies), and beyond.
-
-I started my PhD in 2012 during a similar paradigm shift from what I'd call "statistical machine learning" to deep learning.
-Unlike deep learning, statistical ML prioritized theoretical guarantees (e.g. convexity of the objective function and/or convergence under certain conditions).
-These guarantees arguably limited model expressivity, which arguably necessitated things like feature engineering that deep learning strove to avoid.
-While deep learning by no means "solved" the problems of statistical ML (just as language model development does not "solve" deep learning), it nevertheless presented a paradigm that made dramatic progress on the target problems of statistical ML and unlocked new applications.
-Such empirical successes of deep learning -- which almost entirely eschewed theoretical guarantees -- led to a great deal of hand-wringing on the part of the statistical ML crowd.
-
-As my research increasingly made use of deep learning, I started to find myself at the receiving end of this hand-wringing.
-For example, during my first-ever oral presentation at a conference, I was presenting work that made use of convolutional neural networks.
-During questions, an audience member expressed distaste at my use of "*convoluted*" neural networks and suggested that something simpler would have worked better (of course I had tried simpler models and they worked significantly worse, but let's put that aside for the moment).
-This kind of despair was common at the time - people were applying deep neural networks in settings where they may or may not have been overkill, simply because it was the zeitgeist.
-At another conference I attended during my PhD, I happened to share a hostel room with a computer vision researcher who went on a long rant about the atrocity of deep learning (sometimes I wonder what this researcher is working on now).
-I think this sentiment is most elegantly laid out in [Ali Rahimi's NeurIPS 2017 test-of-time award acceptance speech](https://www.youtube.com/watch?v=x7psGHgatGM), where he argues that deep learning is like alchemy - trial-and-error that yields some effective techniques but lacks rigor.
-Ali's speech had a big impact on me and others but arguably didn't really stop people from continuing to develop and apply deep learning without worrying about rigor and in settings where simpler methods would have sufficed (simply because using a big fancy neural network was sexier).
-
-These experiences led me to promise myself that when my field of study gave birth to another, I wouldn't dig my feet in and resist, I'd follow the tide of progress.
-Now that this is (arguably) happening I'm finding it more difficult than I had anticipated.
-As much as I wish it wasn't true, I cringe a little whenever I see a new LLM technique that ignores a dramatic increase in computational cost and bends over backwards to avoid updating the model's parameters, or an application of an LLM where something dramatically cheaper would suffice, or a paper studying the behaviors of an LLM as if it's a black box (or studying an LLM API, in which case it actually *is* somewhat of a black box), and on and on.
-And try as I might, I can't resist trying to stem the tide -- for example, the [T-Few paper](https://arxiv.org/abs/2205.05638) aimed to convince everyone that few-shot ICL was absurdly computationally inefficient and that fine-tuning specialized models is cheaper and better.
-Of course, people are still using few-shot ICL and are still avoiding task-specific fine-tuning at all costs, because that's the zeitgeist -- and I think this isn't totally wrong, because in tandem there's a huge amount of synergistic work on making LLMs more efficient and effective.
-But, to be honest, it still *feels* a little wrong, and I'm not sure if I'll be able to shake that feeling.
-
-So, what's the best course of action [when you used to be with it, but then they changed what "it" was](https://www.youtube.com/watch?v=LV0wTtiJygY)?
-I think there were many ML researchers who successfully rode the tide from statistical ML to deep learning -- they willingly embraced the new field while bringing their knowledge and sense of rigor to their deep learning research.
-In other words, they used their past knowledge to provide a broader and deeper perspective that newcomers may have lacked.
-An especially prominent product of this kind of research is arguably the [Variational Autoencoder (VAE)](https://arxiv.org/abs/1312.6114), which connected ideas from variational inference to the autoencoder neural network architecture.
-VAEs are still an important component of state-of-the-art diffusion-based generative models.
-Hopefully, those of us who were working on deep learning and NLP before the LLM era can bring a similar perspective (and avoid digging our feet in too much).
diff --git a/_posts/2024-05-07-mode-switching.md b/_posts/2024-05-07-mode-switching.md
deleted file mode 100644
index 7e1c5e08..00000000
--- a/_posts/2024-05-07-mode-switching.md
+++ /dev/null
@@ -1,615 +0,0 @@
----
-layout: distill
-title: Behavioral Differences in Mode-Switching Exploration for
- Reinforcement Learning
-description: In 2022, researchers from Google DeepMind presented an initial
- study on mode-switching exploration, by which an agent separates its
- exploitation and exploration actions more coarsely throughout an episode
- by intermittently and significantly changing its behavior policy. We
- supplement their work in this blog post by showcasing some observed
- behavioral differences between mode-switching and monolithic exploration
- on the Atari suite and presenting illustrative examples of its benefits.
- This work aids practitioners and researchers by providing practical
- guidance and eliciting future research directions in mode-switching
- exploration.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-
-authors:
- - name: Loren J Anderson
- url:
- affiliations:
- name: USA Space Force
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-mode-switching.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: 1. Introduction
- subsections:
- - name: Mode-Switching Distinctions
- - name: Mode-Switching Basics
- - name: Blog Post Motivation
- - name: 2. Experiments
- subsections:
- - name: Concentrated Terminal States
- - name: Early Exploration
- - name: Concentrated Return
- - name: Post-Exploration Entropy
- - name: Top Exploitation Proportions
- - name: 3. Conclusion
- subsections:
- - name: Acknowledgements
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-
----
-
-## 1. Introduction
-
-Imagine learning to ride a bicycle for the first time. This task
-requires the investigation of numerous actions such as steering the
-handlebars to change direction, shifting weight to maintain balance, and
-applying pedaling power to move forward. To achieve any satisfaction, a
-complex sequence of these actions must be taken for a substantial amount of
-time. However, a dilemma emerges: many other tasks such as eating, sleeping, and working may result in more immediate satisfaction (e.g. lowered hunger, better rest, bigger paycheck), which may tempt the learner to favor other tasks. Furthermore, if enough satisfaction is not quickly achieved, the learner may even abandon the task of learning to ride a bicycle altogether.
-
-One frivolous strategy (Figure 1, Option 1) to overcome this dilemma is to
-interleave a few random actions on the bicycle throughout the remaining
-tasks of the day. This strategy neglects the sequential nature of bicycle
-riding and will achieve satisfaction very slowly, if at all. Furthermore,
-this strategy may interrupt and reduce the satisfaction of the other daily
-tasks. The more intuitive strategy (Figure 1, Option 2) is to dedicate
-significant portions of the day to explore the possible actions of bicycle
-riding. The benefits of this approach include testing the sequential
-relationships between actions, isolating different facets of the
-task for quick mastery, and providing an explicit cutoff point to shift
-focus and accomplish other daily tasks. Also -- let's face it -- who wants to wake up in the middle of the night to turn the bicycle handlebar twice
-before going back to bed?
-
-{% include figure.html path="assets/img/2024-05-07-mode-switching/bike.png" class="img-fluid" %}
-
- Figure 1: Illustrative difference between monolithic and mode-switching
-behavior policies .
-
-
-The above example elicits the main ideas of the paper *When Should Agents
-Explore?* , published by
-researchers from Google DeepMind at ICLR 2022, which is the central piece
-of literature discussed throughout this blog post. The first strategy
-presented in the preceding paragraph is known as a **monolithic** behavior
-policy that interleaves exploration actions (e.g. learning to ride a
-bicycle) among the more frequent exploitation actions (e.g. work, sleep) in
-a reinforcement learning (RL) environment. In contrast, the second strategy
-presented above is a **mode-switching** behavior policy, as it more
-coarsely separates exploration and exploitation actions by switching
-between disparate behavior modes throughout an episode. Mode-switching
-policies subsume monolithic policies at the cost of increased complexity
-through introducing a new question: *when to switch*. Similar aspects of
-mode-switching for diverse exploration have been observed in the
-exploratory behavior of humans and animals , which served as a notable motivation for the initial mode-switching study .
-
-This introduction section continues with a brief discussion of topics
-related to mode-switching behavior policies, ranging from different temporal
-granularities to algorithms in the literature that exhibit mode-switching
-behavior. We emphasize practical understanding rather than attempting to present an
-exhaustive classification or survey of the subject. Afterwards, we discuss
-our motivation and rationale for this blog post: the authors of the initial
-mode-switching study showed that training with mode-switching
-behavior policies surpassed the performance of training with monolithic
-behavior policies on hard-exploration Atari games; we augment their work by
-presenting observed differences between mode-switching and monolithic
-behavior policies through supplementary experiments on the Atari benchmark
-and other illustrative environments. Possible avenues for applications and
-future investigations are emphasized throughout the discussion of each experiment. It is assumed that the interested reader has basic knowledge in RL techniques and challenges before proceeding to the rest of this blog post.
-
-### Mode-Switching Distinctions
-
-Mode-switching behavior policies (which we will sometimes shorten to
-*switching
-policies*, and likewise to *monolithic policies*) were explicitly
-introduced in the initial mode-switching study,
-and we will now focus on briefly contrasting switching policies against
-monolithic policies and the previous exploration literature. Figure 2
-illustrates the high-level, pivotal difference between switching and
-monolithic policies: at the beginning of each time step, the agent may use
-all of its available information to determine its behavior mode
-for the current time step and then output a corresponding behavior policy to
-determine
-the action. A key distinction is that switching policies can drastically
-change between time steps since the modes can be tailored to a variety of
-different purposes (e.g. exploration, exploitation, mastery, novelty). As
-the graphic illustrates, switching is such a general addition to an
-algorithm that it was not exhaustively characterized in the initial study.
-
-{% include figure.html path="assets/img/2024-05-07-mode-switching/box.png" class="img-fluid" %}
-
- Figure 2: Introduction of mode-switching behavior to standard
-agent-environment RL interaction.
-
-
-A **mode period** is defined as a sequence of time steps in a single mode.
-At the finest granularity, *step-level* periods only last one step in
-length; the primary example is $\epsilon$-greedy exploration because its
-behavior policy switches between explore and exploit mode at the level of
-one time step . At the other extreme,
-*experiment-level* periods encompass the entire training duration, possibly
-to be used in offline RL (ORL) algorithms . A finer granularity is *episode-level*, in which a single behavior policy is chosen for one entire episode at a time, such as when diversifying the stochasticity of a policy throughout training . The switching policies analyzed in this blog post produce *intra-episodic* periods at a granularity between step-level periods and episode-level periods. Intra-episodic periods generally occur at least a few times during an episode and last for more than a few time steps. The practice and study of interpolating between extremes has occurred in areas such as $n$-step returns and colored noise with notable success, making the study of intra-episodic mode periods even more enticing.
-
-The question investigated by the initial mode-switching study is *when to
-switch*. This blog post and the initial study only perform experiments
-with two possible modes, exploration and exploitation, so the question of
-*when to switch* reduces to the question of *when to explore*. Other
-questions regarding exploration include *how much to explore* that analyzes
-the proportion of exploration actions taken over the entire course of
-training. This problem encompasses the annealing of exploration
-hyperparameters including $\epsilon$ from $\epsilon$-greedy policies and the entropy bonus $\beta$ from softmax
-policies . Another related
-question is *how to explore* that includes strategies such as randomly , optimistically , and intrinsically . These two questions are separate from the question of *when* to explore, as they usually consider a smooth change in the behavior policy after each time step; switching policies incorporate a much more rigid change in the behavior policy, meriting a separate analysis.
-
-### Mode-Switching Basics
-
-The preceding subsection narrowed our focus to determining *when to explore*
-using *intra-episodic* mode periods. At the time of publication of the
-initial mode-switching study, the previous literature contained a
-few works that had incorporated basic aspects of intra-episodic
-mode-switching exploration. For example, Go-Explore is a resetting algorithm that explores randomly after resetting to previously-encountered
-promising states at the beginning of an episode. However, this algorithm
-implements only one switch from resetting to exploration over the course of
-an episode. Temporally-extended $\epsilon$-greedy exploration generalizes $\epsilon$-greedy
-exploration by sampling from a distribution the number of time steps that an
-exploration action should repeat. This method of switching is
-intra-episodic, but it only allows repetition of an action during explore
-mode. The initial mode-switching study extends the above and other work in
-many dimensions and may soon be viewed as the seminal work on
-mode-switching behavior policies; we discuss the most fundamental facets of
-mode-switching architectures below.
-
-The **starting mode** is the mode of the algorithm on the first time step,
-usually exploit mode. The **set of behavior modes** (e.g. explore and
-exploit) must contain at least two modes, and the set of behaviors induced
-by all modes should be fairly diverse. The switching **trigger** is the
-mechanism that prompts the agent to switch modes and is perhaps the most
-interesting consideration of switching policies. An *informed* trigger
-incorporates aspects of the state, action, and reward signals; it is actuated after crossing a prespecified threshold such as the
-difference between the expected and realized reward. A *blind* trigger acts
-independently of these signals; for example, it can be actuated after a
-certain number of time steps has elapsed or actuated randomly at each time
-step with a prespecified probability. A **bandit meta-controller** may be employed to choose the switching
-hyperparameters (e.g. termination probability, mode length, informed threshold) at the beginning of each episode to maximize episodic return and prevent additional hyperparameter tuning. Finally, **homeostasis** can be added when using trigger thresholds (e.g. for informed triggers), which adapts the switching threshold to a target rate across the course of training, again for ease of hyperparameter tuning. Note that these dimensions are so richly diverse that we end the associated discussion to maintain any notion of brevity, and we summarize these facets of mode-switching in Table 1.
-
-| ------------- |-------------|
-| Mode-Switching Facet | Description |
-| ------------- |-------------|
-| Starting Mode | Mode during first time step at episode start |
-| Behavior Mode Set | Set of modes with diverse set of associated behavior policies |
-| Trigger | Informs agent when to switch modes |
-| Bandit Meta-Controller | Adapts switching hyperparameters to maximize episodic return |
-| Homeostasis | Adapts switching threshold to achieve a target rate |
-| ------------- |-------------|
-
-
-
- Table 1: Various facets of mode-switching policies .
-
-
-### Blog Post Motivation
-
-The initial mode-switching study performed experiments solely on 7
-hard-exploration Atari games. The focus of the study was to show the
-increase in score on these games when using switching
-policies versus monolithic policies. One area of future work pointed out by
-the reviewers is to increase the understanding of these less-studied
-policies. For example, the [meta review](https://openreview.net/forum?
-id=dEwfxt14bca¬eId=C0cPgElgV7P) of the paper stated that an illustrative
-task may help provide intuition of the method. The [first reviewer](https://openreview.net/forum?id=dEwfxt14bca¬eId=Fjc2fBjmhwZ) noted how
-the paper could be greatly improved through demonstrating specific benefits
-of the method on certain tasks. The [second reviewer](https://openreview.net/forum?id=dEwfxt14bca¬eId=e3xcQZnyuyt) stated how discussing observed differences on the different domains may be useful. The [third reviewer](https://openreview.net/forum?id=dEwfxt14bca¬eId=Qcv_GiwGPhr) mentioned how the paper could be strengthened by developing guidelines for practical use. The [last reviewer](https://openreview.net/forum?id=dEwfxt14bca¬eId=W6v6g6zFQHi) stated that it would be helpful to more thoroughly compare switching policies to monolithic policies for the sake of highlighting their superiority.
-
-We extend the initial mode-switching study and progress towards
-further understanding of these methods in this blog post through additional
-experiments. The following experiments each discuss an observed behavioral
-difference in switching policies versus monolithic policies. We focus
-on behavioral differences in this work, as they are observable in the
-environment and are not unique to the architecture of certain agents . Our experiments are performed
-on 10 commonly-used Atari games , and we also provide another
-illustrative task or chart for each experiment to further enhance
-understanding. One highlight of this work is showcasing how switching
-policies not only influence exploration but also significantly influence
-exploitation. Our work serves as a first step in empirically delineating
-the differences between switching policies and monolithic policies for the use of practitioners and researchers alike.
-
-## 2. Experiments
-
-This section begins with a discussion on the experimental setup before
-delving into five experiments that highlight observational differences in
-switching and monolithic behavior policies. The complete details of the
-agent and environments can be found in the accompanying [GitHub repository](https://github.com/LorenJAnderson/when-to-explore).
-- The experimental testbed is comprised of 10 commonly-used Atari games: Asterix, Breakout,
- Space Invaders, Seaquest, Q*Bert, Beam Rider, Enduro, MsPacman, Bowling,
- and River Raid. Environments follow the standard Atari protocols of incorporating sticky actions and only providing a terminal signal when all lives are lost.
-- A Stable-Baselines3 DQN policy
- is trained on each game for 25 epochs of 100K time steps each, totaling 2.5M time steps or 10M frames due to frame skipping. The DQN policy
- takes an exploration action on 10% of time steps after being linearly
- annealed from 100% across the first 250K time steps.
-- A switching policy and monolithic policy were evaluated on the testbed
- using the greedy actions of the trained DQN policy when taking
- exploitation actions. Evaluations were made for 100 episodes for each
- game and epoch. The monolithic policy was $\epsilon$-greedy with a 10%
- exploration rate. The switching policy we chose to examine
- incorporates blind switching; we leave an analogous investigation of
- informed switching policies to future work (see initial study for
- background and experiments using informed switching policies).
- The policy begins in
- exploit mode and randomly switches to uniform random explore mode 0.7% of
- the time. It randomly chooses an explore mode length from the set $\\{5,
- 10, 15, 20, 25\\}$ with probabilities $\\{0.05, 0.20, 0.50, 0.20, 0.05\\}
- $. During experimentation, we determined that this switching
- policy took exploration actions at an almost identical rate as the
- monolithic policy (10%).
-
-We briefly cite difficulties and possible confounding factors in our
-experimental design to aid other researchers during future studies on this
-topic.
-- The DQN policy was trained using a monolithic policy, and unsurprisingly,
-monolithic policies had slightly higher evaluation scores. Additional
- studies may use exploitation actions from a policy trained with switching
- behavior for comparison.
-- Many of our experiments aim to evaluate the effect of exploration
- or exploitation actions on some aspect of agent behavior. Due to delayed
- gratification in RL, the credit assignment problem persists and confounds the
- association of actions to behaviors. To attempt to mitigate some
- confounding factors of this problem, we weight the behavior score of the
- agent at an arbitrary time step by the proportion of exploration or
- exploitation actions in a small window of past time steps; for example,
- in the first experiment, we weight the effect of taking exploration
- actions on yielding terminal states by calculating the proportion of exploration
- actions within 10 time steps of reaching the terminal state. Then, we
- average the proportions across 100 evaluation episodes to compute a final score for a single epoch for a single game.
-- Lastly, we only claim to have made observations about the behavioral differences, and we do not claim to have produced statistically significant results; we leave this analysis to future work.
-
-### Concentrated Terminal States
-
-Exploration actions are generally considered to be suboptimal and are
-incorporated to learn about the state space rather than accrue the most
-return. Many environments contain regions of the state space that simply do
-not need more exploration, such as critical states that require directed behavior for
-meaningful progress. For instance, a self-driving car needing to merge onto
-a highway is in a critical state, as it has few behaviors that will keep it
-driving correctly. In these critical states, suboptimal action choices may
-cause the agent to reach a terminal state more quickly than desired. We
-investigate if terminal states are more concentrated after an exploration
-period of a switching policy due to the many exploration actions taken in
-succession.
-
-Our first experiment attempts to analyze the relationship between taking
-many exploration actions in succession and reaching a terminal state. Each
-terminal state is given a score equal
-to the proportion of exploration actions during the past 10 time steps (see
-second paragraph of Experiments section for rationale). Final scores for
-each behavior policy and epoch are computed by averaging the scores of each terminal state across all 100 evaluation episodes and each game. The results are shown in
-Figure 3. Switching policies produced terminal states that more closely
-followed exploration actions. Furthermore, the effect was more pronounced
-as the policies improved, most likely due to the increased disparity of
-optimality between exploitation and exploration actions that seems more
-detrimental to switching policies which explore multiple times in
-succession. Note how the scores for monolithic policies are near 0.10 on
-average, which is the expected proportion of exploration actions per
-episode and therefore suggests that exploration actions had little effect.
-These results demonstrate that switching policies may be able to
-concentrate terminal states to specific areas of an agent's trajectory.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_1_1.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_1_2.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 3 (Left): Terminal states are more concentrated after switching
-exploration periods. Figure 4 (Right): Switching policies perform better on
-cliffwalk environments.
-
-
-We showcase a quick illustrative example of the ability of switching
-policies to concentrate terminal states more uniformly in a cliffwalk
-environment (Figure 4). The agent starts at the black circle in the middle
-column and top row of a 101$\times$11 grid and attempts to reach the white
-'x' at the bottom. All states aside from those in the middle column are
-terminal, and the heatmaps show the visitation frequency per episode of all
-non-terminal states across 10K episodes. When the exploitation policy is to
-move only downward
-and the behavior policies are the usual policies in these experiments, the
-agent incorporating a switching policy more heavily
-concentrates the terminal states in exploration mode and visits states
-further down the cliffwalk environment at a higher rate per episode.
-
-
-Environments that incorporate checkpoint states that agents must traverse
-to make substantial progress may benefit from switching policies that
-concentrate exploration periods away from the checkpoints. For example, the
-game of Montezuma's revenge sometimes
-requires that the agent retrieves a key before advancing through a door,
-and the agent may achieve faster learning by concentrating exploration
-actions away from states near the key after that action is learned. One
-notable and emerging area of RL research that may benefit from
-concentrating terminal states is safe RL . In safe RL, certain safety constraints are
-required during the learning and deployment process. In some situations,
-the safety constraints are closely aligned with terminal states (e.g. aerospace ), and concentrating exploration actions away from terminal states may aid in achieving those safety constraints.
-
-### Early Exploration
-
-Monolithic policies uniformly take exploration actions throughout an episode,
-and as a result, the exploration steps are less concentrated than those of
-switching policies. While the expected number of exploration steps may be
-the same per episode in monolithic policies, certain situations may require
-more concentrated exploration during the beginning of episodes. For example,
-the build orders in StarCraft II significantly influence the possible
-future strategies, making exploration crucial throughout the beginning time
-steps. Early suboptimal actions have also been manually implemented to
-achieve certain effects: passive actions are taken in Atari games to
-prevent memorization of trajectories , and 30 random actions were taken at
-the beginning of Go games when training the AlphaGo engine to force agents
-to encounter more diverse data . We investigate the flexibility of switching policies to concentrate exploration actions in the beginning of episodes.
-
-We perform an experiment to determine how quickly a policy takes a
-prespecified number of exploration actions. Specifically, we compute the
-average number of time steps it takes for a policy to take at least $x$
-total exploration actions across its top 10 of 100 fastest episodes, and we repeat this process for $x \in \\{1, 2, 3, \ldots,
-20\\}$. We compare the top 10 fastest episodes because we are only
-interested in gauging the flexibility of switching behavior of being able
-to achieve this specific facet of exploration (beginning exploration)
-during a small percentage of episodes and not for each episode. Note that
-this experiment did not need to utilize the Atari signals, so we only used
-data from the last epoch. Results were again averaged over each game and
-shown in Figure 5. It is clear that some episodes contain many more
-exploration actions concentrated
-in the beginning few time steps with switching policies. This makes sense
-intuitively, as only one switch needs to occur early in an episode with a
-switching policy for many exploration actions to be taken immediately
-afterwards. The difference increases roughly linearly for greater number of necessary exploration actions and shows that switching natively produces more episodes with exploration concentrated in the beginning.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_2_1.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_2_2.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 5 (Left): Switching policies can explore more frequently earlier
-during the episode. Figure 6 (Right): Switching policies have better
-exploration near the start state on downwalk environments.
-
-
-We illustrate beginning exploration with a downwalk environment in which an
-agent attempts to first move to the middle column and then down the middle
-column to the white 'x' (Figure 6). The agent starts in the
-second row in the middle column at the white circle, and visitation
-frequencies across 1K episodes are shown for all states aside from those
-between the white circle and the white 'x', inclusive. We chose to
-analyze this environment because it is a crude approximation of the trajectory of agents that have learned a single policy and immediately move away from the initial start state at the beginning of an episode. The switching and monolithic policies are the same as before, and switching produces much higher visitation counts at states further from the obvious exploitation trajectory.
-
-Environments that may benefit from flexible early exploration are sparse
-reward environments that provide a single nonzero reward at the terminal
-state. Many game environments fall into this category, since a terminal
-reward of 1 can be provided for a win, -1 for a loss, and 0 for a draw. In
-such environments, agents usually need to learn at states near the sparse
-reward region before learning at states further away, also known as
-cascading . After learning near
-the sparse reward region, the agent may need to reconsider earlier actions,
-and switching policies natively allow for this type of exploration. Future
-work may consider the extent to which switching aids in improving policies
-near the start state in sparse reward environments.
-
-### Concentrated Return
-
-In contrast to the investigation in the first experiment, exploitation
-actions of a trained agent are presumed to be better than all other
-alternatives. Since agents aim to maximize the expected return in an
-environment, exploitation actions often accrue relatively large amounts of
-expected return. For example, the initial experiments of DQN and double DQN (DDQN) decreased the exploration constant (thereby
-increasing exploitation) during testing runs to achieve higher scores and
-ultimately demonstrate superhuman performance on Atari. In this subsection, we investigate the effect of the concentrated exploitation actions of switching policies on expected return.
-
-We perform an experiment to determine the proportion of return that is
-concentrated during exploitation periods. Each reward during an episode is
-weighted by the proportion of exploitation actions during the past 10 time
-steps. The score for each episode is the sum of weighted rewards divided by
-the total rewards. Scores for each behavior policy and epoch are computed
-by averaging scores across all games. The results are shown in Figure 7.
-Quite quickly, exploitation steps of switching policies contain a greater
-percentage of the return than those of monolithic policies. This trend seems
-fairly constant after roughly 2M frames, with switching policies having
-roughly 95% of the return in exploitation steps and monolithic policies
-having roughly 90% of the return; from another point of view, exploration
-steps yield 5% of the return for switching policies and 10% of the return for
-monolithic policies. These results agree with Experiment 1, as switching
-policies will generally reach terminal states more frequently in explore
-mode and will not receive more rewards. Since most of the rewards in our
-selected Atari games are positive, switching policies should accrue
-lower return while in explore mode.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_3_1.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_3_2.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 7 (Left): Switching policies concentrate return in exploitation
-mode. Figure 8 (Right): Switching policies concentrate return in the
-beginning of episodes.
-
-
-One notable case in which exploitation steps are concentrated together is in
-resetting methods such as Go-Explore
-that reset to promising states at the beginning of the episode and explore
-from there. Promising states are usually defined as states that are
-frequently traversed in trajectories that accrue high return. More
-generally, resetting methods aim to prevent *derailment*, whereby an agent
-is unable to return or is *derailed* from returning to promising states
-through its exploratory mechanisms. Since our switching agent begins in
-exploit mode which aims to accrue the most return, we investigate to see if
-switching policies possess characteristics that are inherent to resetting
-methods.
-
-In Figure 8, we plot the proportion of episode return over the past 5% of
-the episode versus the current proportion of episode that is complete. Data
-is taken from the last training epoch. The results show that switching
-policies concentrate return more towards the beginning of each episode,
-most likely because its first exploit mode of switching policies is
-relatively long. Future work involves determining the extent to which the
-beginning exploitation mode of switching policies serves as a flexible
-alternative to resetting, which would have applications in situations
-that do not allow for manual resets such as model-free RL.
-
-### Post-Exploration Entropy
-
-Monolithic policies such as $\epsilon$-greedy are nearly on-policy when any
-exploration constants have been annealed. In contrast, the exploration
-periods of switching policies are meant to free the agent from its current
-exploitation policy and allow the agent to experience significantly
-different trajectories than usual. Due to the lack of meaningful learning at
-states that are further from usual on-policy trajectories, the exploitation actions at those states are more likely to have greater diversity. In this experiment, we investigate the diversity of the action distribution after exploration periods.
-
-We quantify the diversity of the realized action distribution in the time
-step immediately after each exploration period. The diversity is quantified
-by entropy that has higher values for more random data and vice versa. An
-action distribution is constructed for each game and epoch, and
-the entropies across games are averaged. The results are shown in Figure 9.
-The entropy of the action distribution for switching policies is distinctly
-greater than that of monolithic policies. Like most of the previous results, this quantity only plateaus until roughly 2M frames have elapsed.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_4_1.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_4_2.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 9 (Left): Switching policies produce action distributions
-with higher entropy after exploration periods. Figure 10 (Right): Agent has
-random exploitation actions in states that are visited less frequently.
-
-
-To illustrate this idea, we create a gridworld environment that provides
-the agent a reward of -1 for each time step that the agent is still on the
-grid; the agent's goal is to leave the grid as quickly as possible. The
-agent begins in the center of the grid and learns through discrete
-Q-learning. Distinct actions have separate colors in Figure 10, with arrows
-showing the exploit action. The agent learns that it is fastest to exit the
-grid by going left or right. Notably, the actions near the top and bottom
-of the grid are seemingly random, as the agent has not seen and learned from those states as frequently as the others. Switching
-policies are more likely to reach the top and bottom areas of the gridworld
-state space and consequently would be more likely to have a higher entropy
-of the action distribution after exploration.
-
-The difference in the entropy of the action distributions suggests that
-more diverse areas of the state space may be encountered after exploration
-modes with switching policies. This phenomenon is closely tied to the
-notion of *detachment* , whereby
-agents forget how to return or are *detached* from areas of high reward,
-perhaps by focusing too unimodally on one region of the state space. The concentrated behavior of switching policies may provide enough consecutive exploration actions to explore a more diverse set of trajectories. Future work could investigate the ability of switching policies to curb detachment on environments with multiple regions of the state space with high reward.
-
-### Top Exploitation Proportions
-
-Our final investigation involves the change in exploitation proportion
-under switching policies. Since the probability of switching to explore
-mode is very low, there may be some episodes where the switch seldom happens
-if at all. This creates a distribution of exploitation action proportions
-per episode that is more extreme than that of monolithic policies, yet it
-is still not as extreme as using a single mode throughout the entire
-episode. Investigations of methods having similar interpolative
-characteristics have been conducted recently; for example, an action noise
-called pink noise was recently
-introduced that achieved better performance than white and red noise. Pink
-noise is more temporally-correlated than white noise but not as much as red noise. Here, we investigate the return of the most extreme episodes in exploitation proportion.
-
-We perform an experiment to compare the return of the episodes with
-highest exploitation proportions between switching and monolithic policies.
-The returns of the top 10 of 100 episodes ranked by exploitation proportion
-of each epoch and game were averaged. Then, a ratio between the averages of
-switching and monolithic policies was computed and averaged across games. The
-results are plotted in Figure 11.
-There does not appear to be a clear trend aside from the ratio hovering mostly above 1.00, indicating that the top exploitation episodes of switching policies accrue more return than those of monolithic policies.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_5_1.png" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-mode-switching/exp_5_2.png" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 11 (Left): Switching policies have higher return for episodes
-with largest exploit proportion. Figure 12 (Right): Switching policies have
-more extreme exploration and exploitation proportions per episode.
-
-
-The results are best illustrated through plotting the switching and
-monolithic exploitation proportions for 1K episodes (10 games of the last
-epoch) as shown in Figure 12. The top 100 episodes with highest
-exploitation proportion take more exploitation actions than any monolithic
-episode. Therefore, the corresponding distribution is indeed more
-extreme.
-
-While the previous discussion has illustrated that some switching episodes
-exploit more and generate more return, they don't specifically explain why
-training with mode-switching is superior; in particular, the slightly
-greater return is not necessary for learning an optimal policy as long as a
-similar state distribution is reached during training. One
-possibility is the fact that mode-switching policies train on a more
-diverse set of behavior and must generalize to that diversity.
-Reinforcement learning algorithms are notorious at overfitting , and future work
-may investigate the extent to which generalization is improved upon using switching policies.
-
-
-## 3. Conclusion
-
-This blog post highlighted five observational differences between
-mode-switching and monolithic behavior policies on Atari and other
-illustrative tasks. The analysis showcased the flexibility of mode-switching policies, such as the ability to explore earlier in episodes and exploit at a notably higher rate. As the original study of mode-switching behavior by DeepMind was primarily concerned with performance, the experiments in this blog post supplement the study by providing a better understanding of the strengths and weaknesses of mode-switching exploration. Due to the vast challenges in RL, we envision that mode-switching policies will need to be tailored to specific environments to achieve the greatest performance gains over monolithic policies. Pending a wealth of future studies, we believe that mode-switching has the potential to become the default behavioral policy to be used by researchers and practitioners alike.
-
-### Acknowledgements
-
-We thank Nathan Bittner for a few helpful discussions on the topic of
-mode-switching exploration. We also thank Theresa Schlangen (Theresa
-Anderson at the time of publication) for helping polish some of the
-figures.
diff --git a/_posts/2024-05-07-primacy-bias-and-why-it-helps-to-forget.md b/_posts/2024-05-07-primacy-bias-and-why-it-helps-to-forget.md
deleted file mode 100644
index 250251f8..00000000
--- a/_posts/2024-05-07-primacy-bias-and-why-it-helps-to-forget.md
+++ /dev/null
@@ -1,425 +0,0 @@
----
-layout: distill
-title: "It's Time to Move On: Primacy Bias and Why It Helps to Forget"
-description: "'The Primacy Bias in Deep Reinforcement Learning' demonstrates how the first experiences of a deep learning model can cause catastrophic memorization and how this can be prevented. In this post we describe primacy bias, summarize the authors' key findings, and present a simple environment to experiment with primacy bias."
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-
-authors:
- - name: Matthew Kielo
- url: https://mkiel.org/
- affiliations:
- name: Georgia Institute of Technology
- - name: Vladimir Lukin
- url: https://github.com/divannyteoretik
- affiliations:
- name: Georgia Institute of Technology
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-primacy-bias-and-why-it-helps-to-forget.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Introduction to Primacy Bias
- - name: Off Policy Deep Reinforcement Learning
- subsections:
- - name: Are we Overcomplicating?
- - name: Selecting a Replay Ratio
- subsections:
- - name: Heavy Priming
- - name: Weight Resets
- subsections:
- - name: Do Resets Work?
- - name: "What’s The Catch?"
- - name: Implementing Primacy Bias
- subsections:
- - name: 2x2 Switching Frozen Lake
- - name: Results
- - name: Conclusions
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-# This is a test??
-
-
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-## Introduction to Primacy Bias
-
-Primacy bias occurs when a model's training is damaged by overfitting to its first experiences. This can be caused by poor hyperparameter selection, the underlying dynamics of the system being studied, or simply bad luck.
-
-In this post we explore the paper “Primacy Bias in Deep Reinforcement Learning” by Nikishin et al. and presented at ICML 2022 . We will present primacy bias and how it applies to deep reinforcement learning, discuss how the authors prevent primacy bias, and finish by experimenting with our own toy example of primacy bias.
-
-Like many deep learning concepts, primacy bias takes inspiration from psychology . For example, you might have a friend who “doesn’t like math” because they had a bad experience in primary school. Now, they avoid the subject despite having an aptitude for it. It turns out that for humans and machines, first impressions matter more than they should. This is primacy bias.
-
-## Off Policy Deep Reinforcement Learning
-
-Nikishin et al. discuss a specific type of model that is particularly sensitive to primacy bias: *off-policy deep reinforcement learning*. Here, the goal is to learn a (*policy*) that makes good decisions in an interactive environment. Off-policy algorithms achieve this by separating decision-making from learning. Deep Q-Learning (DQN) was one of the first popular off-policy algorithms, which separates the learning process into two steps:
-
-1. Data Collection: use the current policy to interact with the environment and save memories to a dataset called the *replay buffer*.
-2. Learning: sample from the replay buffer to perform gradient updates on the policy.
-
-### Are we Overcomplicating?
-For those without a reinforcement learning background, this might seem needlessly complicated. Why can’t we simply explore with a random policy and then fit a model all at once?
-
-Although this is sometimes done , the quality of the memories in the replay buffer is proportionate to the quality of the policy that gathered the experience. Consider an agent learning to play chess. A random policy might have enough data to learn how to play the start of the game effectively, but it will never learn how to chase an opponent’s king around an empty board. If a policy isn’t smart enough to get the agent out of the ‘early' game, it will never collect experiences to learn the ‘mid’ or ‘late' games.
-
-
-## Selecting a Replay Ratio
-
-The *replay ratio* is the total number of gradient updates per environment interaction. If the number of experiences is fixed, then modifying the replay ratio is equivalent to changing the number of training epochs in a typical deep learning problem.
-
-Most researchers know the importance of training for a sufficient number of epochs. Training for more epochs is preferred and methods such as early stopping, weight regularization, and dropout layers can mitigate the risk of overfitting. At worst, if you end up with an overfit model then you can retrain it from scratch.
-
-In deep reinforcement learning, the replay ratio is typically set to one. Unfortunately, finding the correct replay ratio is difficult. We want the agent to learn as much as possible but there is a path-dependency that is hard to ignore. If the policy becomes overfit early it will have less meaningful interactions with the environment, creating negative feedback. If you don’t catch overfitting in your Poker Bot until it loses a couple tournaments, then you might have spent a lot of money for a dataset on how to lose poker hands.
-
-### Heavy Priming
-
-To quantify this, Nikishin et al. perform an experiment with heavy priming. The goal is to train an agent on the "quadruped-run" environment, where an agent learns to manipulate joint movement to travel forward.
-
-First, a baseline is trained with default parameters. Next, to create heavy priming, the agent collects 100 interactions and then trains for 100K steps. The model with heavy priming fails to ever recover in an example of catastrophic memorization.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/heavy-priming.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Example of Heavy Priming by Nikishi et al.
-
-
-
-## Weight Resets
-
-To avoid primacy bias, Nikishi et al. propose the following solution: freely increase the replay ratio, but periodically perform a *weight reset* to reinitialize all of the agent’s weights while preserving the replay buffer. This destroys any learned information in the network's weights. At worst, if there is no primacy bias, the replay buffer will contain enough information to retrain to the previous weights. At best, primacy bias is eliminated, and the model finds a new optima.
-
-To think about this concretely, consider a 100 step training loop. At each step we:
-
-1. Gather 1 observation.
-2. Add it to the replay buffer.
-3. Select a random sample from the replay buffer.
-4. Perform a gradient update to the model with the sample.
-
-After 100 steps, the first observation will have been sampled on average 5.19 times. The 50th observation will have been sampled 0.71 times, and the 99th observation will have been sampled on average 0.01 times. This can be summarized in a plot.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/samples11.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
- How often an example is sampled on average in a 100 step training loop.
-
-
-
-
-Some solutions to mitigate this include recency weighting or using prioritized experience replay , however, weight resets offer a theoretically parameter free way to fix this. If weights are trained from scratch at every step then all prior observations will have equal influence.
-
-In practice, weight resets are a bit more complicated. Ideally, we retrain the model from scratch after each observation. Unfortunately this isn’t realistic (on my computer). This leaves us with two decisions:
-
-1. Select a reset frequency.
-2. Decide what to reset.
-
-Resetting often will prevent primacy bias but this requires a high replay ratio. This trade-off is discussed in detail in the follow up work "Sample-Efficient Reinforcement Learning by Breaking the Replay Ratio Barrier" published at ICLR in 2023. In particular, a heatmap is shared showing the trade-off between data and computation budget on a dynamic motion control problem:
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/compute-data-tradeoff.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
-
- "Performance of SR-SAC in DMC15 as a function of the number of interactions and of the number of agent updates, determined by the replay ratio."
-
-
-
-
-### Do Resets Work?
-
-Nitkshi et al. show that on average resets work well.
-
-1. Immediately after a reset there is a sudden drop in performance that quickly recovers.
-2. Resets never irreparably harm a model. At worse, the model returns to the pre-reset level (ex: cheetah-run), but sometimes it can perform substantially better (humanoid-run).
-
-These results are consistent across multiple algorithms and environments, including the continuous control Deep Mind Control Suite and the discrete Atari 100k benchmarks.
-
-
-Episode return overtime on a subset of DeepMind Control, with and without resets, using SAC algorithm. Averaged over 10 random seeds.
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/mujuco-resets-sample.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 4,
-
-
-
-
-Episode return overtime in DeepMind Control, with and without resets, using the DRQ algorithm. Averaged over 20 random seeds.
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/mujuco-resets-full.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Figure 18, from Appendix C)
-
-
-
-
-
-Per-game scores in Atari, with and without reset, using the SPR algorithm. Averaged over 20-100 random seeds.
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/atari.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Table 7, from Appendix C)
-
-
-
-
-After seeing the success of resets, it is reasonable to wonder how weight resets compare to other regularization tools. The authors test this as well and show that resets improve outcomes in their experiments on average more than either dropout or L2 regularization (which actually perform worse than the baseline).
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/dropoutsetc.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Comparison of Base Algorithm, Resets (+ resets), Dropout (+ dropout), and L2 (+ L2). Averaged over 10 runs.
-
-
-
-
-### What's The Catch?
-
-While these results are impressive, they come at a cost. At minimum, increasing the replay ratio increases the compute time linearly. D'Oro et al 2023 note that running the full dynamic control benchmark with a replay ratio of 32 takes 4 GPU days with a NVIDIA V100. Using a replay ratio of 16 on Atari 100K requires 5 GPU hours per run.
-
-Additionally, implementing weight resets requires a sneaky number of design decisions. The results from the paper show reset rules specifically chosen for each environment and algorithm.
-
-Some of these considerations include:
-
-1. How often should you reset? Every step is ‘ideal’ but it is also ideal to get results this year.
-2. What is the optimal replay ratio to maximally learn per sample and sustain the reset frequency?
-3. What exactly should I reset? Full model? Last layer?
-
-These are open questions. For weight resets to become widely used new heuristics and best practices will need to develop. The answers may depend on both the network architecture and the underlying system dynamics. Trying to imagine the precise behaviours induced by primacy bias on Atari and Deep Mind Control can be difficult.
-
-
-
-## Implementing Primacy Bias
-
-The best way to learn something is through practice. In this section we will present a minimum example of primacy bias. The associated code is [released as a notebook](https://github.com/mkielo3/iclr-blog2024-primacy-bias) along with additional experiments.
-
-The biggest obstacle to studying primacy bias is the compute required. Training time scales linearly with replay ratio, and a high replay ratio is necessary to extract maximal information per sample and to recover after each reset. To work around this, we present an MVP: Minimum Viable Primacy (bias).
-
-We use a modified version of the Frozen Lake environment provided by Farama Gymnasium with a DQN model (one of first models to popularize a replay buffer) based on the CleanRL implementation .
-
-
-### 2x2 Switching Frozen Lake
-
-Frozen Lake is a simple pathfinding problem. The model receives a reward if it successfully traverses a grid to reach a goal. The model can fail in two ways: 1) it falls in a hole or 2) it takes too long to reach the goal. The model observes its location on the grid and each action is a move one tile up, down, left, or right.
-
-To simplify the problem, we restrict the map size to 2x2 and keep the environment deterministic. The agent always starts in the top left corner and is rewarded if it reaches the bottom right corner. A hole is placed in one of the two remaining spaces. The agent fails if it takes more than 2 steps or falls in a hole. Each map has exactly one solution.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/fl.jpeg" class="img-fluid rounded z-depth-1" %}
-
-
- MVP: Switching 2x2 Frozen Lake Environment, with solution in red.
-
-
-
-
-The agent attempts to cross the lake 1,000 times. To force primacy bias, we show the agent Map 1 for the first 200 crossings, and Map 2 for the last 800. The maps are deliberately chosen to have opposite solutions. After 400 crossings the agent will have experienced each map equally and afterwards the agent should begin to prefer Map 2 with increasing confidence. Our agent is maximally exploitative and will always take the action it thinks is best.
-
-Each trial is considered expensive (our agent doesn't want to freeze). A good algorithm will maximize the number of successful crossings in the 1,000 attempts. Each attempt is saved to the replay buffer and any reset will fully reinitialize all network weights.
-
-The advantage of this environment is that it is very fast. A trial of 1,000 crossings with a replay ratio of 1 completes in less than 5 seconds on a CPU. The disadvantage of this environment is that it's incredibly simple, and findings might not generalize to more complex problems.
-
-### Results
-
-The first thing we do is inspect how our model scores its first action with and without resets for each cross.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/q_vals/01.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Model scores for first action overtime (after softmax), with and without resets. The correct first action is down for the first 200 episodes and right afterwards. Replay ratio of 16 with results averaged over 25 seeds.
-
-
-
-Additional action values overtime for various learning rates.
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/q_vals/001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/q_vals/0001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/q_vals/00005.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/q_vals/00001.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
-
-
-
-
-
-Both models quickly determine that moving down is correct. The resetting model will periodically score actions equally before quickly recovering. Without resets, the map switch is only recognized after the 800th crossing. With resets, this switch happens around crossing 500. We also see that after the map switch the model without resets tries to adjust by increasing the scores for the incorrect left and up actions (which led to failure in two steps instead of one).
-
-We can also plot the reward per crossing, averaged over 25 seeds. Similar to the first result, the model with resets periodically fails, but also adapts to the map switch faster.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/reward/01.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Model score overtime, with and without resets. Replay ratio of 16. Average of 25 seeds.
-
-
-
-
-Additional scores overtime for various learning rates.
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/reward/001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/reward/0001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/reward/00005.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/reward/00001.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
-
-
-
-
-
-
-Next, we conduct a hyperparameter sweep with replay ratios 1, 4, 16 and reset frequencies 0, 50, 100, 500. We then compare the average number of successful crossings. A random policy will earn the reward 1/16 of the time.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/grid/01.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Full period average score, averaged across all crossings. Average of 25 seeds.
-
-
-
-Additional averages scores for various learning rates.
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/grid/001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/grid/0001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/grid/00005.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/grid/00001.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
-
-
-
-
-
-In general, the results match our expectations. With a learning rate of 0.01 a higher replay ratio improves results and having resets is always helpful. A high replay ratio with resets is necessary to achieve a score over 0.6 for all learning rates. Reset frequency and replay ratio must be adjusted alongside learning rate which scales how quickly the network can adapt in a non-stationary environment.
-
-As a final experiment, we vary model size. We compare a much smaller two layer DQN architecture to the larger three layer model used in prior experiments. Interestingly, this produces the highest score yet with a reset frequency of 10 steps although the result quickly disappears with a lower learning rate.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/little/01-2.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Full period average score. Average of 25 seeds. Split by Network Size with Replay Ratio of 16.
-
-
-
-Additional averages scores for various learning rates by network size.
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/little/001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/little/0001.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/little/00005.svg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/little/00001.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
-
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-primacy-bias-and-why-it-helps-to-forget/little/misc.svg" class="img-fluid rounded z-depth-1" %}
-
-
-
- Comparison of 3 layer and 2 layer networks. Reset every 10 steps with a replay ratio of 16. Average of 25 seeds.
-
-
-
-
-## Conclusions
-
-In this blogpost, we discuss primacy bias and its application to off-policy deep reinforcement learning. We highlight a subset of results and apply weight resets to a new problem.
-
-We hope that more examples of primacy bias continue to be discovered and studied. Eventually, we would like to identify specific behaviors that are catastrophically memorized and create guiding principles to identify environments that are most at risk of primacy bias. Overtime we hope this might unlock new applications of deep reinforcement learning.
-
-Even as the theory continues to develop, there is little harm in attempting periodic weight resets with a high replay ratio to train off-policy reinforcement learning agents.
-
-Finally, primacy bias might not always be a bad thing. If you decide to take a new shortcut to work by walking down an alley and the first thing you notice is how dark and unsafe it seems then maybe it’s a good idea to turn back. As always, it is an important decision for the modeller to decide if primacy bias should be treated in their problem.
-
-## Acknowledgements
-
-This blogpost is derived from our work that began in Dr. Zsolt Kira's excellent Deep Learning course at Georgia Tech.
-
diff --git a/_posts/2024-05-07-rlhf-without-rl.md b/_posts/2024-05-07-rlhf-without-rl.md
deleted file mode 100644
index 7f0a9e06..00000000
--- a/_posts/2024-05-07-rlhf-without-rl.md
+++ /dev/null
@@ -1,328 +0,0 @@
----
-layout: distill
-title: RLHF without RL - Direct Preference Optimization
-description: We discuss the RL part of RLHF and its recent displacement by direct preference optimization (DPO).
- With DPO, a language model can be aligned with
- human preferences without sampling from an LM, thereby significantly
- simplifying the training process. By now, DPO has been implemented in many projects and seems to be here to stay.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Michael Panchenko
- url: "https://transferlab.ai/authors/michael-panchenko"
- affiliations:
- name: appliedAI initiative GmbH
-
-bibliography: 2024-05-07-rlhf-without-rl.bib
-
-toc:
- - name: Background
- id: background
- - name: Is RLHF Reinforcement Learning?
- id: is-rlhf-reinforcement-learning
- - name: Direct Preference Optimization
- id: direct-preference-optimization
- - name: DPO in the Wild - Experiments, LLMs and Software
- id: dpo-in-the-wild-experiments-llms-and-software
- - name: Closing Remarks
- id: closing-remarks
-
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-## Background
-
-Reinforcement learning from human feedback (RLHF) is an important technique for
-aligning (large) language models (LM)
-with human preferences. It was introduced by Christiano et al. and then first
-applied to language models in the work by Ziegler et al..
-Since then, RLHF has become a central building block of many LLM-based applications,
-including the first versions of ChatGPT.
-
-RLHF for language models works roughly as follows:
-
-1. Collect a dataset of prompts $\mathcal{D}$ for the LM, typically containing
- instructions or questions.
-2. For each prompt $x\in \mathcal{D}$, collect a set of completions $y_1, ..., y_N$ from the
- LM. One can increase the temperature of the language model for this step to get a
- sufficient variability in them.
-3. Ask human annotators to rate the completions, thereby obtaining a dataset of preferences
- $x, y_{rank_1}, ..., y_{rank_N}$.
-4. Train a parameterized reward function $r_\phi$ (mapping pairs $(x,y)$ to scalars) on the collected preferences by minimizing the loss
-
- $$
- \mathcal{L}(r) = \mathbb{E}_{(x, y_{rank_i})} \left[ \log \frac{e^{r(x, y_{rank_i})}}{\sum_{j=1}^N e^{r(x, y_{rank_j})}} \right].
- $$
-
- This loss is inspired by the Bradley-Terry model for pairwise comparisons and by
- maximum-entropy inverse RL.
- Intuitively, it encourages the reward function to assign higher rewards to completions that are preferred by humans.
- Usually, the reward function is parameterized by the LM itself with an additional linear layer. Thus, the mapping from $(x, y)$ to $r(x, y)$ is given by
- simply concatenating the sequences $x$ and $y$ and passing the embedding of the last (or a differently selected) token through a linear layer.
-5. Fine-tune the LM by viewing it as a policy $\pi_\theta$ and using RL with the learned reward function $r_\phi$ as the
- reward. For this step, a separate dataset of prompts $\mathcal{D}\_{\text{RL}}$ is used to query the LM and collect completions.
- Since the reward is learned on a very limited subset of possible completions, and is therefore unreliable in
- off-distribution data, it would be unwise to aim at optimizing it without any regularization.
-
- The typical choice of regularization is the KL-divergence between the policy (i.e. the aligned/fine-tuned LM) and a reference
- policy $\pi_{\text{ref}}$ (usually the pretrained LM before fine-tuning). The RLHF objective then becomes
-
- $$
- \tag{1}
- \label{eq:rlhf}
- J(\pi) = \mathbb{E}_{x \sim \mathcal{D}_\text{RL}, y\sim \pi_\theta(y \mid x)} \left[
- r_\phi(x, y)- \beta D_{\text{KL}} \left( \pi(y, s) || \pi_\text{ref}(y, s) \right)
- \right],
- $$
-
- which is then used to find the optimal policy $\pi_\theta$ by some optimization algorithm, typically a variant
- of proximal policy optimization (PPO). Here $D_{\text{KL}}$ denotes the
- KL-divergence between two distributions, and the temperature $\beta$ is a hyperparameter
-that controls the strength of the regularization.
-
-The resulting LLMs are very powerful and so widely used that we don't need to further elaborate on their performance here.
-Note, however, that the RLHF scheme has quite some complexity when it comes to actually making it work in practice.
-
-## Is RLHF Reinforcement Learning?
-
-From the beginning, RLHF has sparked some controversy. Some regarded it as one of the prime applications of reinforcement learning
-(which may currently be perceived as "less hot" than LLMs, wherefore applying RL in LLMs is in the former's favor).
-At the same time, others were skeptical about whether RLHF is reinforcement learning at all.
-
-Indeed, some crucial components of RL are missing in RLHF. First, the current forms of RLHF do not involve sequential decision-making
-(although there is some work on that, e.g., the ILQL algorithm).
-While the rollout of a completion can formally be viewed as a sequence of actions, the reward is not given after the completion
-has ended. Moreover, for the purpose of RLHF the LM itself can be regarded as a direct mapping from inputs to distributions over completions,
-rather than a sequential decision-making agent in the space of tokens. Thus, at best, RLHF is a form of single-step,
-immediate-reward RL - in other words, a *contextual bandit*.
-
-Even more troubling than the non-sequential nature of RLHF may be its information flow. While the policy optimization of RLHF is framed as an online RL algorithm,
-*the environment consists of the policy itself*. Usually, in online RL an agent is able to extract new information from the environment.
-In RLHF, however, the information is not "new" in the sense that it is not extracted from something external to the agent itself.
-The only information not originally contained in the LM is in the preferences data (notably, not even in the completions themselves,
-but only in their rankings), and it is only used to fit a reward function. Thus, RLHF is more reminiscent of offline RL or supervised learning
-than of online RL.
-
-Because of this 1-step nature of RLHF and due to the (unusual for RL) application of training enormous models,
-the majority of RLHF software is not set up to be compatible with gym(nasium) or other environment interfaces. Take,
-for example, the well known [trl](https://github.com/huggingface/trl) and [trlx](https://github.com/CarperAI/trlx) libraries,
-which barely mention environments at all. A notable exception is the [RL4LMs project](https://github.com/allenai/RL4LMs) by AllenAI,
-which unfortunately seems to be abandoned, and is based on the deprecated gym instead of
-[gymnasium](https://gymnasium.farama.org/). For practical RLHF, training in parallel on massive datasets
-is a necessary requirement, which somewhat complicates the use of standard environment and training interfaces.
-
-The view that RLHF is not "really" RL, or at least does not have to be,
-has become even more popular after the publication of the DPO algorithm,
-which we will discuss in the next section.
-
-## Direct Preference Optimization
-
-The direct preference optimization (DPO) algorithm for aligning language models (LM) by Rafailov et al.
-is a method for aligning LMs to human preferences without having to sample from the LM and without using RL explicitly.
-Interestingly, DPO still optimizes the same objective as RLHF, but does so purely by supervised learning.
-This results in a much simpler training procedure and
-reportedly better performance in a number of experiments.
-
-The mathematical derivation of DPO is short and insightful. It is based on the following observations:
-
-### 1. Reward as a Function of the Policy
-
-The RLHF objective (\ref{eq:rlhf}) has an exact (non-parametric) solution for the optimal policy $\pi_r$:
-
-$$
-\pi_r(y \mid x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y \mid x) \exp
- \left( \frac{1}{\beta} r(x, y) \right).
-$$
-
-This expression is well known in the RL literature and is sometimes referred to as *Boltzmann policy*
-(note that in the 1-step RL setting, the Q-function is given by the reward itself).
-
-Similar results were proved in the REPS algorithm and follow-up work (a more recent paper in that
-direction is ). While this solution for $\pi_r$ in
-itself is intractable (because of the partition function $Z(x)$), it can be used
-to express the reward as a function of the optimal policy:
-
-$$
- \tag{2}
- \label{eq:reward-as-function-of-policy}
- r(x, y) = \beta \log \left( \frac{\pi_r(y \mid x)}{\pi_{\text{ref}}(y \mid x)} \right) + \log Z(x).
-$$
-
-### 2. Only Differences of Rewards Are Needed
-
-For simplicity, let us consider that only two completions are collected per
-input, which are then ranked as $y_w$ and $y_l$ (for winning and losing).
-DPO can be easily extended to the case of more completions per input, but the
-notation becomes more cumbersome.
-
-The reward $r_\phi$ is then learned by minimizing the loss:
-
-$$
- \mathcal{L}_\phi = \mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}}\left[
- \log \frac{ e ^ {r_\phi(x, y_w)}}{ e^{r_\phi(x, y_w)} + e^{r_\phi(x, y_l)}}
- \right]
-$$
-
-which is equivalent to
-
-$$
- \tag{3}
- \label{eq:reward-loss-binary}
- \mathcal{L}_\phi = - \mathbb{E}_{(x, y_w, y_l)\sim \mathcal{D}} \left[
- \log \sigma \left( r_\phi(x, y_w) - r_\phi(x, y_l) \right)
- \right],
-$$
-
-where $\sigma$ is the sigmoid function. Note that only _differences of rewards_
-enter (\ref{eq:reward-loss-binary}).
-
-### 3. DPO Objective
-
-After plugging the expression for the policy \ref{eq:reward-as-function-of-policy}
-into the loss \ref{eq:reward-loss-binary},
-the partition function $Z(x)$ cancels out. Replacing the
-optimal $\pi_r$ with the parameterized $\pi_\theta$, the DPO objective is obtained as
-
-$$
- \mathcal{L}_{\text{DPO}}(\pi_\theta ; \pi_{\text{ref}}) :=
- - \mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[
- \log \sigma \left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} -
- \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right)
- \right].
-$$
-
-Thus, instead of first learning a reward and then finding the optimizing policy,
-one directly finds the optimal policy such that its reward as obtained from
-(\ref{eq:reward-as-function-of-policy})
-corresponds to collected human preferences (i.e., a reward that
-optimizes (\ref{eq:reward-loss-binary})). Note that while the induced reward function
-itself is intractable, the differences of rewards remain tractable and can be
-computed using the learned policy. This should be sufficient for practical
-purposes, where rewards are mostly used to rank completions and, e.g., perform
-rejection sampling.
-
-The paper includes some more details and a discussion of the interpretation of
-the DPO update, and a detailed comparison to standard RLHF,
-but the essence of the method is captured by the above derivation. DPO can be
-easily extended to the case of more completions per input.
-
-## DPO in the Wild - Experiments, LLMs and Software
-
-The original experiments in the paper were conducted on small-scale models
-and datasets, and as such were not very convincing. We partially include them here for
-completeness:
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-rlhf-without-rl/original-evaluation.svg" class="img-fluid" %}
-
-
-
- Original evaluation of DPO on small-scale models and datasets.
- Left: TL;DR summarization win rates vs.
- human-written summaries, using GPT-4 as evaluator. DPO exceeds PPO’s best-case
- performance on summarization, while being more robust to changes in the sampling
- temperature.
- Right: The frontier of expected reward vs KL to the reference
- policy. DPO provides the highest expected reward for all KL values,
- demonstrating the quality of the optimization.
-
-
-Fortunately, DPO's simplicity has made it attractive to many researchers and engineers.
-By now, only a few months after the publication of the paper, it is
-already included in [trl](https://huggingface.co/docs/trl/dpo_trainer) as well as
-the ray-based library [OpenRLHF](https://github.com/OpenLLMAI/OpenRLHF) (which is
-notably not using rllib, but that's a story for another day). Moreover, several large models have been trained with DPO,
-including [Zephyr 7B](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) and the 70B
-parameters [TÜLU 2](https://github.com/allenai/open-instruct). Here is what the
-authors of the latter had to say about DPO:
-
-
- DPO training significantly improves AlpacaEval and MT-Bench performance. At all sizes,
- DPO training provides significant improvements in AlpacaEval, with our largest DPO-trained model
- significantly outperforming GPT-3.5-turbo-0314 (89.4 vs. 95.1) and is competitive with GPT-4 ...
- We also observe that DPO training provides a large boost in MT-Bench
- performance for the 13B and 70B size models, with TÜLU 2+DPO 70B being the best-performing
- open model compared to all other models on the MT-Bench leaderboard.
-
-
-
- DPO training is stable at large scales. We find that DPO training scales without issues with 70Bsize models,
- with DPO training still providing large benefits for open-ended generation (AlpacaEval)
- even at the 70B size. This suggests DPO is a promising path for training large models on human
- feedback without the engineering complexity required by PPO. To our knowledge, TÜLU 2+DPO
- 70B is the largest publicly-released DPO-trained model.
-
-
-
- DPO does not dramatically harm most other metrics. We find that DPO training does not
- significantly change performance in most other metrics we measure, such as factual reasoning
- (MMLU) or reasoning (BBH, GSM8k), with the exception of multilinguality (which we discuss
- below). This suggests that DPO training does not significantly change model capabilities.
- DPO training significantly drops multilingual capabilities. We find that DPO training significantly drops performance in TydiQA, which tests the multilingual capabilities of our model. However,
- we note that both our supervised finetuning and DPO data mixes do not explicitly contain multilingual
- data, and are majority English-language. As such, DPO training is likely to make multilingual outputs
- further out-of-distribution, and mixing in multilingual data at instruction tuning and DPO training
- stages may significantly improve these results.
-
-
-
- DPO training increases model verbosity. As seen in Table 4, TÜLU 2+DPO models generally
- output answers of longer length than those trained without DPO. This is in line with prior work
- showing a bias toward verbosity from RLHF training. However, we note that our DPO-trained models appear dramatically less verbose than other openweight models, which future work will investigate.
-
-
-## Closing Remarks
-
-One may find it surprising that supervised learning is able to replace RL
-on a formal level. For RLHF, _new_ data is sampled from the language model, and for DPO
-this is not the case.
-
-However, after paying closer attention to the information flow
-of RLHF as described above, it may not be too surprising after all. The sampled
-data is not really new - it is created using the very same model that one is trying
-to optimize. The rewards for these samples are also not new, they are obtained
-by fitting a reward function to the preferences, and no new human preferences are
-retrieved during optimization. So from the information-flow perspective,
-supervised learning and RL are indeed equivalent in this particular case. Maybe
-Francois Chollet was not too extreme for suggesting to _get rid of deep RL
-altogether_ in his tweet (note that it predates DPO. Personally, I don't believe in a complete futility of deep RL, but for RLHF he was on point):
-{% twitter https://twitter.com/fchollet/status/1630241783111364608?s=20 %}
-.
-
-Another surprising aspect of DPO is the question: *Why has nobody done this before?*
-Hopefully after reading this blog post, you will agree that the derivation of DPO is
-not particularly complicated, so why did it take almost 4 years after the introduction of RLHF?
-Especially considering how tricky RLHF can be to implement.
-I don't have an answer, though my intuition is that sometimes as a community we put too much
-effort into following a working solution, instead of taking a step back
-and searching for a simpler path. We might have witnessed a large scale instance of the
-[Region-beta paradox](https://en.wikipedia.org/wiki/Region-beta_paradox).
-
-As a final note on community dynamics: supervised and self-supervised learning are now making more headlines
-compared to reinforcement learning, and DPO might have the effect of slowing down
-the complicated (but, as I believe, necessary) marriage of RL and LLMs.
-I do think that planning and search should play some part of LLM training in the future,
-although only for settings in which there is an actual environment from which new information
-can be extracted (like tool-use or robotics). For now, however, taking the RL out of RLHF
-seems like a good step forward. If DPO can be made beneficial for most LLM trainings, I believe
-that one can firmly answer the opening question of this blog as:
-
-*Is RLHF really (online) RL? No, it is not.*
diff --git a/_posts/2024-05-07-robust-foundation-model.md b/_posts/2024-05-07-robust-foundation-model.md
deleted file mode 100644
index 6bf87f8e..00000000
--- a/_posts/2024-05-07-robust-foundation-model.md
+++ /dev/null
@@ -1,863 +0,0 @@
----
-layout: distill
-title: 'Towards Robust Foundation Models: Adversarial Contrastive Learning'
-description: Foundation models pre-trained on large-scale unlabelled datasets using self-supervision can be generalizable to a wide range of downstream tasks. Existing work has shown that adversarial attacks can effectively fool any downstream models fine-tuned from a pre-trained foundation model. The existence of such adversarial attacks necessitates the development of robust foundation models which can yield both standard generalization and adversarial robustness to safety-critical downstream tasks. Currently, adversarial contrastive learning (ACL) is one of the most effective methods for outputting a robust foundation model. ACL incorporates contrastive learning with adversarial data to effectively output a robust representation without requiring costly annotations. In this blog, we introduced two NeurIPS 2023 publications that can enhance ACL's efficacy and efficiency, respectively. (1) This blog introduces Adversarial Invariant Regularization (AIR) which is a state-of-the-art ACL algorithm. A causal theoretical framework is built to interpret ACL, and then the AIR algorithm is derived from the causal framework to regulate and improve the ACL. (2) This blog also introduces a Robustness-aware Coreset Selection (RCS) method to speed up ACL. RCS does not require label information and searches for an informative training subset that can maintain the adversarial robustness. For the first time, RCS enables the application of ACL on the large-scale ImageNet-1K dataset.
-# Your blog post's abstract.
- # Please add your abstract or summary here and not in the main body of your text.
- # Do not include math/latex or hyperlinks.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-
-authors:
- - name: Jingfeng Zhang
- url: https://zjfheart.github.io/
- affiliations:
- name: The University of Auckland & RIKEN Center for Advanced Intelligence Project
- - name: Xilie Xu
- url: https://godxuxilie.github.io/
- affiliations:
- name: National University of Singapore
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-robust-foundation-model.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Foundation Models
- subsections:
- - name: Contrastive Learning (CL)
- - name: Robust Foundation Models
- subsections:
- - name: Adversarial Contrastive Learning (ACL)
- # subsections:
- # - name: Interactive Figures
- - name: Enhancing ACL via Adversarial Invariant Regularization (AIR)
- subsections:
- - name: Causal View of ACL
- - name: the Methodology of AIR
- - name: Empirical Results
- - name: Robust Self-Supervised Learning (RobustSSL) Benchmark
- - name: Efficient ACL via Robustness-Aware Coreset Selection (RCS)
- subsections:
- - name: Motivation---ACL is Inefficient
- - name: the Methodology of RCS
- - name: Experimental Results
-
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-
-
-## Foundation Models
-
-
-Foundation models are pre-trained on large-scale unlabelled datasets using self-supervised learning methods, which is generalizable to a wide range of downstream tasks via fine-tuning. For example, GPT-3 has been successfully commercialized as a powerful text generation application. Vision transformer has been widely used in computer vision tasks such as object detection and medical analysis . BLIP is a vision-language pre-trained model that can perform many vision-language tasks such as the visual question answering task . CLAP is a language-audio pre-trained model that can be used for understanding the pair of texts and audio.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/foundation_models.png" class="img-fluid" %}
-
-
-
-
-
-
-### Contrastive Learning (CL)
-
-To build foundation models, contrastive learning (CL) is one of the popular self-supervised learning methods. CL aims to maximize the agreement between different natural views of the original data.
-
-Let $$f_\theta: \mathcal{X} \rightarrow \mathcal{Z}$$ be a feature extractor parameterized by $$\theta$$, $$g:\mathcal{Z} \rightarrow \mathcal{V}$$ be a projection head that maps representations to the space where the contrastive loss is applied, and $$\tau_i, \tau_j: \mathcal{X} \rightarrow \mathcal{X}$$ be two transformation operations randomly sampled from a pre-defined transformation set $$\mathcal{T}$$. Given a mini-batch $$B \sim \mathcal{X}^\beta$$ consisting of $$\beta$$ samples, we denote the augmented minibatch $$B^\prime = \{ \tau_i(x_k), \tau_j(x_k) \mid \forall x_k \in B \}$$ consisting of $$2\beta$$ samples. We take $$h_\theta(\cdot) = g \circ f_\theta(\cdot)$$ and $$x_k^u = \tau_u(x_k)$$ for any $$x_k \sim \mathcal{X}$$ and $$u \in \{i,j\}$$. The contrastive loss between different natural views (i.e., $$x_k^i$$ and $$x_k^j$$) is formulated as follows:
-
-$$ \ell_\mathrm{CL}(x_k^i,x_k^j; \theta)\!=\!-\! \sum\limits_{u \in \{i,j\}} \! \log \frac{e^{\mathrm{sim} \left(h_\theta(x_k^i), h_\theta(x_k^j) \right)/t}}{\sum\limits_{x \in B^\prime \setminus \{x_k^u\}} e^{\mathrm{sim} \left( h_\theta(x_k^u), h_\theta(x) \right)/t}}, $$
-
-where $$\mathrm{sim}(\cdot,\cdot)$$ is the cosine similarity function.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/SCL.png" class="img-fluid" %}
-
-
-
- Intuitively, CL aims to maximize the agreement between different natural views (the dash blue lines).
-
-
-**How to implement CL at the pre-training stage in practice?**
-
- Click here to see the Pytorch code for calculating contrastive loss. You can copy-paste it to calculate the contrastive loss in convenience.
-The code is copied from https://github.com/GodXuxilie/Enhancing_ACL_via_AIR.
-{% highlight python %}
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-class CL(nn.Module):
-
- def __init__(self, normalize=True, temperature=0.5):
- super(CL, self).__init__()
- self.normalize = normalize
- self.temperature = temperature
-
- def forward(self, zi, zj):
- # zi: the representation of natural view x^i.
- # zj: the representation of natural view x^j.
-
- bs = zi.shape[0]
- labels = torch.zeros((2*bs,)).long().to(zi.device)
- mask = torch.ones((bs, bs), dtype=bool).fill_diagonal_(0)
-
- zi_norm = F.normalize(zi, p=2, dim=-1) if self.normalize else zi
- zj_norm = F.normalize(zj, p=2, dim=-1) if self.normalize else zj
-
- ### Contrastive Loss ###
- logits_ii = torch.mm(zi_norm, zi_norm.t()) / self.temperature
- logits_ij = torch.mm(zi_norm, zj_norm.t()) / self.temperature
- logits_ji = torch.mm(zj_norm, zi_norm.t()) / self.temperature
- logits_jj = torch.mm(zj_norm, zj_norm.t()) / self.temperature
-
- logits_ij_pos = logits_ij[torch.logical_not(mask)]
- logits_ji_pos = logits_ji[torch.logical_not(mask)]
- logits_ii_neg = logits_ii[mask].reshape(bs, -1)
- logits_ij_neg = logits_ij[mask].reshape(bs, -1)
- logits_ji_neg = logits_ji[mask].reshape(bs, -1)
- logits_jj_neg = logits_jj[mask].reshape(bs, -1)
-
- pos = torch.cat((logits_ij_pos, logits_ji_pos), dim=0).unsqueeze(1)
- neg_i = torch.cat((logits_ii_neg, logits_ij_neg), dim=1)
- neg_j = torch.cat((logits_ji_neg, logits_jj_neg), dim=1)
- neg = torch.cat((neg_i, neg_j), dim=0)
-
- logits = torch.cat((pos, neg), dim=1)
- nat_contrastive_loss = F.cross_entropy(logits, labels)
- return nat_contrastive_loss
-{% endhighlight %}
-
-
-Besides, you can use the following script to conduct self-supervised pre-training via CL using ResNet-18 on CIFAR-10:
-{% highlight bash %}
-# Pre-training stage via CL
-git clone https://github.com/GodXuxilie/Enhancing_ACL_via_AIR.git
-cd Enhancing_ACL_via_AIR
-PRE_TRAIN_DIR=CL_ResNet18_cifar10
-python pretraining.py $PRE_TRAIN_DIR --dataset cifar10 \
- --model r18 \
- --pgd_iter 0 --lambda1 0 --lambda2 0
-{% endhighlight %}
-
-
-## Robust Foundation Models
-Existing work has shown that there exist adversarial attacks that can fool the foundation representations to output incorrect predictions by adding imperceptible adversarial perturbations to the original inputs in downstream tasks.
-The existence of adversarial attacks necessitates the development of robust foundation models in safety-critical downstream tasks.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/adv_attack.png" class="img-fluid" %}
-
-
-
-The foundation representation is vulnerable to adversarial attacks, which wrongly predicts a car as 'NOT a car'.
-
-
-Robust foundation models are pre-trained on large-scale datasets via robust self-supervised learning methods. Robust foundation models have the following two critical properties:
-- Robust foundation representations is generalizable to downstream tasks;
-- Fine-tuned robust foundation representations is adversarially robust against adversarial attacks in downstream tasks.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/robust_foundation_models.png" class="img-fluid" %}
-
-
-
-### Adversarial Contrastive Learning (ACL)
-
-To learn robust foundation representations, adversarial contrastive learning (ACL) is one of the most popular and effective robust self-supervised learning methods. ACL incorporates CL with adversarial data to build a robust foundation model without requiring costly annotations. ACL aims to maximize the agreement between different natural views as well as the agreement between different adversarial views. The adversarial contrastive loss given a data point $$x_k \in \mathcal{X}$$ is formulated as follows:
-
-$$ \ell_\mathrm{ACL}(x_k;\theta) = (1 + \omega) \cdot \ell_\mathrm{CL}(\tilde{x}_{k}^i, \tilde{x}_{k}^j; \theta) + (1 - \omega) \cdot \ell_\mathrm{CL}(x_k^i, x_k^j; \theta), $$
-
-where adversarial views are formulated as follows:
-
-$$ \tilde{x}_{k}^i, \tilde{x}_{k}^j = \mathop{\arg\max}_{
- {\Large \tilde{x}_{k}^i \in \mathcal{B}_\epsilon[x_k^i]}
- \atop
- {\Large \tilde{x}_{k}^j \in \mathcal{B}_\epsilon[x_k^j]}
- } \ell_\mathrm{CL}(\tilde{x}_{k}^i, \tilde{x}_{k}^j; \theta). $$
-
-Note that $$\omega \in [0,1]$$ is a scalar and $$\mathcal{B}_\epsilon[x]$$ is a constraint that ensures the adversarial data $$\tilde{x}$$ is in the $$\epsilon$$-ball around data $$x$$.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/ACL.png" class="img-fluid" %}
-
-
-
- Intuitively, ACL aims to maximize the agreement between different natural view (the dash blue lines) and the agreement between different adversarial views (the dash red lines).
-
-
-Here is the generation procedure of adversarial data via Projected Gradient Descent (PGD) . Given an initial positive pair $$(x_k^{i,(0)}, x_k^{j,(0)})$$, PGD step $$T \in \mathbb{N}$$, step size $$\rho > 0$$, and adversarial budget $$\epsilon \geq 0$$, PGD iteratively updates the pair of data from $$t=0$$ to $$T-1$$ as follows:
-
-$$ x_k^{i,(t+1)} \! = \! \Pi_{\mathcal{B}_\epsilon[x_k^{i,(0)}]} \big( x_k^{i,(t)} +\rho \cdot \mathrm{sign} (\nabla_{x_k^{i,(t)}} \ell_\mathrm{CL}(x_k^{i,(t)}, x_k^{j,(t)}) \big ), $$
-
-$$ x_k^{j,(t+1)} \! = \! \Pi_{\mathcal{B}_\epsilon[x_k^{j,(0)}]} \big( x_k^{j,(t)} +\rho \cdot \mathrm{sign} (\nabla_{x_k^{j,(t)}} \ell_\mathrm{CL}(x_k^{i,(t)}, x_k^{j,(t)}) \big ),$$
-
-where $$\Pi_{\mathcal{B}_\epsilon[x]}$$ projects the data into the $$\epsilon$$-ball around the initial point $$x$$. Generating adversarial data requires $$T$$ iterations of forwarding and back-propagations, which makes the training procedure extremely slow.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/pgd_step.gif" class="img-fluid" %}
-
-
-
- The generation procedure of adversarial data in ACL. The adversarial data $\tilde{x}_k^i$ and $\tilde{x}_k^j$ are updated from the low-loss region to the high-loss region step by step according to the loss gradient.
-
-
-At each epoch, ACL conducts steps (1) and (2) alternatively:
-
-- Step (1): generating adversarial data (i.e., $$\tilde{x}_k^i$$ and $$\tilde{x}_k^j$$) via PGD;
-
-- Step (2): updating model parameters via minimizing adversarial contrastive loss to maximize agreements on the adversarial data and natural data.
-
-
-**How to implement ACL at the pre-training stage in practice?**
-
-Click here to see the Pytorch code for calculating adversarial contrastive loss. You can copy-paste it to calculate the adversarial contrastive loss in convenience. The code is copied from https://github.com/GodXuxilie/Enhancing_ACL_via_AIR.
-{% highlight python %}
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-class ACL(nn.Module):
-
- def __init__(self, normalize=True, temperature=0.5):
- super(ACL, self).__init__()
- self.normalize = normalize
- self.temperature = temperature
-
- def forward(self, zi, zj, zi_adv, zj_adv, weight=0.5):
- # zi: the representation of natural view x^i.
- # zj: the representation of natural view x^j.
- # zi_adv: the representation of adversarial view \tilde{x}^i.
- # zj_adv: the representation of adversarial view \tilde{x}^j.
-
- bs = zi.shape[0]
- labels = torch.zeros((2*bs,)).long().to(zi.device)
- mask = torch.ones((bs, bs), dtype=bool).fill_diagonal_(0)
-
- zi_norm = F.normalize(zi, p=2, dim=-1) if self.normalize else zi
- zj_norm = F.normalize(zj, p=2, dim=-1) if self.normalize else zj
- zi_adv_norm = F.normalize(zi_adv, p=2, dim=-1) if self.normalize else zi_adv
- zj_adv_norm = F.normalize(zj_adv, p=2, dim=-1) i if self.normalize else zj_adv
-
- ### Adversarial Contrastive Loss ###
-
- logits_ii = torch.mm(zi_norm, zi_norm.t()) / self.temperature
- logits_ij = torch.mm(zi_norm, zj_norm.t()) / self.temperature
- logits_ji = torch.mm(zj_norm, zi_norm.t()) / self.temperature
- logits_jj = torch.mm(zj_norm, zj_norm.t()) / self.temperature
-
- logits_ij_pos = logits_ij[torch.logical_not(mask)]
- logits_ji_pos = logits_ji[torch.logical_not(mask)]
- logits_ii_neg = logits_ii[mask].reshape(bs, -1)
- logits_ij_neg = logits_ij[mask].reshape(bs, -1)
- logits_ji_neg = logits_ji[mask].reshape(bs, -1)
- logits_jj_neg = logits_jj[mask].reshape(bs, -1)
-
- pos = torch.cat((logits_ij_pos, logits_ji_pos), dim=0).unsqueeze(1)
- neg_i = torch.cat((logits_ii_neg, logits_ij_neg), dim=1)
- neg_j = torch.cat((logits_ji_neg, logits_jj_neg), dim=1)
- neg = torch.cat((neg_i, neg_j), dim=0)
-
- logits = torch.cat((pos, neg), dim=1)
- nat_contrastive_loss = F.cross_entropy(logits, labels)
-
- logits_ii_adv = torch.mm(zi_adv_norm, zi_adv_norm.t()) / self.temperature
- logits_ij_adv = torch.mm(zi_adv_norm, zj_adv_norm.t()) / self.temperature
- logits_ji_adv = torch.mm(zj_adv_norm, zi_adv_norm.t()) / self.temperature
- logits_jj_adv = torch.mm(zj_adv_norm, zj_adv_norm.t()) / self.temperature
-
- logits_ij_pos_adv = logits_ij_adv[torch.logical_not(mask)]
- logits_ji_pos_adv = logits_ji_adv[torch.logical_not(mask)]
- logits_ii_neg_adv = logits_ii_adv[mask].reshape(bs, -1)
- logits_ij_neg_adv = logits_ij_adv[mask].reshape(bs, -1)
- logits_ji_neg_adv = logits_ji_adv[mask].reshape(bs, -1)
- logits_jj_neg_adv = logits_jj_adv[mask].reshape(bs, -1)
-
- pos_adv = torch.cat((logits_ij_pos_adv, logits_ji_pos_adv), dim=0).unsqueeze(1)
- neg_i_adv = torch.cat((logits_ii_neg_adv, logits_ij_neg_adv), dim=1)
- neg_j_adv = torch.cat((logits_ji_neg_adv, logits_jj_neg_adv), dim=1)
- neg_adv = torch.cat((neg_i_adv, neg_j_adv), dim=0)
-
- logits_adv = torch.cat((pos_adv, neg_adv), dim=1)
- adv_contrastive_loss = F.cross_entropy(logits_adv, labels)
-
- return (1 - weight) * nat_contrastive_loss + (1 + weight) * adv_contrastive_loss
-{% endhighlight %}
-
-
-Besides, you can use the following script to conduct robust self-supervised pre-training via ACL using ResNet-18 on CIFAR-10:
-{% highlight bash %}
-# Pre-training stage via ACL
-git clone https://github.com/GodXuxilie/Enhancing_ACL_via_AIR.git
-cd Enhancing_ACL_via_AIR
-PRE_TRAIN_DIR=ACL_ResNet18_cifar10
-python pretraining.py $PRE_TRAIN_DIR --dataset cifar10 \
- --model r18 \
- --DynAug --lambda1 0 --lambda2 0
-{% endhighlight %}
-
-**How to utilize robust foundation representations via fine-tuning in downstream tasks?**
-
-At the fine-tuning stage, a classifier is randomly initialized and appended to the pre-trained feature extractor for solving the classification tasks.
-There are three types of fine-tuning modes:
-1. Standard linear fine-tuning (SLF): only standardly fine-tuning the classifier while freezing the feature extractor.
-2. Adversarial linear fine-tuning (ALF): only adversarially fine-tuning the classifier while freezing the feature extractor.
-3. Adversarial full fine-tuning (AFF): adversarially fine-tuning both the feature extractor and the classifier.
-
-You can use the following script to transfer an adversarially pre-trained ResNet-18 on CIFAR-10 to a downstream task CIFAR-100 via fine-tuning:
-{% highlight bash %}
-# Fine-tuning stage
-cd Enhancing_ACL_via_AIR
-PRE_TRAIN_DIR=ACL_ResNet18_cifar10
-FINETUNE_DIR=ACL_ResNet18_cifar10_cifar100
-MODE=SLF/ALF/AFF/ALL
-python finetuning.py --mode $MODE \
- --experiment $FINETUNE_DIR \
- --checkpoint ./checkpoints/$PRE_TRAIN_DIR/model.pt \
- --dataset cifar100 \
- --model r18 \
- --eval-AA --eval-OOD --pretraining DynACL
-{% endhighlight %}
-Note that `MODE=ALL` refers to that the `finetuning.py` sequentially conducts fine-tuning of all three modes (i.e., SLF, ALF, and AFF) and outputs the result via each fine-tuning mode in the log file `$FINETUNE_DIR/results/log.txt`.
-
-## Enhancing ACL via Adversarial Invariant Regularization (AIR)
-
-Here, we introduce the NeurIPS 2023 paper which proposes Adversarial Invariant Regularization (AIR) that regulates both standard and robust representations to be style-independent based on a causal theoretical framework. Empirically, AIR yields state-of-the-art performance in terms of robustness against adversarial attacks and common corruption as well as the standard generalization in downstream tasks.
-
-### Causal View of ACL
-
-AIR first introduces the causal graph of the ACL as shown in the following figure.
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/causal_graph.png" class="img-fluid" %}
-
-
-
- The causal graph of the ACL.
-
-During **the data generation procedure**:
-
-- $$c$$ is the content variable, which can be regarded as the original data in the datasets.
-- $$s$$ is the style factor, which can regarded as the data transformation functions that can modify the content while maintaining the semantic meaning of the content. Note that factors $$c$$ and $$s$$ are independent.
-- $$x$$ is the natural data, which is decided by the content factor $$c$$ and the style factor $$s$$.
-- $$y_t \in \{ y_i \}_{i=1}^{T}$$ is the label from an unknown downstream task. Note that $$y_t$$ is only decided by the content factor $$c$$.
-- $$y^R$$ is the proxy label, which is a refinement of $y_t$. $$y^R$$ is used for self-supervised learning without labels. As illustrated in the following figure, the label `dog` is refined into proxy labels `golden Retriever with yellow hair` and `labrador retriever with black hair`. Therefore, when there is no target label, we can train models by differentiating these two different pictures using the contrastive loss.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/proxy_label.png" class="img-fluid" %}
-
-
-
- The illustration of the proxy label $y^R$ which is a refinement of the label $y_t$.
-
-
-- $$\tilde{x}$$ is the adversarial data of $x$. Since the generation procedure of $$\tilde{x}$$ in ACL does not use the labels, the adversarial data $$\tilde{x}$$ is decided by the natural data $$x$$ and the model parameter $$\theta$$.
-
-During **the learning procedure**, ACL optimizes the parameters $$\theta$$ by maximizing the conditional probabilities both $$p(y^R \mid x)$$ and $$p(y^R \mid \tilde{x})$$.
-
-### the Methodology of AIR
-
-**Style-invariant criterion.**
-
-From the causal view of ACL, the learning procedure should satisfy the style-independent criterion. That is to say, the intervention on the style factor should not affect the conditional probability, i.e., $$p^{do(\tau_i)}(y^R \mid x) = p^{do(\tau_j)}(y^R \mid x)$$ where $$do(\tau)$$ is the intervention approximated by the data augmentation function $\tau \in \mathcal{T}$.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/AIR_invariant.png" class="img-fluid" %}
-
-
-
- According to causal reasoning, the style factor $s$ should not affect $p(y^R \mid x)$.
-
-
-Assuming that the path $$x \rightarrow \tilde{x} \rightarrow y^R$$ in the causal graph satisfies the Markov condition, we can obtain that
-
-$$p(y^R \mid x) = p(y^R \mid \tilde{x})p(\tilde{x} \mid x).$$
-
-Therefore, ACL should follow the style-independent criterion as follows:
-
-$$
-p^{do(\tau_i)}(y^R \mid \tilde{x}) p^{do(\tau_i)}(\tilde{x} \mid x) = p^{do(\tau_j)}(y^R \mid \tilde{x}) p^{do(\tau_j)}(\tilde{x} \mid x) \quad \forall \tau_i, \tau_j \in \mathcal{T}
-.$$
-
-The conditional probability $$p^{do(\tau_u)}(y^R \mid \tilde{x})$$ for $$u \in \{i,j\}$$ is calculated as the cosine similarity between the original data $$x$$ and the adversarial data $$\tilde{x}^u$$ normalized by the softmax function:
-
-$$
-p^{do(\tau_u)}(y^R \mid \tilde{x}) = \frac{e^{\mathrm{sim} \left(f_\theta(x), f_\theta(\tilde{x}^u) \right)/t}}
-{\sum\limits_{x_k \in B} e^{\mathrm{sim} \left( f_\theta(x_k), f_\theta(\tilde{x}_k^u) \right)/t}}.
-$$
-
-Note that $$y^R$$ is only decided by the content factor $$c$$. Empirically, the content factor $$c$$ can be approximated by the original data $$x$$ from the datasets.
-
-The conditional probability $$p^{do(\tau_u)}(\tilde{x} \mid x)$$ for $$u \in \{i,j\}$$ is calculated as the cosine similarity between the natural data $$x^u$$ and the adversarial data $$\tilde{x}^u$$ normalized by the softmax function:
-
-$$
-p^{do(\tau_u)}(\tilde{x} | x) = \frac{e^{\mathrm{sim} \left(f_\theta(\tilde{x}^u), f_\theta(x^u) \right)/t}}
-{\sum\limits_{x_k \in B} e^{\mathrm{sim} \left( f_\theta(\tilde{x}_k^u), f_\theta(x_k^u) \right)/t}}.
-$$
-
-
-
-**The loss function of AIR.**
-
-To achieve the style-invariant criterion, AIR is proposed to regulate the representations to be style-independent as follows:
-
-$$
-\mathcal{L}_\mathrm{AIR}(B;\theta, \epsilon) = \mathrm{KL}\left(p^{do(\tau_i)}(y^R \mid \tilde{x}) p^{do(\tau_i)}(\tilde{x} \mid x)
- \| p^{do(\tau_j)}(y^R \mid \tilde{x}) p^{do(\tau_j)}(\tilde{x} \mid x) ; B \right),
-$$
-
-in which $$\epsilon \geq 0$$ is the adversarial budget, $$B$$ is a mini-batch, and
-$$\mathrm{KL}(p(x) \| q(x); B) = \sum_{x \in B} p(x) \log \frac{p(x)}{q(x)}$$ denotes the Kullback–Leibler (KL) divergence.
-
-We provide an illustration of AIR for ACL. The AIR aims to maximize the agreements between the original data and the adversarial view (the dash yellow lines) and the agreements between the natural view and the adversarial view (the dash pink lines).
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/AIR_understand.png" class="img-fluid" %}
-
-
-
- Intuitively, AIR aims to maximize the agreements among different natural views, different adversarial views, and original data.
-
-
-**Learning objective of AIR enhanced ACL.**
-
-The learning objective of AIR is formulated as follows:
-
-$$
-\mathop{\arg\min}_{\theta} \sum_{x \in U} \ell_\mathrm{ACL}(x; \theta) + \lambda_1 \cdot \mathcal{L}_\mathrm{AIR}(U;\theta,0) + \lambda_2 \cdot \mathcal{L}_\mathrm{AIR}(U;\theta,\epsilon),
-$$
-
-where $$\lambda_1 \geq 0$$ and $$\lambda_2 \geq 0$$ are two hyper-parameters.
-
-The official code of AIR is available at [https://github.com/GodXuxilie/Enhancing_ACL_via_AIR](https://github.com/GodXuxilie/Enhancing_ACL_via_AIR).
-Click here to see the Pytorch code for calculating AIR loss. You can copy-paste it to calculate the AIR loss in convenience.
-{% highlight python %}
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-class AIR(nn.Module):
-
- def __init__(self, normalize=True, temperature=0.5):
- super(AIR, self).__init__()
- self.normalize = normalize
- self.temperature = temperature
-
- def forward(self, zi, zj, zi_adv, zj_adv, z_orig, weight=0.5, lambda1=0.5, lambda2=0.5):
- # zi: the representation of natural data x^i.
- # zj: the representation of natural data x^j.
- # zi_adv: the representation of adversarial data \tilde{x}^i.
- # zj_adv: the representation of adversarial data \tilde{x}^j.
- # z_orig: the representation of original data x.
-
- bs = zi.shape[0]
- labels = torch.zeros((2*bs,)).long().to(zi.device)
- mask = torch.ones((bs, bs), dtype=bool).fill_diagonal_(0)
-
- zi_norm = F.normalize(zi, p=2, dim=-1) if self.normalize else zi
- zj_norm = F.normalize(zj, p=2, dim=-1) if self.normalize else zj
- zi_adv_norm = F.normalize(zi_adv, p=2, dim=-1) if self.normalize else zi_adv
- zj_adv_norm = F.normalize(zj_adv, p=2, dim=-1) if self.normalize else zj_adv
- zo_norm = F.normalize(z_orig, p=2, dim=-1) if self.normalize else z_orig
-
- ### Adversarial Contrastive Loss ###
- logits_ii = torch.mm(zi_norm, zi_norm.t()) / self.temperature
- logits_ij = torch.mm(zi_norm, zj_norm.t()) / self.temperature
- logits_ji = torch.mm(zj_norm, zi_norm.t()) / self.temperature
- logits_jj = torch.mm(zj_norm, zj_norm.t()) / self.temperature
-
- logits_ij_pos = logits_ij[torch.logical_not(mask)]
- logits_ji_pos = logits_ji[torch.logical_not(mask)]
- logits_ii_neg = logits_ii[mask].reshape(bs, -1)
- logits_ij_neg = logits_ij[mask].reshape(bs, -1)
- logits_ji_neg = logits_ji[mask].reshape(bs, -1)
- logits_jj_neg = logits_jj[mask].reshape(bs, -1)
-
- pos = torch.cat((logits_ij_pos, logits_ji_pos), dim=0).unsqueeze(1)
- neg_i = torch.cat((logits_ii_neg, logits_ij_neg), dim=1)
- neg_j = torch.cat((logits_ji_neg, logits_jj_neg), dim=1)
- neg = torch.cat((neg_i, neg_j), dim=0)
-
- logits = torch.cat((pos, neg), dim=1)
- nat_contrastive_loss = F.cross_entropy(logits, labels)
-
- logits_ii_adv = torch.mm(zi_adv_norm, zi_adv_norm.t()) / self.temperature
- logits_ij_adv = torch.mm(zi_adv_norm, zj_adv_norm.t()) / self.temperature
- logits_ji_adv = torch.mm(zj_adv_norm, zi_adv_norm.t()) / self.temperature
- logits_jj_adv = torch.mm(zj_adv_norm, zj_adv_norm.t()) / self.temperature
-
- logits_ij_pos_adv = logits_ij_adv[torch.logical_not(mask)]
- logits_ji_pos_adv = logits_ji_adv[torch.logical_not(mask)]
- logits_ii_neg_adv = logits_ii_adv[mask].reshape(bs, -1)
- logits_ij_neg_adv = logits_ij_adv[mask].reshape(bs, -1)
- logits_ji_neg_adv = logits_ji_adv[mask].reshape(bs, -1)
- logits_jj_neg_adv = logits_jj_adv[mask].reshape(bs, -1)
-
- pos_adv = torch.cat((logits_ij_pos_adv, logits_ji_pos_adv), dim=0).unsqueeze(1)
- neg_i_adv = torch.cat((logits_ii_neg_adv, logits_ij_neg_adv), dim=1)
- neg_j_adv = torch.cat((logits_ji_neg_adv, logits_jj_neg_adv), dim=1)
- neg_adv = torch.cat((neg_i_adv, neg_j_adv), dim=0)
-
- logits_adv = torch.cat((pos_adv, neg_adv), dim=1)
- adv_contrastive_loss = F.cross_entropy(logits_adv, labels)
-
- ### Adversarial Invariant Regularization ###
- logits_io = torch.mm(zi_norm, zo_norm.t()) / self.temperature
- logits_jo = torch.mm(zj_norm, zo_norm.t()) / self.temperature
- probs_io_zi = F.softmax(logits_io[torch.logical_not(mask)], -1)
- probs_jo_zj = F.log_softmax(logits_jo[torch.logical_not(mask)], -1)
- AIR_standard = F.kl_div(probs_io_zi, probs_jo_zj, log_target=True, reduction="sum")
-
- logits_io = torch.mm(zi_adv_norm, zi_norm.t()) / self.temperature
- logits_jo = torch.mm(zj_adv_norm, zj_norm.t()) / self.temperature
- probs_io_zi_adv_consis = F.softmax(logits_io[torch.logical_not(mask)], -1)
- probs_jo_zj_adv_consis = F.softmax(logits_jo[torch.logical_not(mask)], -1)
-
- logits_io = torch.mm(zi_adv_norm, zo_norm.t()) / self.temperature
- logits_jo = torch.mm(zj_adv_norm, zo_norm.t()) / self.temperature
- probs_io_zi_adv = F.softmax(logits_io[torch.logical_not(mask)], -1)
- probs_jo_zj_adv = F.softmax(logits_jo[torch.logical_not(mask)], -1)
-
- probs_io_zi_adv = torch.mul(probs_io_zi_adv, probs_io_zi_adv_consis)
- probs_jo_zj_adv = torch.mul(probs_jo_zj_adv, probs_jo_zj_adv_consis)
- AIR_robust = F.kl_div(probs_io_zi_adv, torch.log(probs_jo_zj_adv), log_target=True, reduction="sum")
-
- return (1 - weight) * nat_contrastive_loss + (1 + weight) * adv_contrastive_loss + lambda1 * AIR_standard + lambda2 * AIR_robust
-{% endhighlight %}
-
-
-Besides, you can use the following script to conduct robust self-supervised pre-training via AIR using ResNet-18 on CIFAR-10:
-{% highlight bash %}
-# Pre-training stage via AIR
-git clone https://github.com/GodXuxilie/Enhancing_ACL_via_AIR.git
-cd Enhancing_ACL_via_AIR
-PRE_TRAIN_DIR=AIR_ResNet18_cifar10
-python pretraining.py $PRE_TRAIN_DIR --dataset cifar10 --model r18 --DynAug
-{% endhighlight %}
-
-
-### Empirical Results
-
-**AIR yields state-of-the-art cross-task robustness transferability against adversarial attacks.**
- - $$\mathcal{D}_1 \rightarrow \mathcal{D}_2$$ refers to that the model is pre-trained on dataset $$\mathcal{D}_1$$ and fine-tuned on downstream dataset $$\mathcal{D}_2$$.
- - `SA` refers the standard accuracy calculated as the average accuracy on the natural test data in the downstream dataset $$\mathcal{D}_2$$.
- - `AA` refers to the robust accuracy calculated as the average accuracy on the adversarial test data generated via [adversarial attacks](https://github.com/fra31/auto-attack) in the downstream dataset $$\mathcal{D}_2$$.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/AIR_cross_attack.png" class="img-fluid" %}
-
-
-
-**AIR yields state-of-the-art cross-task robustness transferability against common corruptions.**
-
-`CS-#` refers to the the average accuracy evaluated on the test data under common corruptions with corruption severity (CS) of `#` $$ \in $$ \{1,3,5\} in the downstream dataset $$\mathcal{D}_2$$.
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/AIR_cross_corrup.png" class="img-fluid" %}
-
-
-
-To reproduce the above results of the transferability from CIFAR-10 to CIFAR-100, you can use the following scripts.
-
-- At the pre-training stage, you can conduct AIR using ResNet-18 on CIFAR-10.
-{% highlight bash %}
-# Pre-training stage using AIR
-git clone https://github.com/GodXuxilie/Enhancing_ACL_via_AIR.git
-cd Enhancing_ACL_via_AIR
-PRE_TRAIN_DIR=AIR_ResNet18_cifar10
-python pretraining.py $PRETRAIN_DIR --dataset cifar10 --model r18 --DynAug
-{% endhighlight %}
-
-- At the fine-tuning stage, you can fine-tune the pre-trained ResNet-18 to downstream task CIFAR-100. During the fine-tuning stage, the following script will automatically conduct all three fine-tuning modes (i.e., SLF, ALF, and AFF). After the fine-tuning stage, you can check the standard accuracy, the robust accuracy under adversarial attacks and common cottuptions under each fine-tuning method from a log file at `$FINETUNE_DIR/results/log.txt`.
-
-{% highlight bash %}
-# Fine-tuning stage
-cd Enhancing_ACL_via_AIR
-PRE_TRAIN_DIR=AIR_ResNet18_cifar10
-FINETUNE_DIR=AIR_ResNet18_cifar10_cifar100
-python finetuning.py --experiment $EXP_DIR \
- --checkpoint ./checkpoints/$PRE_TRAIN_DIR/model.pt \
- --dataset cifar100 \
- --model r18 \
- --mode ALL \
- --eval-AA --eval-OOD --pretraining DynACL_AIR
-{% endhighlight %}
-
-
-### Robust Self-Supervised Learning (RobustSSL) Benchmark The website of RobustSSL Benchmark is at https://robustssl.github.io/.
-
-**AIR ranks FIRST in [RobustSSL Benchmark](https://robustssl.github.io/)!** For more information regarding the leaderboards, please check the website of [RobustSSL Benchmark](https://robustssl.github.io/).
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/leaderboard.png" class="img-fluid" %}
-
-
-
- A screenshot of the leaderboard shown in RobustSSL Benchmark.
-
-
-
-## Efficient ACL via Robustness-Aware Coreset Selection (RCS)
-
-Here, we introduce the NeurIPS 2023 spotlight paper which proposes Robustness-Aware Coreset Selection (RCS) that selects an informative coreset without label annotations to speed up ACL. Theoretically, Xu et al. (2023) show that a greedy search algorithm can efficiently find the coreset. Empirically, RCS can speed up both ACL and supervised robust pre-training by a large margin on CIFAR and ImageNet-1K datasets without significantly hurting the robustness transferability. This paper for the first time proves the concept of the possibility of applying ACL on large-scale datasets.
-
-### Motivation---ACL is Inefficient
-
-ACL is computationally prohibitive on large-scale datasets since generating adversarial data requires expensive computational overheads.
-
-Empirically, ACL on the entire ImageNet-1K dataset (1,281,167 training data points) requires about **650 hours** evaluated on RTX A5000 GPUs.
-Due to the inefficiency of ACL, ACL has not yet been applied to ImageNet-1K datasets without RCS.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/PGD.png" class="img-fluid" width="100" height="100" %}
-
-
-
- ACL is inefficient because $T$ PGD steps require expensive computational overheads.
-
-
-### the Methodology of RCS
-
-**Intuition of RCS.**
-
-To speed up ACL, RCS takes an intuitive idea which is to find an informative training subset (called "coreset"). The coreset can directly decrease the number of training samples, thus significantly accelerating ACL. Besides, since the coreset is informative, which is beneficial in improving $$f$$'s adversarial robustness, it should guarantee the ACL to output an effective robust foundation model.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/intuition.png" class="img-fluid" %}
-
-
-
- RCS generates an informative coreset to make ACL efficiently obtain an effective robust foundation model.Image from https://medium.com/analytics-vidhya/sampling-statistical-approach-in-machine-learning-4903c40ebf86.
-
-
-**Representational Distance (RD) as a measurement of $$f$$'s adversarial robustness without labels.**
-
-RD of a data point $$\ell_\mathrm{RD}(x;\theta)$$ is quantified by the representational distance between the natural data and its adversarial counterpart, i.e.,
-
-$$\ell_{\mathrm{RD}}(x; \theta) = d(g \circ f_\theta(\tilde{x}), g \circ f_\theta(x)) \quad \mathrm{s.t.} \quad \tilde{x} = \mathop{\arg\max}_{x^{\prime} \in \mathcal{B}_\epsilon[x]} \quad d(g \circ f_\theta(x^{\prime}), g \circ f_\theta(x)),$$
-
-in which the PGD method is used to generate adversarial data $$\tilde{x}$$ within the $$\epsilon$$-ball centered at $$x$$ and
-$$d(\cdot, \cdot): \mathcal{V} \times \mathcal{V} \rightarrow \mathbb{R}$$ is a distance function, such as the KL divergence.
-The smaller the RD is, the representations are of less sensitivity to adversarial perturbations, thus being more adversarially robust.
-
-**Objective function of RCS.**
-
-To realize the intuitive idea, RCS is formulated as follows:
-
-$$ S^* = \mathop{\arg\min}_{S \subseteq X, |S|/|X| = k} \mathcal{L}_{\mathrm{RD}}(U; \theta(S)),$$
-
-$$\theta(S) = \mathop{\arg\min}_{\theta} \mathcal{L}_\mathrm{ACL}(S; \theta),$$
-
-in which $$S^*$$ is the coreset, $$U$$ is an unlabled validation set, $$k \in (0,1]$$ is subset fraction that controls the size of coreset, and $$ \mathcal{L}_{\mathrm{RD}}(U; \theta(S)) = \sum_{x \in U} \ell_\mathrm{RD}(x; \theta(S)) $$, and $$ \mathcal{L}_\mathrm{ACL}(S; \theta) = \sum_{x \in S} \ell_\mathrm{ACL}(x; \theta) $$.
-
-Intuitively, given a coreset $$S^*$$, after the model parameters are updated to $$ \theta(S^{*}) $$ via minimizing the ACL loss on the coreset $$\mathcal{L}_\mathrm{ACL}(S^*; \theta)$$, the model will achieve the minimizied RD loss on the validation dataset $$\mathcal{L}_{\mathrm{RD}}(U; \theta(S^*))$$, thus being adversarially robust.
-
-Then, RCS can be converted into a problem of maximizing a set function subject to a cardinality constraint as follows:
-
-$$S^* = \mathop{\arg\max}_{S \subseteq X, |S|/|X| = k} G_\theta(S),$$
-
-$$G_\theta(S \subseteq X) \triangleq - \mathcal{L}_\mathrm{RD}(U; \theta(S)) = - \mathcal{L}_\mathrm{RD}(U; \theta - \eta \nabla_\theta \mathcal{L}_\mathrm{ACL}(S; \theta)),$$
-
-where $$G:2^\mathcal{X} \rightarrow \mathbb{R}$$ is a set function, $$\theta(S)$$ is estimated using the one-step approximation and $$\eta \in \mathbb{R}^+$$ is the learning rate.
-
-**RCS via Greedy Search.**
-
-The vanilla solution of traversing all subsets and selecting the subset that has the largest $$G_\theta(S)$$ is intractable.
-Xu et al. (2023) show that the set function $$G_\theta(S)$$ satisfies the following two critical properties, which motivates a greedy search to efficiently search for the coreset.
-
-The set function $$G_\theta(S)$$ is proved as submodularIn reality, the authors of RCS rigorously proved a proxy set function as weakly submodular. Further, the authors of RCS proved that the greedy search algorithm provides a guaranteed lower bound for the proposed set function maximization problem based on a weakly submodular proxy set function. For more details, please refer to the paper of RCS. which satisfies the following two properties:
-
-- Monotonicity: As more data is added to the set, the representation becomes better. $$G(x\mid X)=G(S \cup \{x\}) - G(S) \geq 0$$ for any $$ S \subseteq X$$ and $$x \in X \setminus S$$.
-- Diminishing returns: As the set has more data, the marginal gain of extra data for learning representations gradually diminishes. $$\mathop{\forall}\limits_{A,B \mid A \subseteq B} G_\theta(x \mid A) \geq G_\theta(x \mid B)$$.
-
-Therefore, RCS greedily searches for the data $$x$$ that has the largest marginal gain and then adds them into the coreset.
-
-
-**Pseudo-code of efficient ACL via RCS.**
-
-- Step 1 (Warm-up): Warm up training on the entire training set to find a better starting point $$f_\theta$$.
-- **Step 2.1 (RCS)**: $$S \gets\emptyset$$. $$\theta' \gets \theta$$. Compute gradients $$ Q \gets \{ q_k = \nabla_\theta \mathcal{L}_\mathrm{ACL}(x_k; \theta) \mid \forall x_k \in X \}$$ on unlabeled training dataset $$X$$.
-- **Step 2.2 (RCS)**: Compute gradients $$q_U \gets \nabla_\theta \mathcal{L}_\mathrm{RD}(U; \theta')$$ on unlabeled validation dataset $$U$$.
-- **Step 2.3 (RCS)**: Select a data $$x_k$$, whose gradient $$q_k$$ matches best with $$q_U$$, i.e., $$\mathop{\arg\max}_k \{q_k^\top q_U \}$$.
-- **Step 2.4 (RCS)**: $$S \gets S \cup \{x_k\}$$, $$X \gets X \setminus \{ x_k \}$$, $$\theta' \gets \theta' - \eta' q_k$$.
-- **Step 2.5 (RCS)**: Repeat Steps 2.2-2.4 until $$\mid S\mid/\mid X\mid = k$$.
-- Step 3 (ACL training): Update parameters $$\theta \gets \theta - \eta \nabla_\theta \mathcal{L}_\mathrm{ACL}(S; \theta)$$.
-- Step 4: Every a few epochs, go to Step 2.1 to generate a new coreset; otherwise go to Step 3 to update model parameters. The algorithm stops when reaching the final training epoch.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/RCS_algo.png" class="img-fluid" %}
-
-
-
- A pipeline of efficient ACL via RCS. After the warm-up periods, the model is trained on the coreset. Thus, RCS makes the training procedure much more efficient by decreasing the number of training data.
-
-
-Intuitively, RCS greedily selects and adds the data $$x$$ whose training loss gradient (i.e., $$\nabla_\theta\mathcal{L}_\mathrm{ACL}(\{x\}, \theta)$$) and validation loss gradient (i.e, $$\nabla_\theta\mathcal{L}_\mathcal{RD}(U; \theta(S))$$) have the most similarity into the coreset. In this way, training on the data selected by RCS is most beneficial in optimizing the RD loss, which is thus most helpful to improve $$f$$'s adversarial robustness.
-
-The official code of RCS is available at [https://github.com/GodXuxilie/Efficient_ACL_via_RCS](https://github.com/GodXuxilie/Efficient_ACL_via_RCS).
-
-### Experimental Results
-
-
-**RCS significantly speeds up ACL on CIFAR-10.**
-- The term `speed-up ratio` refers to the ratio of the time consumption of pre-training on the training set to the the time consumption of pre-training on the training subset. Thus, the larger the speed-up ratio is, the more efficient the pre-training procedure is.
-- The terms `standard test accuracy` and `robust test accuracy` refer to the average accuracy evaluated on natural test data and adversarial test data, respectively. Thus, the higher the line is, the more effective the pre-training method is.
-
-The results obtained by RCS located in the upper-right corner is more efficient and more effective.
-
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/RCS_exp1.png" class="img-fluid" %}
-
-
-
-To reproduce the above results of the robustness transferability from CIFAR-10 to CIFAR-100, you can use the following scripts.
-
-- At the pre-training stage, you can conduct ACL via RCS using ResNet-18 on CIFAR-10.
-
-{% highlight bash %}
-# Pre-training stage using RCS
-git clone https://github.com/GodXuxilie/Efficient_ACL_via_RCS.git
-cd Efficient_ACL_via_RCS/ACL_RCS/small_scale_datasets
-PRE_TRAIN_DIR=ACL_RCS_ResNet18_cifar10
-python DynACL_RCS.py $PRE_TRAIN_DIR --ACL_DS --dataset cifar10 --fraction 0.2
-{% endhighlight %}
-
-- At the fine-tuning stage, you can fine-tune the pre-trained ResNet-18 on CIFAR-100. The test accuracy are saved in `$FINETUNE_DIR/results/log.txt`.
-{% highlight bash %}
-# Fine-tuning stage (SLF, ALF, AFF)
-cd Efficient_ACL_via_RCS/ACL_RCS/small_scale_datasets
-PRE_TRAIN_DIR=ACL_RCS_ResNet18_cifar10
-FINETUNE_DIR=ACL_RCS_ResNet18_cifar10_cifar100
-python finetuning.py --experiment $FINETUNE_DIR \
- --checkpoint ./checkpoints/$PRE_TRAIN_DIR/model.pt \
- --dataset cifar100 \
- --model r18 \
- --mode ALL --eval-AA --eval-OOD --pretraining DynACL_RCS
-{% endhighlight %}
-
-
-**For the first time, ACL was conducted efficiently on ImageNet-1K via RCS.**
-The results prove the possibility of applying ACL on large-scale datasets. Here, `SA` refers to standard test accuracy and `RA` refers to the robust test accuracy.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/RCS_exp2.png" class="img-fluid" %}
-
-
-
-To reproduce the above results of the robustness transferability from ImageNet-1K to CIFAR-10, you can use the following scripts.
-- At the pre-training stage, you can ACL via RCS using Wide ResNet with width 10 and depth 28 (WRN-28-10) on ImageNet-1K of $$32 \times 32$$ resolution.
-
-{% highlight bash %}
-# Pre-training stage using RCS
-git clone https://github.com/GodXuxilie/Efficient_ACL_via_RCS.git
-cd Efficient_ACL_via_RCS/ACL_RCS/ImageNet_32
-PRE_TRAIN_DIR=ACL_RCS_WRN_ImageNet
-python ACL_RCS.py $PRE_TRAIN_DIR --gpu 0,1,2,3 --ACL_DS --fraction 0.05
-{% endhighlight %}
-
-- At the fine-tuning stage, you can fine-tune the ImageNet-1K pre-trained models on CIFAR-10.
-{% highlight bash %}
-cd Efficient_ACL_via_RCS/ACL_RCS/ImageNet_32
-PRE_TRAIN_DIR=ACL_RCS_WRN_ImageNet
-FINETUNE_DIR=ACL_RCS_WRN_ImageNet_cifar10
-# Fine-tuning stage (SLF)
-python transfer.py --out_dir $FINETUNE_DIR/SLF \
- --resume $PRE_TRAIN_DIR/model.pt
- --dataset cifar10 \
- --lr 0.01 --linear
-# Fine-tuning stage (ALF)
-python adv_tune.py --out_dir $FINETUNE_DIR/ALF \
- --resume $PRE_TRAIN_DIR/model.pt \
- --dataset cifar10 \
- --lr 0.1 --linear
-# Fine-tuning stage (AFF)
-python adv_tune.py --out_dir $FINETUNE_DIR/AFF \
- --resume $PRE_TRAIN_DIR/model.pt \
- --dataset cifar10 \
- --lr 0.1
-{% endhighlight %}
-
-**RCS can speed up Standard Adversarial Training (SAT) on ImageNet-1K.** The results show that RCS is applicable to robust pre-training in the supervised setting.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-robust-foundation-model/RCS_exp3.png" class="img-fluid" %}
-
-
-
-To reproduce the above results of the robustness transferability from ImageNet-1K to CIFAR-10, you can use the following scripts.
-
-- At the pre-training stage, you can conduct SAT using WRN-28-10 on ImageNet-1K of $$32 \times 32$$ resolution.
-{% highlight bash %}
-git clone https://github.com/GodXuxilie/Efficient_ACL_via_RCS.git
-cd Efficient_ACL_via_RCS/SAT_RCS/ImageNet_32
-# Pre-training stage using RCS
-PRE_TRAIN_DIR=SAT_RCS_WRN_ImageNet
-nohup python SAT_RCS.py --gpu 0,1,2,3 --out_dir $PRE_TRAIN_DIR --fraction 0.2
-{% endhighlight %}
-
-- At the fine-tuning stage, you can fine-tune ImageNet-1K pre-trained WRN-28-10 on CIFAR-10.
-{% highlight bash %}
-cd Efficient_ACL_via_RCS/SAT_RCS/ImageNet_32
-PRE_TRAIN_DIR=SAT_RCS_WRN_ImageNet
-FINETUNE_DIR=SAT_RCS_WRN_ImageNet_cifar10
-# Fine-tuning stage (ALF)
-python adv_tune.py --out_dir $FINETUNE_DIR/ALF \
- --resume $PRE_TRAIN_DIR/checkpoint.pth.tar \
- --dataset cifar10 \
- --lr 0.1 \
- --linear
-# Fine-tuning stage (AFF)
-python adv_tune.py --out_dir $FINETUNE_DIR/AFF \
- --resume $PRE_TRAIN_DIR/checkpoint.pth.tar
- --dataset cifar10 \
- --lr 0.1
-{% endhighlight %}
-
-
-
diff --git a/_posts/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.md b/_posts/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.md
deleted file mode 100644
index b3a6e865..00000000
--- a/_posts/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.md
+++ /dev/null
@@ -1,680 +0,0 @@
----
-layout: distill
-title: The N Implementation Details of RLHF with PPO
-description: Reinforcement Learning from Human Feedback (RLHF) is pivotal in the modern application of language modeling, as exemplified by ChatGPT. This blog post delves into an in-depth exploration of RLHF, attempting to reproduce the results from OpenAI's inaugural RLHF paper, published in 2019. Our detailed examination provides valuable insights into the implementation details of RLHF, which often go unnoticed.
-
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-
-
-authors:
- - name: Shengyi Costa Huang
- affiliations:
- name: Hugging Face
- - name: Tianlin Liu
- affiliations:
- name: University of Basel
- - name: Leandro von Werra
- affiliations:
- name: Hugging Face
-
-
-# authors:
-# - name: Albert Einstein
-# url: "https://en.wikipedia.org/wiki/Albert_Einstein"
-# affiliations:
-# name: IAS, Princeton
-# - name: Boris Podolsky
-# url: "https://en.wikipedia.org/wiki/Boris_Podolsky"
-# affiliations:
-# name: IAS, Princeton
-# - name: Nathan Rosen
-# url: "https://en.wikipedia.org/wiki/Nathan_Rosen"
-# affiliations:
-# name: IAS, Princeton
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Matching Learning Curves
- - name: General Implementation Details
- - name: Reward Model Implementation Details
- - name: Policy Training Implementation Details
- - name: PyTorch Adam optimizer numerical issues w.r.t RLHF
- - name: Limitations
- - name: Conclusion
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-
-**Reinforcement Learning from Human Feedback** (RLHF) has been an impactful technique for training modern language models such as ChatGPT. In our quest to research more on RLHF, this blog post closely examines OpenAI’s inaugural RLHF paper published in 2019 together with its open-source codebase at available at [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences). Despite being based on TensorFlow-1, the code base released by OpenAI is very well-evaluated and benchmarked, making it a good place to study RLHF implementation engineering details.
-
-We aim to:
-
-1. reproduce OpenAI’s results in stylistic tasks and match the learning curves of [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences), using the modern PyTorch and JAX frameworks in conjunction with HuggingFace Transformers that are predominantly used by the open-source community nowadays;
-2. present a checklist of implementation details, similar to the spirit of [*The 37 Implementation Details of Proximal Policy Optimization*](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) and [*Debugging RL, Without the Agonizing Pain*](https://andyljones.com/posts/rl-debugging.html);
-3. provide a simple-to-read and minimal reference implementation of RLHF;
-
-This work is just for educational / learning purposes. For advanced users requiring more features, such as running larger models with parameter-efficient fine-tuning, [*huggingface/trl*](https://github.com/huggingface/trl) would be a great choice.
-
-- In [Matching Learning Curves](#matching-learning-curves), we show our main contribution: creating a codebase that can reproduce OpenAI’s results in the stylistic tasks and matching learning curves very closely with [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences).
-- We then take a technical deep dive into the implementation details that are relevant to reproducing OpenAI’s work. In [General Implementation Details](#general-implementation-details), we talk about basic details, such as how rewards/values are generated and how responses are generated. In [Reward Model Implementation Details](#reward-model-implementation-details), we talk about details such as reward normalization. In [Policy Training Implementation Details](#policy-training-implementation-details), we discuss details such as rejection sampling and reward “whitening”.
- - In [**PyTorch Adam optimizer numerical issues w.r.t RLHF**](#pytorch-adam-optimizer-numerical-issues-wrt-rlhf), we highlight a very interesting implementation difference in Adam between TensorFlow and PyTorch, which causes an aggressive update in the model training.
-- Next, we examine the effect of training different base models (e.g., gpt2-xl, falcon-1b,) given that the reward labels are produced with `gpt2-large`.
-- Finally, we conclude our work with limitations and discussions.
-
-
-
-
-Here are the important links:
-
-- 💾 [Our reproduction codebase](https://github.com/vwxyzjn/lm-human-preference-details)
-- 🤗 [Demo of RLHF model comparison](https://huggingface.co/spaces/lm-human-preference-details/rlhf-demo)
-- 🐝 [All w&b training logs](https://wandb.ai/openrlbenchmark/lm_human_preference_details)
-
-# Matching Learning Curves
-
-Our main contribution is to reproduce OpenAI’s results in stylistic tasks, such as sentiment and descriptiveness. As shown in the figure below, our codebase (orange curves) can produce nearly identical learning curves as OpenAI’s codebase (blue curves).
-
-
-
-{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching.png" class="img-fluid" %}
-
-
-
-## A note on running openai/lm-human-preferences
-
-To make a direct comparison, we ran the original RLHF code at [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences), which will offer valuable metrics to help validate and diagnose our reproduction. We were able to set the original TensorFlow 1.x code up, but it requires a hyper-specific setup:
-
-- OpenAI’s dataset was partially corrupted/lost (so we replaced them with similar HF datasets, which may or may not cause a performance difference)
- - Specifically, its book dataset was lost during OpenAI’s GCP - Azure migration ([https://github.com/openai/lm-human-preferences/issues/17#issuecomment-1044051496](https://github.com/openai/lm-human-preferences/issues/17#issuecomment-1044051496)). We replaced the book dataset with Hugging Face’s `bookcorpus` dataset, which is, in principle, what OpenAI used.
-- It can’t run on 1 V100 because it doesn’t implement gradient accumulation. Instead, it uses a large batch size and splits the batch across 8 GPUs, and will OOM on just 1 GPU.
-- It can’t run on 8x A100 because it uses TensorFlow 1.x, which is incompatible with Cuda 8+
-- It can’t run on 8x V100 (16GB) because it will OOM
-- It can only run on 8x V100 (32GB), which is only offered by AWS as the `p3dn.24xlarge` instance.
-
-# General Implementation Details
-
-We now take a technical deep dive into the implementation details that are relevant to reproducing OpenAI’s work. In this section, we talk about basic details, such as how rewards/values are generated and how responses are generated. Here are these details in no particular order:
-
-1. **The reward model and policy’s value head take input as the concatenation of `query` and `response`**
- 1. The reward model and policy’s value head do *not* only look at the response. Instead, it concatenates the `query` and `response` together as `query_response` ([lm_human_preferences/rewards.py#L105-L107](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L105-L107)).
- 2. So, for example, if `query = "he was quiet for a minute, his eyes unreadable"`., and the `response = "He looked at his left hand, which held the arm that held his arm out in front of him."`, then the reward model and policy’s value do a forward pass on `query_response = "he was quiet for a minute, his eyes unreadable. He looked at his left hand, which held the arm that held his arm out in front of him."` and produced rewards and values of shape `(B, T, 1)`, where `B` is the batch size, `T` is the sequence length, and `1` is the reward head dimension of 1 ([lm_human_preferences/rewards.py#L105-L107](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L105-L107), [lm_human_preferences/policy.py#L111](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L111)).
- 3. The `T` means that each token has a reward associated with it and its previous context. For example, the `eyes` token would have a reward corresponding to `he was quiet for a minute, his eyes`.
-2. **Pad with a special padding token and truncate inputs.**
- 1. OpenAI sets a fixed input length for query `query_length`; it **pads** sequences that are too short with `pad_token` ([lm_human_preferences/language/datasets.py#L66-L67](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L66-L67)) and **truncates** sequences that are too long ([lm_human_preferences/language/datasets.py#L57](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L57)). See [here](https://huggingface.co/docs/transformers/pad_truncation) for a general introduction to the concept). When padding the inputs, OpenAI uses a token beyond the vocabulary ([lm_human_preferences/language/encodings.py#L56](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/encodings.py#L56)).
- 1. **Note on HF’s transformers — padding token.** According to ([transformers#2630#issuecomment-578159876](https://github.com/huggingface/transformers/issues/2630#issuecomment-578159876)), padding tokens were not used during the pre-training of GPT and GPT-2; therefore transformer’s gpt2 models have no official padding token associated with its tokenizer. A common practice is to set `tokenizer.pad_token = tokenizer.eos_token`, but in this work, we shall distinguish these two special tokens to match OpenAI’s original setting, so we will use `tokenizer.add_special_tokens({"pad_token": "[PAD]"})`.
-
- Note that having no padding token is a default setting for decoder models, since they train with “packing” during pretraining, which means that many sequences are concatenated and separated by the EOS token and chunks of this sequence that always have the max length are fed to the model during pretraining.
- 2. When putting everything together, here is an example
-
- ```python
- import transformers
- tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
- query_length = 5
- texts = [
- "usually, he would",
- "she thought about it",
- ]
- tokens = []
- for text in texts:
- tokens.append(tokenizer.encode(text)[:query_length])
-
- print("tokens", tokens)
- inputs = tokenizer.pad(
- {"input_ids": tokens},
- padding="max_length",
- max_length=query_length,
- return_tensors="pt",
- return_attention_mask=True,
- )
- print("inputs", inputs)
-
- """prints are
- tokens [[23073, 11, 339, 561], [7091, 1807, 546, 340]]
- inputs {'input_ids': tensor([[23073, 11, 339, 561, 50257],
- [ 7091, 1807, 546, 340, 50257]]), 'attention_mask': tensor([[1, 1, 1, 1, 0],
- [1, 1, 1, 1, 0]])}
- """
- ```
-
-3. **Adjust position indices correspondingly for padding tokens**
- 1. When calculating the logits, OpenAI’s code works by masking out padding tokens properly. This is achieved by finding out the token indices corresponding to the padding tokens ([lm_human_preferences/language/model.py#L296-L297](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L296-L297)), followed by adjusting their position indices correspondingly ([lm_human_preferences/language/model.py#L320](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L320)).
- 2. For example, if the `query=[23073, 50259, 50259]` and `response=[11, 339, 561]`, where (`50259` is OpenAI’s padding token), it then creates position indices as `[[0 1 1 1 2 3]]` and logits as follows. Note how the logits corresponding to the padding tokens remain the same as before! This is the effect we should be aiming for in our reproduction.
-
- ```python
- all_logits [[[ -35.28693 -34.2875 -38.16074 ... -41.595802 -41.082108
- -35.36577 ]
- [ -35.28693 -34.2875 -38.16074 ... -41.595802 -41.082108
- -35.36577 ]
- [ -35.28693 -34.2875 -38.16074 ... -41.595802 -41.082108
- -35.36577 ]
- [-111.303955 -110.94471 -112.90624 ... -113.13064 -113.7788
- -109.17345 ]
- [-111.51512 -109.61077 -114.90231 ... -118.43514 -111.56671
- -112.12478 ]
- [-122.69775 -121.84468 -128.27417 ... -132.28055 -130.39604
- -125.707756]]] (1, 6, 50257)
- ```
-
- 3. **Note on HF’s transformers — `position_ids` and `padding_side`.** We can replicate the exact logits using Hugging Face’s transformer with 1) left padding and 2) pass in the appropriate `position_ids`:
-
- ```python
- import torch
- import transformers
- tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
- pad_id = tokenizer.pad_token_id
- query = torch.tensor([
- [pad_id, pad_id, 23073],
- ])
- response = torch.tensor([
- [11, 339, 561],
- ])
- temperature = 1.0
-
- query = torch.tensor(query)
- response = torch.tensor(response).long()
- context_length = query.shape[1]
- query_response = torch.cat((query, response), 1)
- pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
- def forward(policy, query_responses, tokenizer):
- attention_mask = query_responses != tokenizer.pad_token_id
- position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
- input_ids = query_responses.clone()
- input_ids[~attention_mask] = 0
- return policy(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- return_dict=True,
- output_hidden_states=True,
- )
- output = forward(pretrained_model, query_response, tokenizer)
- logits = output.logits
- logits /= temperature
- print(logits)
-
- """
- tensor([[[ -26.9395, -26.4709, -30.0456, ..., -33.2208, -33.2884,
- -27.4360],
- [ -27.1677, -26.7330, -30.2386, ..., -33.6813, -33.6931,
- -27.5928],
- [ -35.2869, -34.2875, -38.1608, ..., -41.5958, -41.0821,
- -35.3658],
- [-111.3040, -110.9447, -112.9062, ..., -113.1306, -113.7788,
- -109.1734],
- [-111.5152, -109.6108, -114.9024, ..., -118.4352, -111.5668,
- -112.1248],
- [-122.6978, -121.8447, -128.2742, ..., -132.2805, -130.3961,
- -125.7078]]], grad_fn=)
- """
- ```
-
- 4. **Note on HF’s transformers — `position_ids` during `generate`:** during generate we should not pass in `position_ids` because the `position_ids` are already adjusted in `transformers` (see [huggingface/transformers#/7552](https://github.com/huggingface/transformers/pull/7552)).
-
- Usually, we almost never pass `position_ids` in transformers. All the masking and shifting logic are already implemented e.g. in the `generate` function (need permanent code link).
-4. **Response generation samples a fixed-length response without padding.**
- 1. During response generation, OpenAI uses `top_k=0, top_p=1.0` and just do categorical samples across the vocabulary ([lm_human_preferences/language/sample.py#L43](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/sample.py#L43)) and the code would keep sampling until a fixed-length response is generated ([lm_human_preferences/policy.py#L103](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L103)). Notably, even if it encounters EOS (end-of-sequence) tokens, it will keep sampling.
- 2. **Note on HF’s transformers — sampling could stop at `eos_token`:** in `transformers`, the generation could stop at `eos_token` ([src/transformers/generation/utils.py#L2248-L2256](https://github.com/huggingface/transformers/blob/67b85f24def79962ce075353c2627f78e0e53e9f/src/transformers/generation/utils.py#L2248-L2256)), which is not the same as OpenAI’s setting. To align the setting, we need to do set `pretrained_model.generation_config.eos_token_id = None, pretrained_model.generation_config.pad_token_id = None`. Note that `transformers.GenerationConfig(eos_token_id=None, pad_token_id=None, ...)` does not work because `pretrained_model.generation_config` would override and set a `eos_token`.
-
- ```python
- import torch
- import transformers
- tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
- pad_id = tokenizer.pad_token_id
- query = torch.tensor([
- [pad_id, pad_id, 23073],
- ])
- response = torch.tensor([
- [11, 339, 561],
- ])
- response_length = 4
- temperature = 0.7
- pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
- pretrained_model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to
- pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding
- generation_config = transformers.GenerationConfig(
- max_new_tokens=response_length,
- min_new_tokens=response_length,
- temperature=temperature,
- top_k=0.0,
- top_p=1.0,
- do_sample=True,
- )
- context_length = query.shape[1]
- attention_mask = query != tokenizer.pad_token_id
- input_ids = query.clone()
- input_ids[~attention_mask] = 0 # set padding tokens to 0
- output = pretrained_model.generate(
- input_ids=input_ids,
- attention_mask=attention_mask,
- # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on.
- generation_config=generation_config,
- return_dict_in_generate=True,
- )
- print(output.sequences)
-
- """
- tensor([[ 0, 0, 23073, 16851, 11, 475, 991]])
- """
- ```
-
- 3. Note that in a more recent codebase https://github.com/openai/summarize-from-feedback, OpenAI does stop sampling when encountering EOS token ([summarize_from_feedback/utils/experiment_helpers.py#L19](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/utils/experiment_helpers.py#L19)). However in this work we aim to do a 1:1 replication, so we align the setting that could keep sampling even eos_token is encountered
-5. **Learning rate annealing for reward model and policy training.**
- 1. As Ziegler et al. (2019) suggested, the reward model is trained for a single epoch to avoid overfitting the limited amount of human annotation data (e.g., the `descriptiveness` task only had about 5000 labels). During this single epoch, the learning rate is annealed to zero ([lm_human_preferences/train_reward.py#L249](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L249)).
- 2. Similar to reward model training, the policy's learning rate is annealed to zero ([lm_human_preferences/train_policy.py#L172-L173](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L172-L173)).
-6. **Use different seeds for different processes**
- 1. When spawning 8 GPU processes to do data parallelism, OpenAI sets a different random seed per process ([lm_human_preferences/utils/core.py#L108-L111](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/utils/core.py#L108-L111)). Implementation-wise, this is done via `local_seed = args.seed + process_rank * 100003`. The seed is going to make the model produce different responses and get different scores, for example.
- 1. Note: We believe the dataset shuffling has a bug — the dataset is shuffled using the same seed for some reason ([lm_human_preferences/lm_tasks.py#L94-L97](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/lm_tasks.py#L94-L97)).
-
-# Reward Model Implementation Details
-
-In this section, we discuss reward-model-specific implementation details. We talk about details such as reward normalization and layer initialization. Here are these details in no particular order:
-
-1. **The reward model only outputs the value at the last token.**
- 1. Notice that the rewards obtained after the forward pass on the concatenation of `query` and `response` will have the shape `(B, T, 1)`, where `B` is the batch size, `T` is the sequence length (which is always the same; it is `query_length + response_length = 64 + 24 = 88` in OpenAI’s setting for stylistic tasks, see [launch.py#L9-L11](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L9-L11)), and `1` is the reward head dimension of 1. For RLHF purposes, the original codebase extracts the reward of the last token ([lm_human_preferences/rewards.py#L132](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L132)), so that the rewards will only have shape `(B, 1)`.
- 2. Note that in a more recent codebase [*openai/summarize-from-feedback*](https://github.com/openai/summarize-from-feedback), OpenAI stops sampling when encountering EOS token ([summarize_from_feedback/utils/experiment_helpers.py#L19](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/utils/experiment_helpers.py#L19)). When extracting rewards, it is going to identify the `last_response_index`, the index before the EOS token ([#L11-L13](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/reward_model.py#L11-L13)), and extract the reward at that index ([summarize_from_feedback/reward_model.py#L59](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/reward_model.py#L59)). However in this work we just stick with the original setting.
-2. **Reward head layer initialization**
- 1. The weight of the reward head is initialized according to \\( \mathcal{N}\left(0,1 /\left(\sqrt{d_{\text {model }}+1}\right)\right) \\) ([lm_human_preferences/language/model.py#L368,](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L368) [lm_human_preferences/language/model.py#L251-L252](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L251-L252)). This aligns with the settings in Stiennon et al., 2020 ([summarize_from_feedback/query_response_model.py#L106-L107](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/query_response_model.py#L106-L107)) (P.S., Stiennon et al., 2020 had a typo on page 17 saying the distribution is \\( \mathcal{N}\left(0,1 /\left(d_{\text {model }}+1\right)\right) \\) without the square root)
- 2. The bias of the reward head is set to 0 ([lm_human_preferences/language/model.py#L254](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L254)).
-3. **Reward model normalization before and after**
- 1. In the paper, Ziegler el al. (2019) mentioned that "to keep the scale of the reward model consistent across training, we normalize it so that it has mean 0 and variance 1 for
- \\( x \sim \mathcal{D}, y \sim \rho(·|x) \\)." To perform the normalization process, the code first creates a `reward_gain` and `reward_bias`, such that the reward can be calculated by `reward = reward * reward_gain + reward_bias` ([lm_human_preferences/rewards.py#L50-L51](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L50-L51)).
- 2. When performing the normalization process, the code first sets `reward_gain=1, reward_bias=0` ([lm_human_preferences/train_reward.py#L211](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L211)), followed by collecting sampled queries from the target dataset (e.g., `bookcorpus, tldr, cnndm`), completed responses, and evaluated rewards. It then gets the **empirical mean and std** of the evaluated reward ([lm_human_preferences/train_reward.py#L162-L167](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L162-L167)) and tries to compute what the `reward_gain` and `reward_bias` should be.
- 3. Let us use \\( \mu_{\mathcal{D}} \\) to denote the empirical mean, \\( \sigma_{\mathcal{D}} \\) the empirical std, \\(g\\) the `reward_gain`, \\(b\\) `reward_bias`, \\( \mu_{\mathcal{T}} = 0\\) **target mean** and \\( \sigma_{\mathcal{T}}=1\\) **target std**. Then we have the following formula.
-
- $$
- \begin{aligned}g*\mathcal{N}(\mu_{\mathcal{D}}, \sigma_{\mathcal{D}}) + b &= \mathcal{N}(g*\mu_{\mathcal{D}}, g*\sigma_{\mathcal{D}}) + b\\&= \mathcal{N}(g*\mu_{\mathcal{D}} + b, g*\sigma_{\mathcal{D}}) \\&= \mathcal{N}(\mu_{\mathcal{T}}, \sigma_{\mathcal{T}}) \\g &= \frac{\sigma_{\mathcal{T}}}{\sigma_{\mathcal{D}}} \\b &= \mu_{\mathcal{T}} - g*\mu_{\mathcal{D}}\end{aligned}
- $$
-
- 4. The normalization process is then applied **before** and **after** reward model training ([lm_human_preferences/train_reward.py#L232-L234](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L232-L234), [lm_human_preferences/train_reward.py#L252-L254](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L252-L254)).
-
-
- 5. Note that responses \\( y \sim \rho(·|x) \\) we generated for the normalization purpose are from the pre-trained language model \\(\rho \\). The model
- \\(\rho \\) is fixed as a reference and is not updated in reward learning ([lm_human_preferences/train_reward.py#L286C1-L286C31](https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_reward.py#L286C1-L286C31)).
-
-# Policy Training Implementation Details
-
-In this section, we will delve into details, such as layer initialization, data post-processing, and dropout settings. We will also explore techniques, such as of rejection sampling and reward "whitening", and adaptive KL. Here are these details in no particular order:
-
-1. **Scale the logits by sampling temperature.**
- 1. When calculating the log probability of responses, the model first outputs the logits of the tokens in the responses, followed by dividing the logits with the sampling temperature ([lm_human_preferences/policy.py#L121](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L121)). I.e., `logits /= self.temperature`
- 2. In an informal test, we found that without this scaling, the KL would rise faster than expected, and performance would deteriorate.
-2. **Value head layer initialization**
- 1. The weight of the value head is initialized according to \\(\mathcal{N}\left(0,0\right)\\) ([lm_human_preferences/language/model.py#L368,](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L368) [lm_human_preferences/language/model.py#L251-L252](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L251-L252)). This is
- 2. The bias of the reward head is set to 0 ([lm_human_preferences/language/model.py#L254](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L254)).
-3. **Select query texts that start and end with a period**
- 1. This is done as part of the data preprocessing;
- 1. Tries to select text only after `start_text="."` ([lm_human_preferences/language/datasets.py#L51](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L51))
- 2. Tries select text just before `end_text="."` ([lm_human_preferences/language/datasets.py#L61](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L61))
- 3. Then pad the text ([lm_human_preferences/language/datasets.py#L66-L67](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L66-L67))
- 2. When running `openai/lm-human-preferences`, OpenAI’s datasets were partially corrupted/lost ([openai/lm-human-preferences/issues/17#issuecomment-104405149](https://github.com/openai/lm-human-preferences/issues/17#issuecomment-1044051496)), so we had to replace them with similar HF datasets, which may or may not cause a performance difference)
- 3. For the book dataset, we used [https://huggingface.co/datasets/bookcorpus](https://huggingface.co/datasets/bookcorpus), which we find not necessary to extract sentences that start and end with periods because the dataset ) is already pre-processed this way (e.g., `"usually , he would be tearing around the living room , playing with his toys ."`) To this end, we set `start_text=None, end_text=None` for the `sentiment` and `descriptiveness` tasks.
-4. **Disable dropout**
- 1. Ziegler et al. (2019) suggested, “We do not use dropout for policy training.” This is also done in the code ([lm_human_preferences/policy.py#L48](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L48)).
-5. **Rejection sampling**
- 1. Ziegler et al. (2019) suggested, “We use rejection sampling to ensure there is a period between tokens 16 and 24 and then truncate at that period (This is a crude approximation for ‘end of sentence.’ We chose it because it is easy to integrate into the RL loop, and even a crude approximation is sufficient for the intended purpose of making the human evaluation task somewhat easier). During the RL finetuning, we penalize continuations that don’t have such a period by giving them a fixed reward of −1.”
- 2. Specifically, this is achieved with the following steps:
- 1. **Token truncation**: We want to truncate at the first occurrence of `truncate_token` that appears at or after position `truncate_after` in the responses ([lm_human_preferences/train_policy.py#L378](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378))
- 2. **Run reward model on truncated response:** After the response has been truncated by the token truncation process, the code then runs the reward model on the **truncated response**.
- 3. **Rejection sampling**: if there is not a period between tokens 16 and 24, then replace the score of the response with a fixed low value (such as -1)([lm_human_preferences/train_policy.py#L384](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L384), [lm_human_preferences/train_policy.py#L384-L402](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L384-L402))
- 4. To give some examples in `descriptiveness`:
-
- {% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/descriptiveness-samples.png" class="img-fluid" %}
-
-6. **Discount factor = 1**
- 1. The discount parameter \\(\gamma\\) is set to 1 ([lm_human_preferences/train_policy.py#L56](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L56)), which means that future rewards are given the same weight as immediate rewards.
-7. **Terminology of the training loop: batches and minibatches in PPO**
- 1. OpenAI uses the following training loop ([lm_human_preferences/train_policy.py#L184-L192](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L184-L192)). Note: we additionally added the `micro_batch_size` to help deal with the case in gradient accumulation. At each epoch, it shuffles the batch indices.
-
- ```python
-
- import numpy as np
- batch_size = 8
- nminibatches = 2
- gradient_accumulation_steps = 2
- mini_batch_size = batch_size // nminibatches
- micro_batch_size = mini_batch_size // gradient_accumulation_steps
- data = np.arange(batch_size).astype(np.float32)
- print("data:", data)
- print("batch_size:", batch_size)
- print("mini_batch_size:", mini_batch_size)
- print("micro_batch_size:", micro_batch_size)
- for epoch in range(4):
- batch_inds = np.random.permutation(batch_size)
- print("epoch:", epoch, "batch_inds:", batch_inds)
- for mini_batch_start in range(0, batch_size, mini_batch_size):
- mini_batch_end = mini_batch_start + mini_batch_size
- mini_batch_inds = batch_inds[mini_batch_start:mini_batch_end]
-
- # `optimizer.zero_grad()` set optimizer to zero for gradient accumulation
- for micro_batch_start in range(0, mini_batch_size, micro_batch_size):
- micro_batch_end = micro_batch_start + micro_batch_size
- micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
- print("____⏩ a forward pass on", data[micro_batch_inds])
- # `optimizer.step()`
- print("⏪ a backward pass on", data[mini_batch_inds])
-
- # data: [0. 1. 2. 3. 4. 5. 6. 7.]
- # batch_size: 8
- # mini_batch_size: 4
- # micro_batch_size: 2
- # epoch: 0 batch_inds: [6 4 0 7 3 5 1 2]
- # ____⏩ a forward pass on [6. 4.]
- # ____⏩ a forward pass on [0. 7.]
- # ⏪ a backward pass on [6. 4. 0. 7.]
- # ____⏩ a forward pass on [3. 5.]
- # ____⏩ a forward pass on [1. 2.]
- # ⏪ a backward pass on [3. 5. 1. 2.]
- # epoch: 1 batch_inds: [6 7 3 2 0 4 5 1]
- # ____⏩ a forward pass on [6. 7.]
- # ____⏩ a forward pass on [3. 2.]
- # ⏪ a backward pass on [6. 7. 3. 2.]
- # ____⏩ a forward pass on [0. 4.]
- # ____⏩ a forward pass on [5. 1.]
- # ⏪ a backward pass on [0. 4. 5. 1.]
- # epoch: 2 batch_inds: [1 4 5 6 0 7 3 2]
- # ____⏩ a forward pass on [1. 4.]
- # ____⏩ a forward pass on [5. 6.]
- # ⏪ a backward pass on [1. 4. 5. 6.]
- # ____⏩ a forward pass on [0. 7.]
- # ____⏩ a forward pass on [3. 2.]
- # ⏪ a backward pass on [0. 7. 3. 2.]
- # epoch: 3 batch_inds: [7 2 4 1 3 0 6 5]
- # ____⏩ a forward pass on [7. 2.]
- # ____⏩ a forward pass on [4. 1.]
- # ⏪ a backward pass on [7. 2. 4. 1.]
- # ____⏩ a forward pass on [3. 0.]
- # ____⏩ a forward pass on [6. 5.]
- # ⏪ a backward pass on [3. 0. 6. 5.]
- ```
-
-8. **Per-token KL penalty**
- - The code adds a per-token KL penalty ([lm_human_preferences/train_policy.py#L150-L153](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L150-L153)) to the rewards, in order to discourage the policy to be very different from the original policy.
- - Using the `"usually, he would"` as an example, it gets tokenized to `[23073, 11, 339, 561]`. Say we use `[23073]` as the query and `[11, 339, 561]` as the response. Then under the default `gpt2` parameters, the response tokens will have log probabilities of the reference policy `logprobs=[-3.3213, -4.9980, -3.8690]` .
- - During the first PPO update epoch and minibatch update, so the active policy will have the same log probabilities `new_logprobs=[-3.3213, -4.9980, -3.8690]`. , so the per-token KL penalty would be `kl = new_logprobs - logprobs = [0., 0., 0.,]`
- - However, after the first gradient backward pass, we could have `new_logprob=[3.3213, -4.9980, -3.8690]` , so the per-token KL penalty becomes `kl = new_logprobs - logprobs = [-0.3315, -0.0426, 0.6351]`
- - Then the `non_score_reward = beta * kl` , where `beta` is the KL penalty coefficient \\(\beta\\), and it’s added to the `score` obtained from the reward model to create the `rewards` used for training. The `score` is only given at the end of episode; it could look like `[0.4,]` , and we have `rewards = [beta * -0.3315, beta * -0.0426, beta * 0.6351 + 0.4]`.
-9. **Per-minibatch reward and advantage whitening, with optional mean shifting**
- 1. OpenAI implements a `whiten` function that looks like below, basically normalizing the `values` by subtracting its mean followed by dividing by its standard deviation. Optionally, `whiten` can shift back the mean of the whitened `values` with `shift_mean=True`.
-
- ```python
- def whiten(values, shift_mean=True):
- mean, var = torch.mean(values), torch.var(values, unbiased=False)
- whitened = (values - mean) * torch.rsqrt(var + 1e-8)
- if not shift_mean:
- whitened += mean
- return whitened
- ```
-
- 1. In each minibatch, OpenAI then whitens the reward `whiten(rewards, shift_mean=False)` without shifting the mean ([lm_human_preferences/train_policy.py#L325](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L325)) and whitens the advantages `whiten(advantages)` with the shifted mean ([lm_human_preferences/train_policy.py#L338](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L338)).
- 2. **Optimization note:** if the number of minibatches is one (which is the case in this reproduction) we only need to whiten rewards, calculate and whiten advantages once since their values won’t change.
- 3. **TensorFlow vs PyTorch note:** Different behavior of `tf.moments` vs `torch.var`: The behavior of whitening is different in torch vs tf because the variance calculation is different:
-
- ```jsx
- import numpy as np
- import tensorflow as tf
- import torch
-
- def whiten_tf(values, shift_mean=True):
- mean, var = tf.nn.moments(values, axes=list(range(values.shape.rank)))
- mean = tf.Print(mean, [mean], 'mean', summarize=100)
- var = tf.Print(var, [var], 'var', summarize=100)
- whitened = (values - mean) * tf.rsqrt(var + 1e-8)
- if not shift_mean:
- whitened += mean
- return whitened
-
- def whiten_pt(values, shift_mean=True, unbiased=True):
- mean, var = torch.mean(values), torch.var(values, unbiased=unbiased)
- print("mean", mean)
- print("var", var)
- whitened = (values - mean) * torch.rsqrt(var + 1e-8)
- if not shift_mean:
- whitened += mean
- return whitened
-
- rewards = np.array([
- [1.2, 1.3, 1.4],
- [1.5, 1.6, 1.7],
- [1.8, 1.9, 2.0],
- ])
-
- with tf.Session() as sess:
- print(sess.run(whiten_tf(tf.constant(rewards, dtype=tf.float32), shift_mean=False)))
- print(whiten_pt(torch.tensor(rewards), shift_mean=False, unbiased=True))
- print(whiten_pt(torch.tensor(rewards), shift_mean=False, unbiased=False))
- ```
-
- ```jsx
- mean[1.5999999]
- var[0.0666666627]
- [[0.05080712 0.4381051 0.8254035 ]
- [1.2127019 1.6000004 1.9872988 ]
- [2.3745968 2.7618952 3.1491938 ]]
- mean tensor(1.6000, dtype=torch.float64)
- var tensor(0.0750, dtype=torch.float64)
- tensor([[0.1394, 0.5046, 0.8697],
- [1.2349, 1.6000, 1.9651],
- [2.3303, 2.6954, 3.0606]], dtype=torch.float64)
- mean tensor(1.6000, dtype=torch.float64)
- var tensor(0.0667, dtype=torch.float64)
- tensor([[0.0508, 0.4381, 0.8254],
- [1.2127, 1.6000, 1.9873],
- [2.3746, 2.7619, 3.1492]], dtype=torch.float64)
-
- ```
-
-10. **Clipped value function**
- 1. As done in the original PPO ([baselines/ppo2/model.py#L68-L75](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L68-L75)), the value function is clipped ([lm_human_preferences/train_policy.py#L343-L348](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L343-L348)) in a similar fashion as the policy objective.
-11. **Adaptive KL**
- - The KL divergence penalty coefficient \\(\beta\\) is modified adaptively based on the KL divergence between the current policy and the previous policy. If the KL divergence is outside a predefined target range, the penalty coefficient is adjusted to bring it closer to the target range ([lm_human_preferences/train_policy.py#L115-L124](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L115-L124)). It’s implemented as follows:
-
- ```python
- class AdaptiveKLController:
- def __init__(self, init_kl_coef, hparams):
- self.value = init_kl_coef
- self.hparams = hparams
-
- def update(self, current, n_steps):
- target = self.hparams.target
- proportional_error = np.clip(current / target - 1, -0.2, 0.2)
- mult = 1 + proportional_error * n_steps / self.hparams.horizon
- self.value *= mult
- ```
-
- - For the `sentiment` and `descriptiveness` tasks examined in this work, we have `init_kl_coef=0.15, hparams.target=6, hparams.horizon=10000`.
-
-## **PyTorch Adam optimizer numerical issues w.r.t RLHF**
-
-- This implementation detail is so interesting that it deserves a full section.
-- PyTorch Adam optimizer ([torch.optim.Adam.html](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html)) has a different implementation compared to TensorFlow’s Adam optimizer (TF1 Adam at [tensorflow/v1.15.2/adam.py](https://github.com/tensorflow/tensorflow/blob/v1.15.2/tensorflow/python/training/adam.py), TF2 Adam at [keras/adam.py#L26-L220](https://github.com/keras-team/keras/blob/v2.13.1/keras/optimizers/adam.py#L26-L220)). In particular, **PyTorch follows Algorithm 1** of the Kingma and Ba’s Adam , but **TensorFlow uses the formulation just before Section 2.1** of the paper and its `epsilon` referred to here is `epsilon hat` in the paper. In a pseudocode comparison, we have the following
-
-```python
-### pytorch adam implementation:
-bias_correction1 = 1 - beta1 ** step
-bias_correction2 = 1 - beta2 ** step
-step_size = lr / bias_correction1
-bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
-denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
-param.addcdiv_(exp_avg, denom, value=-step_size)
-
-### tensorflow adam implementation:
-lr_t = lr * _dispatch_sqrt((1 - beta2 ** step)) / (1 - beta1 ** step)
-denom = exp_avg_sq.sqrt().add_(eps)
-param.addcdiv_(exp_avg, denom, value=-lr_t)
-```
-
-- Let’s compare the update equations of pytorch-style and tensorflow-style adam. Following the notation of the adam paper [(Kingma and Ba, 2014)](https://arxiv.org/abs/1412.6980), we have the gradient update rules for pytorch adam (Algorithm 1 of Kingma and Ba’s paper) and tensorflow-style adam (the formulation just before Section 2.1 of Kingma and Ba’s paper) as below:
-
-$$\begin{aligned}\text{pytorch adam :}\quad \theta_t & =\theta_{t-1}-\alpha \cdot \hat{m}_t /\left(\sqrt{\hat{v}_t}+\varepsilon\right) \\& =\theta_{t-1}- \alpha \underbrace{\left[m_t /\left(1-\beta_1^t\right)\right]}_{=\hat{m}_t} /\left[\sqrt{\underbrace{v_t /\left(1-\beta_2^t\right)}_{=\hat{v}_t} }+\varepsilon\right]\\& =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right]\frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{green}{\varepsilon \sqrt{1-\beta_2^t}}}\end{aligned}$$
-
-$$\begin{aligned}\text{tensorflow adam:}\quad \theta_t & =\theta_{t-1}-\alpha_t m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}-\underbrace{\left[\alpha \sqrt{1-\beta_2^t} /\left(1-\beta_1^t\right)\right]}_{=\alpha_t} m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right] \frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{green}{\hat{\varepsilon}}} \end{aligned}$$
-
-
-- The equations above highlight that the distinction between pytorch and tensorflow implementation is their **normalization terms**, \\(\color{green}{\varepsilon \sqrt{1-\beta_2^t}}\\) and \\(\color{green}{\hat{\varepsilon}}\\). The two versions are equivalent if we set \\(\hat{\varepsilon} =\varepsilon \sqrt{1-\beta_2^t}\\) . However, in the pytorch and tensorflow APIs, we can only set \\(\varepsilon\\) (pytorch) and \\(\hat{\varepsilon}\\) (tensorflow) via the `eps` argument, causing differences in their update equations. What if we set \\(\varepsilon\\) and \\(\hat{\varepsilon}\\) to the same value, say, 1e-5? Then for tensorflow adam, the normalization term \\(\hat{\varepsilon} = \text{1e-5}\\) is just a constant. But for pytorch adam, the normalization term \\({\varepsilon \sqrt{1-\beta_2^t}}\\) changes over time. Importantly, initially much smaller than 1e-5 when the timestep \\(t\\) is small, the term \\({\varepsilon \sqrt{1-\beta_2^t}}\\) gradually approaches to 1e-5 as timesteps increase. The plot below compares these two normalization terms over timesteps:
-
-{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/norma_const_comparison.png" class="img-fluid" %}
-
-- The above figure shows that, if we set the same `eps` in pytorch adam and tensorflow adam, then pytorch-adam uses a much smaller normalization term than tensorflow-adam in the early phase of training. In other words, pytorch adam goes for **more aggressive gradient updates early in the training**. Our experiments support this finding, as we will demonstrate below.
-- How does this impact reproducibility and performance? To align settings, we record the original query, response, and rewards from [https://github.com/openai/lm-human-preferences](https://github.com/openai/lm-human-preferences) and save them. We also record the metrics of the first two epochs of training with TF1’s `AdamOptimizer` optimizer as the ground truth. Below are some key metrics:
-
-
- | | OpenAI’s TF1 Adam | PyTorch’s Adam | Our custom Tensorflow-style Adam |
- | --- | --- | --- | --- |
- | policy/approxkl | 0.00037167023 | 0.0023672834504395723 | 0.000374998344341293 |
- | policy/clipfrac | 0.0045572915 | 0.02018229104578495 | 0.0052083334885537624 |
- | ratio_mean | 1.0051285 | 1.0105520486831665 | 1.0044583082199097 |
- | ratio_var | 0.0007716546 | 0.005374275613576174 | 0.0007942612282931805 |
- | ratio_max | 1.227216 | 1.8121057748794556 | 1.250215768814087 |
- | ratio_min | 0.7400441 | 0.4011387825012207 | 0.7299948930740356 |
- | logprob_diff_mean | 0.0047487603 | 0.008101251907646656 | 0.004073789343237877 |
- | logprob_diff_var | 0.0007207897 | 0.004668936599045992 | 0.0007334011606872082 |
- | logprob_diff_max | 0.20474821 | 0.594489574432373 | 0.22331619262695312 |
- | logprob_diff_min | -0.30104542 | -0.9134478569030762 | -0.31471776962280273 |
-- **PyTorch’s `Adam` produces a more aggressive update** for some reason. Here are some evidence:
- - **PyTorch’s `Adam`'s `logprob_diff_var`** **is 6x higher**. Here `logprobs_diff = new_logprobs - logprobs` is the difference between the log probability of tokens between the initial and current policy after two epochs of training. Having a larger `logprob_diff_var` means the scale of the log probability changes is larger than that in OpenAI’s TF1 Adam.
- - **PyTorch’s `Adam` presents a more extreme ratio max and min.** Here `ratio = torch.exp(logprobs_diff)`. Having a `ratio_max=1.8121057748794556` means that for some token, the probability of sampling that token is 1.8x more likely under the current policy, as opposed to only 1.2x with OpenAI’s TF1 Adam.
- - **Larger `policy/approxkl` `policy/clipfrac`.** Because of the aggressive update, the ratio gets clipped **4.4x more often, and the approximate KL divergence is 6x larger.**
- - The aggressive update is likely gonna cause further issues. E.g., `logprob_diff_mean` is 1.7x larger in PyTorch’s `Adam`, which would correspond to 1.7x larger KL penalty in the next reward calculation; this could get compounded. In fact, this might be related to the famous KL divergence issue — KL penalty is much larger than it should be and the model could pay more attention and optimizes for it more instead, therefore causing negative KL divergence.
-- **Larger models get affected more.** We conducted experiments comparing PyTorch’s `Adam` (codename `pt_adam`) and our custom TensorFlow-style (codename `tf_adam`) with `gpt2` and `gpt2-xl`. We found that the performance are roughly similar under `gpt2`; however with `gpt2-xl`, we observed a more aggressive updates, meaning that larger models get affected by this issue more.
- - When the initial policy updates are more aggressive in `gpt2-xl`, the training dynamics get affected. For example, we see a much larger `objective/kl` and `objective/scores` spikes with `pt_adam`, especially with `sentiment` — *the biggest KL was as large as 17.5* in one of the random seeds, suggesting an undesirable over-optimization.
- - Furthermore, because of the larger KL, many other training metrics are affected as well. For example, we see a much larger `clipfrac` (the fraction of time the `ratio` gets clipped by PPO’s objective clip coefficient 0.2) and `approxkl`.
-
-
-
-{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2.png" class="img-fluid" %}
-
-
-{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2_xl.png" class="img-fluid" %}
-
-
-
-# Limitations
-
-Noticed this work does not try to reproduce the summarization work in CNN DM or TL;DR. This was because we found the training to be time-consuming and brittle.
-
-The particular training run we had showed poor GPU utilization (around 30%), so it takes almost 4 days to perform a training run, which is highly expensive (only AWS sells p3dn.24xlarge, and it costs $31.212 per hour)
-
-Additionally, training was brittle. While the reward goes up, we find it difficult to reproduce the “smart copier” behavior reported by Ziegler et al. (2019) . Below are some sample outputs — clearly, the agent overfits somehow.
-
-```
-QUERY: The modern take on Death & Taxes still plays it. I know as a Tron player I find Mindcensor
-pretty frustrating. ⏎ However, the answer to your question is probably that the decks that
-are currently taking up the lion's share of the metagame don't really care about its effect.
-It has a chance to act as an instant speed Stone Rain + Gut Shot in response to a fetch
-crack, but after that it's a 2/1 flyer that dies to all relevant removal. ⏎ It's especially
-powerful against Tron since so much of the deck's consistency is in its tutor effects --
-Sylvan Scrying, Expedition Map, and Eye of Ugin. This combined with Pyroclasm and Oblivion
-Stone being the only creature removal the deck has available to it before Karn make
-Mindcensor fairly potent. ⏎ I'm having trouble thinking of any other popular decks (ignoring
-the interaction with fetches) that Mindcensor really sticks it to, however.
-
-PPO GENERATED RESPONSE: Mindcensor is a very good card, but it's not really worth playing.<|endoftext|>Image
-caption The family was forced from their home after they were not allowed to leave the
-house⏎⏎A family of four has been jailed for five years after admitting breaching home rule
-by refusing to leave their home.⏎
-
-reward: -0.48
-kl: +11.37
-total reward: -0.82
-
-QUERY: To make a long story short, my SO (27/m) and I live in my hometown. We've been together for
-4 years. His family lives across the country. A lot of the time we spend with my parents but
-besides that, we aren't very social. His family not being here doesn't make me feel like I
-can't be with mine. I just don't like to. ⏎ I find my extended family (so not my parents)
-extremely irritating. Especially the ones around my age. They're dramatic and rude. Every
-family event is some Jerry Springer episode I don't want to participate in. I don't have any
-issues with anyone in particular. I just find them annoying. I'm never uncivil or openly
-telling them how I feel. If there's get togethers, I tend to avoid them. I don't ask anyone
-to hang out or come over. ⏎ My SO and I have experienced multiple miscarriages due to
-unknown reasons. This has made me very sensitive and a bit of a loaner when it comes to the
-girls in my family my age. I'm not jealous that my family has children. I just don't seem to
-be able to relate to them. When I explained our issues, they reacted very unsympathetic and
-just weren't very supportive. I guess that was the time that I started to pull away and not
-try to keep up any relationships. ⏎ My SO and I are pregnant, half way! This is something
-we're so joyed about. However, when we told everyone, they just didn't really give a
-reaction. No congrats or that they were happy for us. Since then, I completely pulled away.
-I talked to my parents about this and I just don't feel the need to contact anyone. The
-reason I'm posting is because I discussed deleting some relatives off social media with my
-mom and she was telling me how it's really rude. I don't think it is. We don't talk. There's
-no relationship. Honestly, I think they're just friends with me for gossip purposes. A few
-times I tried reaching out to some of my cousins with no reply. ⏎ I guess what I'm asking is
-if I should be trying harder to establish something with my family or if it's perfectly OK
-just to delete them and not talk to them again? Our primary contact is Facebook even though
-we live in the same town.
-
-PPO GENERATED RESPONSE: We're dating now. We've met, we're dating. We're dating now. We're dating now. We're dating
-now. We're dating now. We're dating now.⏎⏎We've met⏎⏎We're dating now.⏎⏎We're dating now.⏎
-
-reward: +2.09
-kl: +19.45
-total reward: +1.50
-
-QUERY: This tower totally decided to screw me over on the final stage. So a bit of background I
-don't have the strongest squad but pretty much I have: ⏎ Guardian Loch (Sacred Jewel) ⏎
-Anima Tiara (Flesh Armor) ⏎ Lord Cayena (sacred Jewel) ⏎ Oracle Twins (sacred Jewel) ⏎
-Guardian Logan (Burny soul stone) ⏎ and was using a friends maxed michelle. ⏎ So throughout
-the earlier stages I've seen this guy about 5 times but have never caught him. So I had a
-little trouble with stage 9 but felt like with the additional 50% atk boost and bb from
-michelle I could do it. Also on stage 9 the reason I died was because I didn't take any
-stimulants and this guy managed to paralyze 3 members of my squad. So I replaced mega cures
-for stimulants. ⏎ Round 3 I meet him. He gets a lucky crit on my caynea and takes her down.
-I have to use my one revive, take him down and capture him. I'm pretty excited. ⏎ Round 7
-I see the warning danger sign. I face him again and manage to take him down without any of
-my units dying no capture. At this point I'm starting to get low on healing items. Also I've
-only heard of meeting him twice on a run. But then I checked this thread and saw that he
-always appears on the last stage. I thought that maybe it glitched and he wouldn't appear on
-the final round. ⏎ Nope. ⏎ This guy shows up on the final round showing up 3 times in one
-run. Worst part was I manged to get him down to a sliver of health so that as long as one
-unit hit him I would've defeated him with the next turn. This was also the turn he decided
-to use soul desecration twice. I definitely felt like my soul got desecrated watching all my
-units die. So at this point it was give up the captured unit and try all over again (knowing
-my luck never capture him again) or gem for the last hit. I was really looking forward to
-getting the gem for completion (f2p gotta scrap up everything right?) but I let it go.
-
-PPO GENERATED RESPONSE: ⏎⏎When you capture a unit try to kill it.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill
-them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎
-
-reward: +0.79
-kl: +10.82
-total reward: +0.46
-```
-
-# Conclusion
-
-In this work, we took a deep dive into OpenAI’s original RLHF codebase and compiled a list of its implementation details. We also created a minimal base which reproduces the same learning curves as OpenAI’s original RLHF codebase, when the dataset and hyperparameters are controlled. Furthermore, we identify surprising implementation details such as the adam optimizer’s setting which causes aggressive updates in early RLHF training.
-
-
diff --git a/_posts/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective.md b/_posts/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective.md
deleted file mode 100644
index 07074cf8..00000000
--- a/_posts/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective.md
+++ /dev/null
@@ -1,197 +0,0 @@
----
-layout: distill
-title: Understanding gradient inversion attacks from the prior knowledge perspective
-description: In this blogpost, we mention multiple works in gradient inversion attacks, point out the chanllenges we need to solve in GIAs, and provide a perspective from the prior knowledge to understand the logic behind recent papers.
-date: 2024-11-27
-future: true
-htmlwidgets: true
-
-#Anonymize when submitting
-authors:
- - name: Yanbo Wang
- affiliations:
- name: School of AI, UCAS $\n$ CRIPAC & MAIS, CASIA
- - name: Jian Liang
- affiliations:
- name: School of AI, UCAS $\n$ CRIPAC & MAIS, CASIA
- - name: Ran He
- affiliations:
- name: School of AI, UCAS $\n$ CRIPAC & MAIS, CASIA
-
-
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Fundamental pipeline of GIAs
- - name: The tough challenge in GIAs
- - subsections:
- - name: A simple example of information discards
- - name: Understanding GIAs from the prior knowledge perspective
- - subsections:
- - name: Unparameterized regularization terms
- - name: Generative models
- - name: End-to-end networks
- - name: Limitation and future directions
- - name: Conclusions
-
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-Federated learning, as a way to collaboratively train a deep model, was originally developed to enhance training efficiency and protect data privacy. In a federated learning paradigm, no matter whether it is horizontal or vertical, data could be processed locally, and the central server could only get access to the processed information, such as trained model weights or intermediate gradients. Avoiding direct access to private local data, federated learning is believed to successfully protect clients' data privacy, for the central server could only make use of uploaded information to train a global model but it does not know exactly what the training dataset really contains. However, in horizontal federated learning, researchers found that with training gradients, the central server could still recover input data, which may be a threat to training data privacy. Such privacy attack is then named gradient inversion attack (or gradient leakage attack).
-
-## Fundamental pipeline of Gradient inversion attacks (GIAs)
-Gradient inversion attacks (GIAs) aim at reconstructing clients' private input data from the gradients in deep neural network training phases. It is a threat to federated learning framework, especially the horizontal one where a curious-but-honest central server collects gradients from multiple clients, analyzes the optimal parameter updating direction, and sends back the updated model in one step. Getting rid of complicated mathematical formulas, GIA is actually a matching process: the attacker (which is the central server in the most common settings) expects that the data it randomly initialized could finally generate the identical gradients as the ground truth, therefore it measures the difference (or distance) to optimize input data pixel-wisely. The smaller the distance between gradients, the better the private data are reconstructed.
-{% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/Picture1.jpg" class="img-fluid" %}
-This is a **white-box** attack, for its requirement for full model parameters to conduct backpropagation. In such a process, with fixed model parameters, the distance between gradients is highly dependent on the attacker's dummy data. GIA's target is to optimize the distance below, where $x^\ast$ and $y^\ast$ represent the dummy data-label tuple, $\mathcal{D}$ represents the distance function, $\theta$ represents the model weights, and $\mathcal{L}$ represents the CE loss.
-
-$$\arg\min \limits_{(x^*,y^*)} {\mathcal{D}}\left(\nabla_\theta\mathcal{L}_\theta\left( x,y\right),\nabla_\theta\mathcal{L}_\theta\left( x^*,y^*\right)\right)$$
-
-After raising this problem, there are a few research topics in this field. iDLG provides a way to recover the input label analytically. Following this, a series of works is proposed to recover labels from batches, and it is generally believed that compared with optimizing image-label tuples simultaneously, simply optimizing input images with ground-truth labels could achieve better performance. Except for recovering labels, attack evaluations and defense methods also attract much attention. However, recovering high-quality images is still the key focus.
-## The tough challenge in GIAs
-In GIA, the tough challenge, which has not been solved yet, is the reconstruction of batched input data, where **multiple samples share the same labels**. Previous works headed towards such a goal by a few steps: they first recovered single input data, then extended them to batches with known labels, and added a new algorithm to recover batched one-hot labels before recovering input images. However, to the best of my knowledge, it is still limited to the situation where **for every class there could be at most one sample in a batch**. Batched data recovery with repeated labels is still a failure for all current algorithms. The key reason for this failure lies in the information discard of averaged gradients.
-### A simple example of information discards
-Let's first take a look at a simple neural network: MLP. In a specific layer, it takes in intermediate features $\mathbf{x}$ and outputs a result of matrix multiplication $\mathbf{z}=\mathbf{Wx}+\mathbf{b}$. To recover the input from gradients, we could simply use the bias attack:
-
-$$\frac{\partial \mathcal{L}}{\partial {\mathbf{W}}}=\frac{\partial \mathcal{L}}{\partial \mathbf{z}} \times \frac{\partial \mathbf{z}}{\partial {\mathbf{W}}}=\frac{\partial \mathcal{L}}{\partial {b}}\mathbf{x}^\mathrm{T}$$
-
-In the above equation, it is clear that for a single input, with full access to model weights and gradients, the gradients of the MLP contain full information to execute single-image recovery.
-
-Here, we conduct a simple experiment to illustrate the existence of information discard. Firstly We pick a 4-layer MLP as the target neural network and randomly select a few images from the Flowers-17 dataset as the private input data for recovery. We take $l_2$ loss as the gradient matching function without any prior knowledge (regularization terms). Firstly, we provide an example of input image recovery when **`batchsize=1` with known labels**.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_l2_fc.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_l2_1.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_l2_2.gif" class="img-fluid rounded z-depth-1" %}
-
-
-
- Image reconstruction with $l_2$ loss on MLP. no regularization terms are adopted.
-
-
-It is not surprising that $l_2$ gradient matching functions could recover the input data well. Such a good performance is mainly because MLP's gradients contain enough information of intermediate features for single inputs. With proper labels, we could conclude that GIA works well on MLP when `batchsize=1`.
-
-However, when it comes to CNNs, such inversion gets harder. For convolution layers, the gradients of convolution kernels are aggregated through the whole feature map, therefore even if we set batchsize=1, gradients may still experience information discards, affecting the attack performance. This problem is also mentioned in R-GAP, which executes the GIA from an equation-solving perspective. If equations are "rank-deficient", then we cannot get a unique solution, indicating obvious information discards. Here, for better illustration, we first show CIFAR-10 image reconstructions on LeNet with `batchsize=1`. Ground-truth one-hot labels are provided.
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_l2_f.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_cos_gt.jpg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_cos_f.gif" class="img-fluid rounded z-depth-1" %}
-
-
- Image reconstruction on LeNet with CIFAR-10 dataset when batchsize=1. we show the ground-truth image in the middle and attach the reconstruction process on two sides ($l_2$ loss on the left and cosine similarity loss on the right).
-
-
-
-It is clear that even though both functions could recover the image, there are some pixels not perfectly optimized, indicating the existence of information discards. If we change the batchsize, even if we only slightly enlarge it as `batchsize=2`, such reconstruction ends up with a failure.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_cos_gt.jpg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs1_cos_gt_2.jpg" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos1.gif" class="img-fluid rounded z-depth-1" %}
-
-
-
- Image reconstruction with cosine similarity loss on LeNet and no regularization terms are adopted. In the middle, we show ground-truth images in the batch.
-
-
-For a given network, the size of gradients is fixed. Therefore, with the increase in batchsize, GIA will experience more obvious information discards. This is easy to understand, and researchers designed a few ways to complement this loss.
-## Understanding GIAs from the prior knowledge perspective
-Realizing the information discards, reviewing the recent paper through the prior knowledge perspective may help understand the logic better. To achieve better image reconstruction quality, it is natural to consider the prior knowledge of images as the complement. Here, the prior knowledge could be explained in three aspects.
-
-### Unparameterized regularization terms
-In IG, they utilize the total variance as a regularization because they believe a real image taken from nature should have a small total variance. That is the first prior knowledge term utilized in the gradient matching function, and it turns out to function well. After that, in GradInversion this regularization term is extended to include batch normalization supervision, $$l_2$$ norms and group consistency. This is a stronger prior knowledge implying that a real input image, or batched real images, except for total variance, should also possess lower $$l_2$$ norms, proper intermediate mean and the variance for batch normalization layers. Apart from that, all reconstructions from different random initializations ought to reach a group consistency. These terms are unparameterized, and it is clearly demonstrated in their ablation experiments that these terms matter significantly in reconstructing high-quality images.
-
-To further illustrate the benefits such regulariztaion terms have on the data reconstruction processes, here is an example of adding total variance for `batchsize=2` image reconstruction. The scale of total variance ranges from $$10^{-4}$$ to $$10^{-1}$$.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos_tv0.0001.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos_tv0.001.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos_tv0.01.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos_tv0.1.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos1_tv0.0001.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos1_tv0.001.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos1_tv0.01.gif" class="img-fluid rounded z-depth-1" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/bs2_cos1_tv0.1.gif" class="img-fluid rounded z-depth-1" %}
-
-
-
- Image reconstruction with cosine similarity loss and total variance on LeNet. The scale of the total variance starts from $10^{-4}$ for the very left column to $10^{-1}$ with 10 times as the interval.
-
-
-With identical learning rate, images with higher total variance are reconstructed faster. Because the total variance penalizes obvious distinctions for adjacent pixels, images with higher total variance are also more blurred. On the other side, reconstructions with insufficient total variance fail to generate recognizable images.
-### Generative models
-Keep following the logic that recent works require some other conditions as prior knowledge to reinforce the information discards from gradients, generative models, especially GANs, could serve as a strong tool to encode what "real images" should be. The way to add GAN's generator in gradient matching processes is simple: instead of optimizing direct image pixels, with the generator we could keep the backpropagation way back to the latent space, then alter the latent code as well as the parameters of the generator to produce recovered images. Pre-trained generators naturally encode a likely distribution of the input data, which is a stronger prior knowledge compared with previous unparameterized regularization terms.
-{% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/Picture2.jpg" class="img-fluid" %}
-
-Recent work GIFD extends this method by optimizing GAN network layer-wisely. Instead of directly optimizing GAN weights and the latent vector in one step, GIFD optimizes the intermediate layers iteratively, making such a process more stable. In summary, gradients here serve more as an indicator for attackers to select the best image from distributions modeled by pre-trained GANs.
-
-### End-to-end networks
-Actually, the most intuitive way to conduct a GIA is to design a function that takes gradients as input and then outputs recovered images. For a target network, image-gradient tuples are easy to collect, therefore the prior knowledge could be encoded in such an end-to-end neural network through model training.
-{% include figure.html path="assets/img/2024-05-07-understanding-gradient-inversion-attacks-from-the-prior-knowledge-perspective/Picture3.jpg" class="img-fluid" %}
-
-Here, the neural network resembles a GAN generator which takes in representation vectors and outputs a synthesized image. However, instead of abstract latent codes, such a network receives gradient vectors to generate images. In implementations, Wu et.al utilizes *feature hashing* to reduce the dimension of gradient vectors. For network picking, they use a simple 3-layer MLP to generate flattened images, which is different from widely-used GAN structures.
-However, such a method faces multiple difficulties, such as large input sizes and limited structural flexibility. Even for one specific model, once the model weights are changed, such end-to-end network requires retraining to construct a new mapping from gradients to images. Besides, there is still space for network design. Will the network structure influence image reconstruction performance under identical datasets? How to construct a mapping function from gradients to images with varying batchsize? Could the network find an optimal batchsize after analyzing the gradients? These questions are all worth further exploration.
-
-## Limitation and future directions
-For GIAs that require pre-trained models, the key limitation is the auxiliary dataset. It is kind of unrealistic to claim that the dataset used for pretraining generative models (or end-to-end models) shares the same distribution with the unknown private input data, and possibly, with distinct dataset distribution, the generative performance may experience a drop. Both GIAS and GIFD use GAN with in-distribution auxiliary data to compare with previous state-of-the-art works, and GIFD paper only shows the reconstruction result of distinct distribution data when `batchsize=1` with the same label space. For the most general situation where the attacker has limited knowledge of the potential distribution of the private data, it may be still hard to recover high-quality batched data with generative networks.
-Considering these limitations, it is of great value to explore algorithms to learn some general prior knowledge, especially those robust among different data distributions.
-
-## Conclusions
-1. The existence of information discards in gradient aggregation is the tough challenge of GIAs.
-2. From the prior knowledge perspective, previous GIA works provide three ways to complement information discards.
-3. It may still be hard to recover batched data from gradients with limited knowledge of private data distribution.
\ No newline at end of file
diff --git a/_posts/2024-05-07-understanding-icl.md b/_posts/2024-05-07-understanding-icl.md
deleted file mode 100644
index 8950ecd9..00000000
--- a/_posts/2024-05-07-understanding-icl.md
+++ /dev/null
@@ -1,1231 +0,0 @@
----
-layout: distill
-title: Understanding in-context learning in transformers
-description: We propose a technical exploration of In-Context Learning (ICL) for linear regression tasks in transformer architectures. Focusing on the article Transformers Learn In-Context by Gradient Descent by J. von Oswald et al., published in ICML 2023 last year, we provide detailed explanations and illustrations of the mechanisms involved. We also contribute novel analyses on ICL, discuss recent developments and we point to open questions in this area of research.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-# Anonymize when submitting
-# authors:
-# - name: Anonymous
-# affiliations:
-# name: Anonymous
-
-authors:
- - name: Simone Rossi
- url: "https://scholar.google.com/citations?user=lTt86awAAAAJ&hl=en"
- affiliations:
- name: Stellantis, France
- - name: Rui Yuan
- url: "https://scholar.google.com/citations?hl=en&user=4QZgrj0AAAAJ"
- affiliations:
- name: Stellantis, France
- - name: Thomas Hannagan
- url: "https://scholar.google.com/citations?hl=en&user=u6OFo3YAAAAJ"
- affiliations:
- name: Stellantis, France
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-understanding-icl.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: What is in-context learning?
- subsections:
- - name: From large language models to regression tasks
- - name: Objective of this blog post
- - name: Preliminaries and notations
- subsections:
- - name: Dataset construction and tokenization
- - name: A quick review of self-attention
- - name: Training details
- - name: Transformers can learn any linear function in-context
- subsections:
- - name: Linear self-attention is sufficient
- - name: What is special about linear self-attention?
- subsections:
- - name: Establishing a connection between gradient descent and data manipulation
- - name: Building a linear transformer that implements a gradient descent step
- - name: Experiments and analysis of the linear transformer
- subsections:
- - name: During training a linear transformer implements a gradient descent step
- - name: The effect of the GD learning rate
- - name: Analytical derivation of the best GD learning rate
- - name: If one layer is a GD step, what about multiple layers?
- - name: Is this just for transformers? What about LSTMs?
- - name: Concluding remarks
- subsections:
- - name: What now?
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
-
- .center {
- display: block;
- margin-left: auto;
- margin-right: auto;
- }
-
- .framed {
- border: 1px var(--global-text-color) dashed !important;
- padding: 20px;
- }
-
- d-article {
- overflow-x: visible;
- }
-
- .underline {
- text-decoration: underline;
- }
-
- .todo{
- display: block;
- margin: 12px 0;
- font-style: italic;
- color: red;
- }
- .todo:before {
- content: "TODO: ";
- font-weight: bold;
- font-style: normal;
- }
- summary {
- color: steelblue;
- font-weight: bold;
- }
-
- summary-math {
- text-align:center;
- color: black
- }
-
- [data-theme="dark"] summary-math {
- text-align:center;
- color: white
- }
-
- details[open] {
- --bg: #e2edfc;
- color: black;
- border-radius: 15px;
- padding-left: 8px;
- background: var(--bg);
- outline: 0.5rem solid var(--bg);
- margin: 0 0 2rem 0;
- font-size: 80%;
- line-height: 1.4;
- }
-
- [data-theme="dark"] details[open] {
- --bg: #112f4a;
- color: white;
- border-radius: 15px;
- padding-left: 8px;
- background: var(--bg);
- outline: 0.5rem solid var(--bg);
- margin: 0 0 2rem 0;
- font-size: 80%;
- }
- .box-note, .box-warning, .box-error, .box-important {
- padding: 15px 15px 15px 10px;
- margin: 20px 20px 20px 5px;
- border: 1px solid #eee;
- border-left-width: 5px;
- border-radius: 5px 3px 3px 5px;
- }
- d-article .box-note {
- background-color: #eee;
- border-left-color: #2980b9;
- }
- d-article .box-warning {
- background-color: #fdf5d4;
- border-left-color: #f1c40f;
- }
- d-article .box-error {
- background-color: #f4dddb;
- border-left-color: #c0392b;
- }
- d-article .box-important {
- background-color: #d4f4dd;
- border-left-color: #2bc039;
- }
- html[data-theme='dark'] d-article .box-note {
- background-color: #555555;
- border-left-color: #2980b9;
- }
- html[data-theme='dark'] d-article .box-warning {
- background-color: #7f7f00;
- border-left-color: #f1c40f;
- }
- html[data-theme='dark'] d-article .box-error {
- background-color: #800000;
- border-left-color: #c0392b;
- }
- html[data-theme='dark'] d-article .box-important {
- background-color: #006600;
- border-left-color: #2bc039;
- }
- d-article aside {
- border: 1px solid #aaa;
- border-radius: 4px;
- padding: .5em .5em 0;
- font-size: 90%;
- }
- .caption {
- font-size: 80%;
- line-height: 1.2;
- text-align: left;
- }
----
-
-
-
-
-## What is in-context learning?
-
-
-In-Context Learning (ICL) is the behavior first observed in Large Language Models (LLMs), whereby learning occurs from prompted data without modification of the weights of the model . It is a simple technique used daily and throughout the world by AI practitioners of all backgrounds, to improve generation quality and alignment of LLMs .
-ICL is important because it addresses full-on the once widespread criticism that for all their impressive performance, modern deep learning models are rigid systems that lack the ability to adapt quickly to novel tasks in dynamic settings - a hallmark of biological intelligence.
-By this new form of "learning during inference", Large Language Models have shown that they can be, in some specific sense (once pretrained), surprisingly versatile and few-shot learners.
-
-
-
-
-**Figure 1**: Example of a simple in-context prompt for ChatGPT.
-
-
-Interestingly, it was around the release of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [GPT-3](https://arxiv.org/abs/2005.14165) that researchers observed that an auto-regressive language model pre-trained on enough data with enough parameters was capable of performing arbitrary tasks without fine-tuning, by simply prompting the model with the task with few examples and letting it generate the output.
-In recent months, the research community has started to investigate the phenomenon of ICL in more details, and several papers have been published on the topic.
-
-
-
-**Figure 2**: The number of papers published on the topic of ICL (and transformers) in the last years. Data extracted from [arxiv.org](https://arxiv.org/) on November 16th, 2023. In the last year alone, the number of papers on the topic has increased by more than 200%.
-
-
-
-
-Specifically, since learning processes in biology and machine are often, if not always, understood in terms of iterative optimization, it is natural to ask what kind of iterative optimization is being realized during ICL, and how.
-
-### From large language models to regression tasks
-
-Though ICL is generally regarded as a phenomenon exhibited by LLMs, we now hasten to study it in a non-language, small-scale model that enables more control and where ICL can still be shown to emerge.
-This simpler situation is that of a transformer model trained to regress a set of numerical data points presented in the prompt, with data points generated from a distinct function for each prompt, but where all prompts sample a function from the same general class (i.e. linear) at train and at test time. We will see that to some extent, this simplification allows for a mathematical treatment of ICL.
-
-The following figure gives a visual representation of the ICL setup we will consider in this blog post.
-The model is a generic transformer pre-trained to solve generic linear regression tasks. At inference time, we can give the model a prompt with a new linear regression task, and it is able to solve it with surprisingly good performance.
-
-
-
-
- **Figure 3**: The model is pre-trained to regress linear functions, and frozen during inference. With different context (input points), the model can still recover the exact underlying function. Use the slider to change the linear function to regress.
-
-
-
-
-
-
-
-### Objective of this blog post
-
-The objective of this blog post is to understand how ICL is possible, and to present in an interactive way what is known of its underlying mechanism.
-Specifically, we will analyze the results reported in the paper *Transformers Learn In-Context by Gradient Descent* by J. von Oswald et al. recently published in ICML 2023 , which first showed that a simplified transformer model learns in-context by gradient descent. We will replicate the authors' findings and then we will complement the discussion with a number of additional insights, before pointing to open questions. We hope the reader comes out of this post with a better vision of what *fundamentally* ICL is and the open challenges that remain.
-
-
-
-
-## Preliminaries and notations
-
-First of all we need to agree on a mathematical formalization of in-context learning.
-
-Before we start, let's introduce some notation and color convention that will be used throughout the rest of the blog post.
-We will use the following colors to denote different quantities:
-
-- blue: inputs
-- green: model parameters
-- yellow: output
-
-Vectors will be denoted with bold letters, e.g. $$\mba$$, and matrices with bold capital letters, e.g. $$\mbA$$.
-Additional notation will be introduced in-line when needed.
-
-Formally, let's define $$p(\mbx)$$ as a probability distribution over inputs $$\mbx\in\cX$$ and $$\cH$$ a class of functions $$h: \cX \rightarrow \cY$$.
-You can think of $$\cH$$ as a set of functions that share some common properties, for example, the set of all linear functions, or the set of all functions that can be represented by a neural network with a given architecture.
-Also, let's define $$p(h)$$ as a probability measure over $$\cH$$.
-
-
-
-
- **Figure 4**: Visual representation of various parametric function classes (linear, sinusoidal, shallow neural network). Use the dropdown menu to select the function class.
-
-
-
-
-Following the terminology of the LLM community, let's define a *prompt* $$P$$ of length $$C$$ as a *sequence* of $$2C+1$$ points $$(\mbx_0, h(\mbx_0), \ldots, \mbx_{C-1}, h(\mbx_{C-1}), \mbx_{\text{query}})$$ where inputs ($$\mbx_i$$ and $$\mbx_{\text{query}}$$) are independently and identically drawn from $$p(\mbx)$$, and $$h$$ is drawn from $$\cH$$. In short we will also write $$P_C = \left[\{\mbx_i, h(\mbx_i)\}_{i=0}^{C-1}, \mbx_\text{query}\right]$$.
-
-
-
- **Note**: The expectation in Equation \eqref{eq:in-context-error} is taken over the randomness of the input and the function. This means that we are considering the average performance of the model over all possible inputs and functions in $$\cH$$.
-
-
-
-
-
-
- Additional details on the ICL formalism
-
-We can also define the ICL problem through the lens of statistical learning theory.
-Suppose $$\ell$$ the same per-task loss function as described above.
-Let's define the following loss $$\cL:\cF\rightarrow\bbR$$:
-
-$$
-\begin{equation}
- \cL_C(f) = \mathbb{E}\left[\ell\left(f(P_C), h\left(\mbx_{\text{query}}\right)\right) \right]
-\end{equation}
-$$
-
-Let's define $$f_C$$ as the model that minimizes the loss with $$C$$ in-context examples:
-
-$$
-\begin{equation}
-f_C = \arg\min_{f\in\cF} \cL_C(f)
-\end{equation}
-$$
-
-and $$f_\infty$$ as the model that minimizes the loss with an infinite number of in-context examples:
-
-$$
-\begin{equation}
- f_\infty = \arg\min_{f\in\cF} \cL_\infty(f)
-\end{equation}
-$$
-
-We say that a class of transformer models $$\cF$$ learns in-context for a function class $$\cH$$ if, for any $$\epsilon > 0$$, there exists a model $$f\in\cF$$ such that the following inequality holds:
-
-$$
-\begin{equation}
-\mathbb{P} \left[ \cL( f_C) - \cL( f_\infty) \leq \epsilon \right] \geq 1 - \delta
-\end{equation}
-$$
-
-In other words, the last equation says that a class of transformer models $$\cF$$ learns in-context for a function class $$\cH$$ if, for any $$\epsilon > 0$$, there exists a model $$f\in\cF$$ such that the difference between the loss of the model trained with $$C$$ in-context examples and the loss of the model trained with an infinite number of in-context examples is smaller than $$\epsilon$$ with probability at least $$1-\delta$$.
-
-Additionally, we can look at the consistency property, defined as:
-
-$$
-\begin{equation}
- \lim_{C\rightarrow\infty} \mathbb{P} \left[ \cL( f_C) - \cL( f_\infty) \geq \epsilon \right] = 0
-\end{equation}
-$$
-
-This equation signifies that the difference between the loss of the model trained with $$C$$ in-context examples and the loss of the model trained with an infinite number of in-context examples converges to zero as $$C$$ goes to infinity.
-
-
-
-
-
-### Dataset construction and tokenization
-
-For our setup, we will consider a linear regression problem, where the goal is to learn a linear function $$h_{\mbw}(\mbx) = \mbw^\top\mbx$$, with $$\mbw\in\bbR^D$$, from a set of in-context examples $$\{\mbx_i, \y_i\}_{i=0}^{C-1}$$, where $$\mbx_i\in\bbR^D$$ and $$\y_i\in\bbR$$.
-So $$h_{\mbw} \in \cH$$.
-
-In order to better understand how the prompt is constructed starting from a regression task, let's consider the following visual example:
-
-
-
- **Figure 5**: Visualization of the data construction process, from the regression dataset, to the input prompt and the tokenization.
-
-
-
-
-The figure shows a visual representation of the construction of a single input prompt.
-In particular, we first sample a weight $$\mbw$$ from the distribution $$p(\mbw)$$, and then we sample $$C$$ inputs $$\mbx_i$$ from $$p(\mbx)$$, where $$C$$ is the fixed context size.
-Finally, we compute the corresponding outputs $$\y_i = \mbw^\top\mbx_i$$.
-We consider $$p(\mbx) = \cU(-1, 1)$$, where $$\cU$$ is the uniform distribution, and $$p(\mbw) = \cN(\mbzero, \alpha^2\mbI)$$, where $$\cN$$ is a multivariate Gaussian distribution of dimension $$D$$, with $$0$$ mean and $$\alpha$$ standard deviation.
-
-
-Defining $$c=C+1$$ and $$d=D+1$$, where $$C$$ is the context size and $$D$$ is the input dimension, we can represent the input as a matrix $$\mbE\in\bbR^{d\times c}$$ (also referred to as *token embeddings* or, simply, *embeddings*), where the first $$C$$ columns represent the context inputs $$\mbx_i$$ and output $$\y$$ and the last column represents the query input $$\mbx_{\text{query}}$$ with $$0$$ padding.
-
-
-To construct a batch of regression problems, we just repeat the above procedure $$N$$ times with the fixed context size $$C$$, where $$N$$ is the size of the batch.
-
-
-
-### A quick review of self-attention
-
-In this section we will briefly review the self-attention mechanism, which is the core component of the transformer architecture .
-
-Let $$\mbW^K, \mbW^Q \in \bbR^{d_k\times d}$$, $$\mbW^V \in \bbR^{d_v\times d}$$ and $$\mbW^P \in \bbR^{d \times d_v}$$ the key, query, value and projection weight matrices respectively.
-Given an embedding $$\mbE\in\bbR^{d\times c}$$, the softmax self-attention layer implements the following operation,
-
-$$
-\begin{equation}
-\label{eq:softmax-self-attention}
- f_\text{attn} (\mbtheta_\text{attn}, \mbE) = \mbE + \mbW^P \mbW^V \mbE \sigma\left(\frac{(\mbW^K \mbE)^\top \mbW^Q \mbE}{\sqrt{d}}\right),
-\end{equation}
-$$
-
-with $$\mbtheta_\text{attn}=\{\mbW^K, \mbW^Q, \mbW^V, \mbW^P\}$$, where for simplicity we will consider $$d_k=d_v=d$$, and $$\sigma(\cdot)$$ is the softmax function applied column-wise.
-It's simple to verify that the output dimension of $$f_\text{attn}$$ is the same as the input dimension.
-To simplify further, we can also define the value, key and query matrices as $$\mbV = \mbW^V\mbE$$, $$\mbK = \mbW^K\mbE$$, $$\mbQ = \mbW^Q\mbE$$, respectively.
-
-
-
-
-
-### Training details
-
-
-
- Figure 6: Visualization of the pre-training process. The model is trained to minimize the loss function defined in Equation \eqref{eq:pre-train-loss-expectation}.
-
-
-
-
-Once the dataset is created, we can train the model using the following objective:
-
-$$
-\begin{equation}
-\label{eq:pre-train-loss-expectation}
-\cL(\mbtheta) = \mathbb{E}\left\|f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) - \y_{\text{query}}\right\|^2,
-\end{equation}
-$$
-
-where the expectation is taken over $$p(\mbx)$$ and $$p(\mbw)$$, with $$h_{\mbw}(\mbx) = \mbw^\top\mbx$$.
-Note that the output of the model is a sequence of $$C+1$$ values, i.e. same as the input prompt, and the loss is computed only on the last value of the sequence, which corresponds to the predicted query output $$\widehat\y_{\text{query}}$$.
-Specifically, for reading out just the prediction for $$\mbx_{\text{query}}$$, we multiply again by $$-1$$ this last value.
-Note that this choice is completely transparent during model training, as it is equivalent to simply changing the sign of a few elements in the projection weight matrix $$\mbW^P$$.
-The reason for this will be clear in the following sections.
-At each training iteration, we replace the expectation with an empirical average over a batch of $$N$$ regression tasks, each made of a different set of context points $$\{\mbx_i^{(n)}, \y_i^{(n)}\}_{i=0}^{C-1}$$, and a query input/target pain, $$\mbx^{(n)}_\text{query}$$ and $$\y^{(n)}_{\text{query}}$$, respectively.
-Note that because of the on-line creation of the dataset, during training the model will never see the same regression task twice.
-
-
-
-
- Code for the transformer loss
- This is the code for the loss computation, including the reading out of the query output.
-
-
-
-
-
-
-
-
-
-
-## Transformers can learn any linear function in-context
-
-
-With all the preliminaries and notations in place, we can now start to analyze some results regarding the ability of transformers to learn linear functions in-context.
-One of the first papers that studied the ability of transformers to learn linear functions in-context is *What Can Transformers Learn In-Context? A Case Study of Simple Function Classes* by S. Garg et al .
-We will first replicate their results using a simpler configuration: using only up to 5 layers, single head attention, with 64 embedding units for a total number of parameters of 17K, 34K, 50K, 67K, 84K respectively.
-
-
-
-
-
-In the figure below, we report the in-context test loss (as defined in Equation \eqref{eq:in-context-test-loss}) for each model configuration, for various context sizes $$C$$, from 2 to 100.
-
-
-
-
-
- Figure 7: Transformers can learn linear functions in-context, reasonably well. The test loss decreases as the context size increases, and as the number of layers increases.
-
-
-
-
-The experiment above shows that the test loss diminishes for larger context sizes, and also as the number of layers increases. These two main effects are clearly expected, as consequences of more data points and more compute, respectively, and they replicate the findings of Garg et al .
-
-### Linear self-attention is sufficient
-
-From this point, we will depart from the classic softmax self-attention layer, and restrict our study to a linear self-attention layer, which is the setting considered in the paper of J. von Oswald et al .
-Recently, a number of papers have drawn connections between linear transformers and *Fast Weight Programmers* and have
-shown that linearized self-attention layers can be used to replace the softmax self-attention layer in transformers, with the advantage of reducing the computational complexity of the attention operation .
-
-A **linear self-attention** updates embeddings $$\mbE$$ as follows:
-
-$$
-\begin{equation}
- f_\text{linattn} (\mbtheta_\text{linattn}, \mbE) = \mbE + \frac{\mbW^P \mbV\left(\mbK^\top \mbQ \right)}{\sqrt{d}},
-\end{equation}
-$$
-
-with $$\mbV, \mbK, \mbQ$$ being the value, key and query defined right after Equation \eqref{eq:softmax-self-attention}.
-
-Now, to analyze if a linear self-attention layer is sufficient to learn linear functions in-context, we can use the same experimental setup as before, but replacing the softmax self-attention layer with a linear self-attention layer.
-
-Additionally, we also strip down the transformer to its bare minimum, i.e. we remove the normalization, the embedding layer, the feed-forward layer, and only use a single head. The only remaining component is the linear self-attention layer.
-Therefore, in the following we use the term "linear transformer" to refer to this simplified model.
-
-
- Code for the linear transformer
- This is the code for the linear transformer, without any normalization, embedding, etc with a single head
-
-
-
-
-We test the linear transformer on the same dataset setup as before, and we will use the same number of layers as before, i.e. 1, 2, 3, 4, 5.
-
-
-
-
- **Figure 8**: Linear transformers can also learn linear functions in-context, reasonably well. The test loss decreases as the context size increases, and as the number of layers increases.
-
-
-
-
-
-
-
-
-
-
-## What is special about linear self-attention?
-
-From the previous section we have seen that a linear self-attention layer is sufficient to learn linear functions in-context.
-In this section we will try to understand why this is the case, starting from a review of least-squares regression and gradient descent.
-
-### Establishing a connection between gradient descent and data manipulation
-
-In this section, we establish an important connection that will be fundamental to understand the mechanism behind ICL with linear self-attention. To do so we need to start from a simple linear regression problem, and we will show that we can achieve the same loss after *one* gradient step by changing the inputs and the targets, and keeping the weights fixed.
-
-
-
-The loss for a linear regression problem is defined as:
-$$
-\begin{equation}
-\label{eq:linear-regression-loss}
-\cL_{\text{lin}}\left(\mbw, \{\mbx_i, {\y}_i\}_{i=0}^{C-1}\right) = \frac 1 {2C} \sum_{i=0}^{C-1} (\mbw^\top\mbx_i - \y_i)^2
-\end{equation}
-$$
-
-where $$\mbw\in\bbR^D$$, $$\mbx_i\in\bbR^D$$ and $$\y_i\in\bbR$$. With a given learning rate $$\eta$$, the gradient descent update is $$\mbw \leftarrow \mbw - \Delta \mbw$$, where
-$$
-\begin{equation}
-\label{eq:linear-regression-gd-gradient}
-\Delta \mbw = \eta \nabla_{\mbw} \cL_{\text{lin}}\left(\mbw, \{\mbx_i, {\y}_i\}_{i=0}^{C-1}\right) = \frac{\eta}{C} \sum_{i=0}^{C-1} \left(\mbw^\top\mbx_i - \y_i\right)\mbx_i
-\end{equation}
-$$
-The corresponding loss (after the update) is:
-$$
-\begin{equation}
-\label{eq:linear-regression-loss-after-gd}
-\cL_{\text{lin}}\left(\mbw - \Delta \mbw, \{\mbx_i, {\y}_i\}_{i=0}^{C-1}\right) = \frac 1 {2C} \sum_{i=0}^{C-1} \left(\mbw^\top\mbx_i - \y_i - \Delta \mbw^\top\mbx_i\right)^2
-\end{equation}
-$$
-
-It is trivial to see that if we now define $$\widehat{\mbx}_i = \mbx_i$$ and $$\widehat{\y}_i = \y_i + \Delta \mbw^\top\mbx_i$$, we can compute Equation \eqref{eq:linear-regression-loss} with the new inputs and targets, i.e. $$\cL_{\text{lin}}(\mbw, \{\widehat{\mbx}_i, \widehat{\y}_i\}_{i=0}^{C-1})$$, which is the same as the loss after the gradient descent update (Equation \eqref{eq:linear-regression-loss-after-gd}).
-
-
-
-
-### Building a linear transformer that implements a gradient descent step
-
-As we just saw, the starting intuition is that we can build a gradient step on the linear regression loss by manipulating the inputs and the targets.
-This is the *key insight* of Oswald et al. that allows us to draw a connection between the gradient descent dynamics and the linear transformer.
-
-Before stating the main result, recall the definitions of value, key and query as $$\mbV = \mbW^V\mbE$$, $$\mbK = \mbW^K\mbE$$, and $$\mbq_j = \mbW^Q\mbe_j$$.
-
-
-
-**Main result**:
-Given a 1-head linear attention layer and the tokens $$\mbe_j = (\mbx_j, \y_j)$$, for $$j=0,\ldots,C-1$$, we can construct key, query and value matrices $$\mbW^K, \mbW^Q, \mbW^V$$ as well as the projection matrix $$\mbW^P$$ such that a transformer step on every token $$\mbe_j \leftarrow (\mbx_i, \y_{i}) + \mbW^{P} \mbV \mbK^{T}\mbq_{j}$$ is identical to the gradient-induced dynamics $$\mbe_j \leftarrow (\mbx_j, \y_j) + (0, -\Delta \mbW \mbx_j)$$. For the query data $$(\mbx_{\text{query}}, \y_{\text{query}})$$, the dynamics are identical.
-
-
-
-
-For notation, we will identify with $$\mbtheta_\text{GD}$$ the set of parameters of the linear transformer that implements a gradient descent step.
-
-
-
-Nonetheless, we can construct a linear self-attention layer that implements a gradient descent step and a possible construction is in block form, as follows.
-
-$$
-\begin{align}
-\mbW^K = \mbW^Q = \left(\begin{array}{@{}c c@{}}
- \mbI_D & 0 \\
- 0 & 0
-\end{array}\right)
-\end{align}
-$$
-
-with $$\mbI_D$$ the identity matrix of size $$D$$, and
-
-$$
-\begin{align}
-\mbW^V = \left(\begin{array}{@{}c c@{}}
- 0
- & 0 \\
- \mbw_0^\top &
- -1
-\end{array}
- \right)
-\end{align}
-$$
-
-with $$\mbw_0 \in \bbR^{D}$$ the weight vector of the linear model and $$\mbW^P = \frac{\eta}{C}\mbI_{d}$$ with identity matrix of size $$d$$.
-
-
-
-If you are interested in the proof of construction for the GD-equivalent transformer, you can find it in the following collapsible section.
-
-
-
- Proof of construction for the GD-equivalent transformer
-
-To verify this, first remember that if $$\mbA$$ is a matrix of size $$N\times M$$ and $$\mbB$$ is a matrix of size $$M\times P$$,
-
-$$
-\begin{align}
-\mbA\mbB = \sum_{i=1}^M \mba_i\otimes\mbb_{,i}
-\end{align}
-$$
-
-where $$\mba_i \in \bbR^{N}$$ is the $$i$$-th column of $$\mbA$$, $$\mbb_{,i} \in \bbR^{P}$$ is the $$i$$-th row of $$\mbB$$, and $$\otimes$$ is the outer product between two vectors.
-
-It is easy to verify that with this construction we obtain the following dynamics
-
-$$
-\begin{align}
-\left(\begin{array}{@{}c@{}}
-\mbx_j\\
-\y_j
-\end{array}\right)
-\leftarrow &
-\left(\begin{array}{@{}c@{}}
-\mbx_j\\
-\y_j
-\end{array}\right) + \mbW^{P} \mbV \mbK^{T}\mbq_{j} = \mbe_j + \frac{\eta}{C} \sum_{i={0}}^{C-1} \left(\begin{array}{@{}c c@{}}
-0
-& 0 \\
-\mbw_0 &
--1
-\end{array}
-\right)
-\left(\begin{array}{@{}c@{}}
-\mbx_i\\
-\y_i
-\end{array}\right)
-\otimes
-\left(
-\left(\begin{array}{@{}c c@{}}
-\mbI_D & 0 \\
-0 & 0
-\end{array}\right)
-\left(\begin{array}{@{}c@{}}
-\mbx_i\\
-\y_i
-\end{array}\right)
-\right)
-\left(\begin{array}{@{}c c@{}}
-\mbI_D & 0 \\
-0 & 0
-\end{array}\right)
-\left(\begin{array}{@{}c@{}}
-\mbx_j\\
-\y_j
-\end{array}\right)\\
-&= \left(\begin{array}{@{}c@{}}
-\mbx_j\\
-\y_j
-\end{array}\right) + \frac{\eta}{C} \sum_{i={0}}^{C-1} \left(\begin{array}{@{}c@{}}
-0\\
-\mbw_0^\top \mbx_i - \y_i
-\end{array}\right)
-\otimes
-\left(\begin{array}{@{}c@{}}
-\mbx_i\\
-0
-\end{array}\right)
-\left(\begin{array}{@{}c@{}}
-\mbx_j\\
-0
-\end{array}\right) =
-\left(\begin{array}{@{}c@{}}
-\mbx_j\\
-\y_j
-\end{array}\right) + \left(\begin{array}{@{}c@{}}
-0\\
-
-- \frac{\eta}{C}\sum_{i=0}^{C-1} \left( \left(\mbw_0^\top\mbx_i - \y_i\right)\mbx_i\right)^\top \mbx_j
- \end{array}\right).
- \end{align}
-$$
-
-Note that the update for the query token $$(\mbx_{\text{query}}, \textcolor{output}{0})$$ is identical to the update for the context tokens $$(\mbx_j, \y_j)$$ for $$j=0,\ldots,C-1$$.
-
-
-
-
-
-
-## Experiments and analysis of the linear transformer
-
-Now let's do some experiments to verify the theoretical results.
-We will work within the same experimental setup as before with the same dataset construction, training procedure and testing procedure.
-In this first section, we consider a linear transformer with a single layer, and the transformer built as described in the previous section (the GD-equivalent transformer), i.e. with a linear self-attention layer that implements a gradient descent step.
-
-### During training, a linear transformer learns to implement a gradient descent step
-
-We now study the evolution of the test loss of a linear transformer during training $$\cL(\mbtheta)$$, and compare it to the loss of a transformer implementing a gradient descent step $$\cL(\mbtheta_\text{GD})$$.
-
-
-
- **Figure 9**: The loss of a trained linear transformer converges to the loss of a transformer implementing a gradient descent step on the least-squares regression loss with the same dataset. Use the slider to change the context size.
-
-
-
-
-
-
-
-
-
-
-
-Although an empirical proof of such a functional equivalence would require to check the outputs for all possible test samples, we can try to gather more evidence by considering more closely the computations that unfold in the linear transformer during one pass.
-
-To better understand the dynamics of the linear transformer, we now study the evolution of a few metrics during training (the *L2 error for predictions*, the *L2 error for gradients* and the *cosine similarity* between models).
-
-
-Metrics details
-
-The metrics introduced above are defined as follows:
-
-- **L2 error (predictions)** measures the difference between the predictions of the linear transformer and the predictions of the transformer implementing a gradient descent step and it is defined as $$\left\|f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) - f\left(\mbtheta_\text{GD}, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) \right\|^2$$;
-
-- **L2 error (gradients w.r.t. inputs)** measures the difference between the gradients of the linear transformer and the gradients of the transformer implementing a gradient descent step and it is defined as $$\left\|\nabla_{\mbx_\text{query}} f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) - \nabla_{\mbx_\text{query}} f\left(\mbtheta_\text{GD}, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) \right\|^2$$;
-
-- **Model cosine similarity (gradients w.r.t. inputs)** measures the cosine similarity between the gradients of the linear transformer and the gradients of the transformer implementing a gradient descent step and it is defined as $$\cos\left(\nabla_{\mbx_\text{query}} f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right), \nabla_{\mbx_\text{query}} f\left(\mbtheta_\text{GD}, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right)\right)$$.
-
-
-
-
-
-
-
- **Figure 10**: Comparison between the linear transformer and the GD-transformer during training. The predictions of the linear transformer converge to the predictions of the GD-transformer and the gradients of the linear transformer converge to the gradients of the GD-transformer. Use the slider to change the context size.
-
-
-
-
-
-From this figure, we see that the predictions of the linear transformer converge to the predictions of the GD-transformer, and the gradients of the linear transformer converge to the gradients of the GD-transformer.
-Notably, this is true for all context sizes, though the convergence is faster for larger $$C$$.
-
-As a final visualization, we can also look at the evolution of the gradients of the linear transformer during training, as shown in the figure below. In this animation, we take six different regression tasks and we plot the gradients of the linear transformer during training and the exact gradients of the least-squares regression loss.
-
-
-
-
- Figure 11: Animation of the gradients of the linear transformer during training. The loss landscape visualized is the least-squares regression loss (each task has its own loss). The gradients of the linear transformer are shown in red, while the gradients of the least-squares regression loss are shown in orange.
-
-
-
-
-To reiterate, the loss landscape visualized is the least-squares regression loss and each task is a different linear regression problem with a different loss landscape.
-Once more, this is a visualization that the linear transformer is not learning a single regression model, but it is learning to solve a linear regression problem.
-
-### The effect of the GD learning rate
-
-Next, we study the effect of the GD learning rate on the test loss of the GD-equivalent transformer.
-We believe this is an important point of discussion which was covered only briefly in the paper.
-
-
-
-
-
-
-
-Indeed, this is the same procedure we have used to find the optimal GD learning rate for our previous experiments.
-We now show what happens if we use a different GD learning rate than the one found with line search.
-In the following experiment, we visualize this behavior, by plotting the metrics described above for different values of the GD learning rate.
-
-
-
-
- Figure 12: Effect of the GD learning rate on the alignment between the linear transformer and the GD-transformer. The agreement between the two is maximized for a specific GD learning rate, which must be found by line search. Use the slider to manually change the GD learning rate.
-
-
-
-
-
-
-
-### Analytical derivation of the best GD learning rate
-
-It turns out that having a line search to find the best GD learning rate is not necessary.
-
-
-
-The analytical solution is provided below with its derivation reported in the collapsible section immediately following.
-
-
-Analytical derivation of the best GD learning rate
-
-We are interested in finding the optimal learning rate for the GD-transformer, which by construction (see main Proposition), is equivalent to finding the optimal GD learning rate for the least-squares regression problem. Consequently, the analysis can be constructed from the least-squares regression problem \eqref{eq:linear-regression-loss}.
-
-Recall the GD update of the least-squares regression in \eqref{eq:linear-regression-gd-gradient} without taking into account of the learning rate. That is,
-
-$$
-\begin{equation}
-\label{eq:linear-regression-gd-gradient-no-lr}
-\Delta \mbw = \nabla_{\mbw}
-\cL_{\text{lin}}\left(\mbw, \{\mbx_i, \y_i\}_{i=0}^{C-1}\right) =
-\frac{1}{C} \sum_{i=0}^{C-1} \left(\mbw^\top\mbx_i - \y_i\right)\mbx_i.
-\end{equation}
-$$
-
-Now we consider the test loss of the least-squares regression defined as
-
-$$
-\begin{equation}
-\cL_\mathrm{lin, te}(\{\mbw^{(n)}\}_{n=0}^{N-1}) = \frac{1}{N} \sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query})^2,
-\end{equation}
-$$
-
-where $$N$$ is the number of the queries, which is the same number of the regression tasks of the in-context test loss dataset.
-Similar to \eqref{eq:linear-regression-loss-after-gd}, after one step of the GD update \eqref{eq:linear-regression-gd-gradient-no-lr}, the corresponding test loss becomes
-
-$$
-\begin{align}
-&\quad \ \ \cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1}) \nonumber \\
-&= \frac{1}{N} \sum_{n=0}^{N-1} \left((\mbx^{(n)}_\text{query})^\top (\mbw^{(n)} - \eta \Delta \mbw^{(n)}) - \y^{(n)}_\text{query}\right)^2 \nonumber \\
-&= \frac{1}{N} \sum_{n=0}^{N-1} \left((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query} - \eta (\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)} \right)^2 \nonumber \\
-&= \frac{\eta^2}{N} \sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)})^2
-+ \cL_\mathrm{lin, te}(\{\mbw^{(n)}\}_{n=0}^{N-1}) \nonumber \\
-&\quad \ - \frac{2\eta}{N} \sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query})(\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)}. \label{eq:loss_query_W1}
-\end{align}
-$$
-
-One can choose the optimum learning rate $$\eta^*$$ such that $$\cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1})$$ achieves its minimum with respect to the learning rate $$\eta$$. That is,
-
-$$
-\begin{align}
-\eta^* \in \arg\min_{\eta > 0} \cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1}).
-\end{align}
-$$
-
-To obtain $$\eta^*$$, it suffices to solve
-
-$$
-\begin{align}
-\nabla_\eta \cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1}) = 0.
-\end{align}
-$$
-From \eqref{eq:loss_query_W1} and plugging $$\Delta w^{(n)}$$ in \eqref{eq:linear-regression-gd-gradient-no-lr}, we obtain
-$$
-\begin{align}
-\eta^* &= \frac{\sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query})(\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)} }
-{\sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)})^2} \nonumber \\
-&= C \frac{\sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query}) \sum_{i=0}^{C-1} ((\mbw^{(n)})^\top \mbx_i^{(n)} - \y_i^{(n)})(\mbx_i^{(n)})^\top \mbx^{(n)}_\text{query}}
-{\sum_{n=0}^{N-1} \left( \sum_{i=0}^{C-1} ((\mbw^{(n)})^\top \mbx_i^{(n)} - \y_i^{(n)})(\mbx_i^{(n)})^\top \mbx^{(n)}_\text{query} \right)^2}.
-\end{align}
-$$
-Finally, for the initialization $$\mbw^{(n)} = 0$$ for $$n = 0, \ldots, N-1$$, the optimal learning rate can be simplified to be
-$$
-\begin{align}
-\eta^* = C \frac{\sum_{n=1}^{N-1} \y^{(n)}_\text{query} \left(\sum_{i=0}^{C-1}\left( \y^{(n)}_i{\left(\mbx^{(n)}_i\right)}^\top \mbx_\text{query}^{(n)}\right)\right)
-}{\sum_{n=1}^{N-1} \left(\sum_{i=0}^{C-1}\left(\y^{(n)}_i {\left(\mbx^{(n)}_i\right)}^\top \mbx_\text{query}^{(n)}\right)\right)^2}.
-\end{align}
-$$
-
-
-
-
-#### Some comments on the analytical solution
-
-This derivation of the optimal GD learning rate $$\eta^*$$ agrees well with the line search procedure (up to the numerical precision of the line search procedure itself).
-While this is expected, let's take a moment to understand why this is the case.
-
-1. The analytical solution is obtained starting from the linear regression loss, while the line search procedure using the loss $$\cL(\mbtheta_\text{GD})$$ defined in Equation \eqref{eq:pre-train-loss-expectation}.
-However, the two losses are equivalent by construction, hence the two procedures are equivalent.
-
-1. Because the construction of the GD transformer is not unique, it's not easy to see the effect of the GD learning rate once we compare it with the trained linear transformer.
-Recall that due to its parametrization, the linear transformer does not have an explicit $$\eta$$ parameter, which it can be absorbed in any of the weight matrices in the linear self-attention layer.
-Yet, the linear transformer converges to the exact same loss of the GD-transformer for the optimal GD learning rate $$\eta^*$$.
-This is expected because fundamentally the loss function used for the line search and the one used for the analytical solution is equivalent to the loss in Equation \eqref{eq:pre-train-loss-expectation} used during the transformer training.
-
-
-
-Said differently, what we did in two steps for the GD-transformer (first build the $$\mbW^K, \mbW^Q, \mbW^V$$ matrices, then find the optimal GD learning rate) is done implicitly during the training of the linear transformer.
-
-The following table summarizes the three different procedures we have discussed so far.
-
-| | Loss function | GD learning rate |
-| ------------------------ | ------------------------------------ | -------------------------------------------- |
-| Least-squares regression | $$\cL_\text{lin}(\mbw-\Delta \mbw)$$ | Explicit $$\eta^*$$ by analytical solution |
-| GD-transformer | $$\cL(\mbtheta_\text{GD})$$ | Explicit $$\eta^*$$ by line search |
-| Linear transformer | $$\cL(\mbtheta)$$ | Implicit $$\eta^*$$ by training $$\mbtheta$$ |
-
-
-Finally, one comment on the computational complexity of the two procedures.
-It doesn't come as a surprise that the analytical solution is faster to compute than the line search: the line search requires on average 10 seconds to find the optimal GD learning rate, while the analytical solution requires only 10 milliseconds (both with JAX's JIT compilation turned on, run on the same GPU).
-
-
-
-
-
-### If one layer is a GD step, what about multiple layers?
-
-It is only natural to ask if the same behavior is observed for a linear transformer with multiple layers.
-In particular, if we take a trained linear transformer with a single layer (which we now know it implements a gradient descent step) and we repeat the same layer update multiple times recursively, will we observe the same behavior?
-
-As we now show in the following experiment, the answer is no.
-In fact, the test loss for both the linear transformer and the transformer implementing a gradient descent step diverges as we increase the number of layers.
-
-To stabilize this behavior, we use a dampening factor $$\lambda$$, which is a scalar in $$[0, 1]$$, and we update the linear transformer as follows:
-
-$$
-\begin{equation}
-\label{eq:linear-transformer-update}
-\mbE^{(l+1)} = \mbE^{(l)} + \lambda \mbW^P \mbV\left(\mbK^\top \mbQ \right),
-\end{equation}
-$$
-
-where $$\mbE^{(l)}$$ is the embedding matrix at layer $$l$$, and $$\mbW^P, \mbV, \mbK, \mbQ$$ are the projection, value, key and query matrices as defined before.
-Effectively, this is equivalent to applying a gradient descent step with scaled learning rate.
-
-
- Code for the recurrent transformer
- This is the code for the recurrent transformer, with a dampening factor \(\lambda\). Note that the attention layer is the same as before, but we now apply it multiple times.
-
-
-
-
-
-
-
-
- Figure 13: A pre-trained transformer with a single layer can be used recursively to implement multiple gradient descent steps, after applying a dampening factor \(\lambda\) to the self-attention layer. Use the slider to change the value of \(\lambda\).
-
-
-
-
-
-Note that in the original paper, the authors suggest that a dampening factor of $$\lambda=0.75$$ is generally sufficient to obtain the same behavior as a single layer linear transformer. As we can see from the figure above, in our investigations we do not find this to be the case.
-In our experiments, we see that we need at least $$\lambda=0.70$$ to obtain the same behavior as a single layer linear transformer, which suggests that the effect of the dampening factor can vary.
-
-
-
-
-## Is this just for transformers? What about LSTMs?
-
-Transformers are not the only architecture that can sequence-to-sequence models .
-Notably, *recurrent neural networks* (RNNs) have been used for a long time to implement sequence-to-sequence models, and in particular *long short-term memory* (LSTM) networks have been shown to be very effective in many tasks .
-
-Indeed, from a modeling perspective, nothing prevents us from using a LSTM to implement in-context learning for regression tasks.
-In fact, we can use the same experimental setup as before, but replacing the transformer with a LSTM.
-The main architectural difference between a LSTM and a transformer is that LSTM layers are by-design causal, i.e. they can only attend to previous tokens in the sequence, while transformers can attend to any token in the sequence.
-While for some tasks where order matters, like language modeling, this is a desirable property, for the regression task we are considering this is not the case, since the input sequence is not ordered (i.e. shuffling the input sequence does not change the output of the linear regression model).
-For this reason, together with the classic uni-directional LSTM, we will also consider a bi-directional LSTM, which can attend to both previous and future tokens in the sequence.
-This provides a fair comparison between the LSTMs and the transformers.
-
-In this first experiment, we analyze the performance of the uni-directional and the bi-directional LSTM to learn linear functions in-context.
-Note that because of the intrinsic non-linear nature of the LSTM layers, we cannot manually construct a LSTM that implements a gradient descent step, as we did for the transformer.
-Nonetheless, we can still compare the LSTMs with the GD-equivalent transformer (which we now know it implements a gradient descent step on the least-squares regression loss).
-
-
-
- Figure 14: LSTMs cannot learn linear functions in-context as effectively as transformers and bi-directional LSTMs can learn linear functions in-context better than uni-directional LSTMs. Use the slider to change the number of layers.
-
-
-
-
-
-
-In this figure we can see that a single layer LSTM is not sufficient to learn linear functions in-context. For the uni-directional LSTM, we see that the test loss is always higher than the test loss of the transformer implementing a gradient descent step, even if we increase the number of layers.
-On the contrary, for the bi-directional LSTM, we see that the test loss approaches that of the GD-equivalent transformer as we increase the number of layers.
-
-The poor performance of the uni-directional LSTM is not surprising. Additional evidence is provided in the figure below, where, as we did for the transformer, we plot the L2 error (predictions), the L2 error (gradients w.r.t. inputs) and the model cosine similarity (gradients w.r.t. inputs) comparing the LSTM with the GD-equivalent transformer.
-
-
-
-
-
- Figure 15: Uni-directional LSTMs cannot learn linear functions in-context as effectively as transformers. Use the slider to change the number of layers.
-
-
-
-
-Regardless of the number of layers, we see that the uni-directional LSTM is not implementing a gradient descent step, as the L2 error (predictions) and the L2 error (gradients w.r.t. inputs) do not converge to 0, and the model cosine similarity (gradients w.r.t. inputs) remains well below 1.
-The picture changes for the bi-directional LSTM, as we can see in the figure below.
-
-
-
-
-
-
- Figure 16: Bi-directional LSTMs align better with the GD-equivalent transformer as we increase the number of layers. Use the slider to change the number of layers.
-
-
-
-
-
-While for a single layer, we can comfortably say that also the bi-directional LSTM is not equivalent to a GD step, for **2 or more layers** we cannot reject the hypothesis that the bi-directional LSTM is equivalent to a GD step (use the slider to change the number of layers in Figure 14-16).
-Note that if we compare this result with **Figure 10**, while we don't see exactly the same behavior (e.g. cosine similarity a bit lower than 1), it is still remarkably similar.
-This is not a conclusive result but it is interesting to see that the bi-directional LSTM can learn linear functions in-context *similarly* to a transformer implementing a gradient descent step.
-
-
-
-
-## Concluding remarks
-
-In this blog post, we have presented a series of experiments to understand the mechanistic behavior of transformers and self-attention layers through the lens of optimization theory.
-In particular, we analyze the results of the paper *Transformers Learn In-Context by Gradient Descent*, replicating some of the experiments and providing additional insights.
-In particular, we also derive an analytical solution for the best GD learning rate, which is faster to compute than the line search procedure used in the original paper.
-Finally, we also empirically show that LSTMs behave differently than transformers, and that single layer LSTMs do not in fact implement a gradient descent step.
-The results on deep LSTMs are less conclusive, showing behavior similar to the GD-equivalent transformer, but not exactly the same.
-
-
-
-### What now?
-
-The results presented in this blog post, while confirming the main findings of the original paper, also raise a number of questions and suggest possible future research directions.
-
-1. To reiterate, what we have done so far is to try to understand the behavior of transformers and self-attention layers through the lens of optimization theory.
-This is the common approach in the literature, including very recent additions , and it is the approach we have followed in this blog post.
-However, this can pose significant limitations regarding the generalization of the results and the applicability of the findings to other architectures (notably, causal self-attention layers).
-Phenomena like the emergent abilities or the memorization of large language models may indicate that fundamentally different mechanisms are at play in these models, and that the optimization perspective might not be sufficient to understand them.
-
-1. On the other hand, nothing prevents us from working in the opposite direction, i.e. to start from specific learning algorithms and try to design neural networks that implement them.
-From an alignment perspective, for example, this is desirable because it allows us to start by designing objective functions and learning algorithms that are more interpretable and more aligned with our objectives, rather than starting from a black-box neural network and trying to understand its behavior.
-In this quest, the developing theory of mesa-optimization can represent a useful framework to understand these large models .
-
-1. Finally, we want to highlight that the main results shown in this blog post are consequences of the simplified hypothesis and the experimental setup we have considered (linear functions, least-squares regression loss, linear self-attention layers).
-In an equally recent paper , for example, the authors take a completely different route: by representing transformers as interacting particle systems, they were able to show that tokens tend to cluster to limiting objects, which are dependent on the input context.
-This suggests that other interpretations of the behavior of transformers are not only possible, but also possibly necessary to understand how these models learn in context.
-
-
-
-
-
-
-## Appendix
-
-
-### Connection with meta-learning
-
-From a learning point-of-view, ICL seems closely related to the definition of *meta-learning*, where the goal is to learn a model that can quickly adapt to new tasks .
-If we consider the function class $$\cH$$ as an uncountable set of tasks, then the model is learning *how* to adapt to new function by observing a few examples of that function.
-The main difference between the classic formulation of meta-learning and the formulation of in-context learning is that in the latter case the model is not allowed to change its weights, but it can only change its internal state (e.g., the hidden activations of the transformer).
-Indeed, meta-learning relies on the assumption that the model can quickly adapt to new tasks by changing its weights (i.e. by taking one or more gradient steps).
-
-#### Connection with MAML (Model-Agnostic Meta-Learning)
-
-In the meta-learning setup, we need to define a generic base-model $$m:\cX\rightarrow\cY$$ parameterized with $$\mbw$$ that works at sample-level.
-Let's now relax the assumption of $$\cF$$ as a class of transformer models and let's build $$f$$ as follows:
-
-$$
-\begin{equation}
-\label{eq:meta-learning-model}
-f(\mbw, P_C) = m\left(\mbw - \eta \nabla_{\mbw} \sum_{i=0}^{C-1}\ell\left(m(\mbw,\mbx_i), \y_i\right),\mbx_\text{query}\right)
-\end{equation}
-$$
-
-where $$\eta$$ is the learning rate of the meta-learning algorithm.
-Equation \eqref{eq:meta-learning-model} represents the inner optimization loop in a simplified version of the MAML algorithm , where the model is updated with a single gradient step.
-
-Putting all together, we can define the meta-learning loss as:
-
-$$
-\begin{equation}
-\label{eq:meta-learning-loss}
-\cL_{\text{MAML}}(\mbw) = \mathbb{E}\left[\ell\left(f(\mbw, P_C), h\left(\mbx_{\text{query}}\right)\right) \right]
-\end{equation}
-$$
-
-which now is optimized w.r.t. the base-model's parameters $$\mbw$$.
-
-The resemblance between Equation \eqref{eq:in-context-error} and Equation \eqref{eq:meta-learning-loss} is now clear and it justifies the interpretation of in-context learning as a form of meta-learning.
-
-In particular, it is interesting to study under which conditions the model $$f$$ defined in Equation \eqref{eq:meta-learning-model} is equivalent to a transformer model.
-
-
-
-
-### Testing details
-
-In order to test whether a model learns in-context for a given function class, we need to define a dataset of in-context examples.
-In this case we will only consider in-distribution test examples, i.e. examples that are drawn from the same distribution as the training examples.
-Specifically, we will use the same distribution for the test inputs $$p(\mbx)$$ and the same distribution for the test weights $$p(\mbw)$$ as those used during training.
-Various papers have also considered the case where the inputs are drawn from a different distribution than the training examples (also known as out-of-distribution, or OOD), but to keep the discussion relevant we will only consider the in-distribution case.
-
-We define the in-context test loss as:
-
-$$
-\begin{equation}
-\label{eq:in-context-test-loss}
-\cL_\text{te}(\mbtheta) = \frac 1 N \sum_{n=0}^{N-1} \left\|f\left(\mbtheta, \left[\{\mbx_i^{(n)}, \y_i^{(n)}\}_{i=0}^{C-1}, \mbx^{(n)}_\text{query}\right]\right) - \y^{(n)}_{\text{query}}\right\|^2.
-\end{equation}
-$$
-
-Specifically, we will consider a fixed dataset of $$N=10000$$ regression tasks, where each task is defined by a set of in-context examples $$\{\mbx_i^{(n)}, \y_i^{(n)}\}_{i=0}^{C-1}$$ and a query pair $$\mbx^{(n)}_{\text{query}}$$ and $$\y^{(n)}_{\text{query}}$$.
-
-
-
\ No newline at end of file
diff --git a/_posts/2024-05-07-update-frequency-in-mbrl.md b/_posts/2024-05-07-update-frequency-in-mbrl.md
deleted file mode 100644
index 613ce4bb..00000000
--- a/_posts/2024-05-07-update-frequency-in-mbrl.md
+++ /dev/null
@@ -1,409 +0,0 @@
----
-layout: distill
-title: Fair Model-Based Reinforcement Learning Comparisons with Explicit and Consistent Update Frequency
-# description: Model-based reinforcement learning has emerged as a promising approach to achieve both state-of-the-art performance and sample-efficiency.However, ensuring fair benchmark comparisons can be challenging due to the implicit design choices made by the different algorithms. This article focuses on one such choice, the update frequency of the model and the agent. While the update frequency can sometimes be optimized to improve performance, real-world applications often impose constraints, allowing updates only between deployments on the actual system. We emphasize the need for more evaluations using consistent update frequencies across different algorithms. This will provide researchers and practitioners with clearer comparisons under realistic constraints.
-description: Implicit update frequencies can introduce ambiguity in the interpretation of model-based reinforcement learning benchmarks, obscuring the real objective of the evaluation. While the update frequency can sometimes be optimized to improve performance, real-world applications often impose constraints, allowing updates only between deployments on the actual system. This blog post emphasizes the need for evaluations using consistent update frequencies across different algorithms to provide researchers and practitioners with clearer comparisons under realistic constraints.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Albert Thomas
- url: https://albertcthomas.github.io/
- affiliations:
- name: Huawei Noah's Ark Lab
- - name: Abdelhakim Benechehab
- url: https://scholar.google.com/citations?user=JxgqOKwAAAAJ
- affiliations:
- name: Huawei Noah's Ark Lab - Department of Data Science, EURECOM, France
- - name: Giuseppe Paolo
- url: https://www.giupaolo.com
- affiliations:
- name: Huawei Noah's Ark Lab
- - name: Balázs Kégl
- url: https://twitter.com/balazskegl
- affiliations:
- name: Huawei Noah's Ark Lab
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-update-frequency-in-mbrl.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Introduction
- - name: Three popular model-based reinforcement learning algorithms
- subsections:
- - name: MBPO
- - name: PETS
- - name: BREMEN
- - name: Making the update frequency more accessible
- - name: Comparisons with fixed update frequency
- - name: Ablation studies
- subsections:
- - name: Varying the update frequency in MBPO
- - name: Conclusion
- - name: Appendix
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-## Introduction
-
-In reinforcement learning , an agent learns to make decisions by interacting with an environment, receiving a feedback, or reward, following each action it takes to move from a state of the environment to another. The objective is to learn a policy, a mapping from states to action, that maximizes the expected cumulative reward over successive interactions.
-
-There are two main approaches when designing a reinforcement learning algorithm: model-based or model-free. Model-based reinforcement learning (MBRL) algorithms first learn a model of the environment dynamics which, given a state of the environment and an action, predicts the next state of the environment. This model can then be used in place of the real environment to learn or decide how to act. Model-free algorithms avoid this step and directly try to learn a policy. As MBRL algorithms can rely on the learned dynamics model instead of the real environment, they are known to be more sample efficient than model-free algorithms (see for instance or ). MBRL is thus a good choice when interactions with the environment are limited, which is often the case for real applications such as controlling engineering systems.
-
-We discuss here about one of the design choices of MBRL algorithms: the *update frequency* of the agent. As shown in the figure below This figure is inspired by Figure 1 in ., the frequency at which algorithms update their agent varies widely: some algorithms update their agent after each step on the real system while others update after thousands of steps . At the end of the spectrum, the pure offline setting considers only a single training of the agent from an initial dataset We observe that similar differences in update frequency exist in the model-free literature but we decide to focus only on model-based algorithms..
-
-{% include figure.html path="assets/img/2024-05-07-update-frequency-in-mbrl/bremen.png" class="img-fluid" %}
-
-The update frequency is often viewed as yet another hyperparameter of the complex MBRL pipeline. However, in practice the update frequency may be imposed by real-life deployment constraints, motivating the discussions of this blog post. It is often the case that for safety reasons, system engineers agree to run a new agent on their system for a given period of time but prefer the agent to be fixed during this deployment, as studies. System engineers are then able to investigate the fixed solution before deciding to deploy it, knowing that it will not change during the deployment. It also happens that the system on which the agent is deployed does not have the required computational resources to support agent updates. Such real-life constraints could thus discard state-of-the-art MBRL algorithms that require updating their agent too frequently to perform well.
-
-Given the importance of the update frequency in real-life applications, this blog post advocates for:
-- explicitly specifying the update frequency employed by each algorithm in a benchmark, as this remains implicit and hard to find in many existing benchmarks,
-- conducting additional experiments that compare algorithms under a given update frequency, mirroring the constraints often encountered in real-life applications, and
-- performing more ablation studies on update frequency, evaluating its impact on algorithm performance.
-
-For the rest of this blog post, we define a *deployment* as a data collection campaign realized with a fixed agent. The agents are thus updated between two consecutive deployments but not within one deployment. The *update frequency* is the number of steps realized at each deployment (that we assume fixed for all deployments). We use the term *agent* to refer to all the components of the model-based algorithm that are used to act on the system. For instance, in a Dyna-style algorithm , where a model-free algorithm is applied on the model instead of the real system, *agent* would thus refer to both the dynamics model and the policy learned with a model-free algorithm.
-
-We begin by introducing three popular MBRL algorithms (MBPO, PETS and BREMEN) as we will often refer to them to illustrate our arguments.
-
-## Three popular MBRL algorithms
-
-The following table gives an overview of the update frequency of the three algorithms we discussed below and few others. This table is not meant to provide an exhaustive list of all the MBRL algorithms but rather to give an idea of the different training schedules that are used in the literature.
-
-| Algorithm | Agent update frequency | Policy update frequency | Model update frequency |
-|-----------|----------------------|---------------------------|------------------------|
-| MBPO | 1 step | 1 step | 250 steps |
-| PETS | Task Horizon | No policy | Task Horizon |
-| PILCO | Task Horizon | Task Horizon | Task Horizon |
-| BREMEN | 100k or 200k steps | 100k or 200k steps | 100k or 200k steps |
-| ME-TRPO | 3k or 6k steps | 3k or 6k steps | 3k or 6k steps |
-
-
-### MBPO
-Model-based Policy Optimization (MBPO) Original code available at https://github.com/jannerm/mbpo is one of the most well-known model-based algorithms. The algorithm trains an ensemble of probabilistic neural networks for the dynamics model and trains a model-free agent, Soft Actor Critic (SAC) , using short rollouts on the model to avoid error accumulation. The agent is updated at each step: the model is updated each 250 steps but the SAC policy is updated at each step. This highly frequent update schedule discards MBPO even for small deployments on real systems.
-
-### PETS
-Probabilistic Ensemble and Trajectory Sampling (PETS) Original code available at https://github.com/kchua/handful-of-trials is another popular model-based algorithm known for its use of an ensemble of probabilistic neural networks for the dynamics model (MBPO uses the dynamics model introduced by PETS). PETS relies on the learned model and the Cross-Entropy Method to search for the best action sequence at decision time. Therefore, it does not have to learn (nor update) a policy, as MBPO does with SAC. The only component that needs learning is the dynamics model. Compared to MBPO, the dynamics model is updated at the end of each episode (usually 1000 steps).
-
-
-### BREMEN
-Behavior-Regularized Model-ENsemble (BREMEN) Original code available at https://github.com/matsuolab/BREMEN considers the setting where only a few deployments (between 5 to 10) are possible on the real system. However large datasets can be collected at each deployment (they assume 100 000 or 200 000 transitions for each deployment, far more than just one episode which is usually of the order of 1000 transitions). The algorithm relies on an ensemble of deterministic dynamics models and a policy learned on the model, à la Dyna-Style. It only updates the policy and the model between two consecutive deployments. The update frequency is here very clear as it is motivated by real-life applications where deployments are limited. Therefore in this algorithm this is not an hyperparameter that can be tuned for better performance but rather a parameter imposed by the application. One of the goals of the blog post is to emphasize and to develop the idea of a constrained update frequency.
-
-We now detail the main arguments of our blog post: making the update frequency more accessible, designing benchmarks with fixed update frequencies and running ablation studies on the update frequency.
-
-## Making the update frequency more accessible
-
-Experiments done in popular papers do not always explicit the update frequencies they use for each of the algorithms they run. When nothing is said, it is very likely that most of the times the benchmarks are using the original implementation of the algorithms, shared by the authors of the algorithms in the best case. For instance the MBPO paper does not mention the update frequencies that the authors used in their experiments. The update frequency of MBPO can be found in the code shared by the authors. However it is harder to find the update frequency that the authors used for PETS. We thus assume that they use the original PETS update frequency, which updates the agent at the end of each episode. We also looked at one of the most exhaustive benchmark of MBRL algorithms . Nothing is said in the paper about the update frequency and a careful investigation of the code provided by the authors is required (more on this later).
-
-The difficulty in knowing the update frequencies used in benchmarks makes it harder for the researchers and practitioners to take this parameter into account to assess the performance of the algorithms and whether they would be good candidates for their real-life applications. It also demands much more investigation from the reader to know what the authors used.
-
-MBRL algorithms have an order of magnitude more meaningful hyperparameters than supervised models, and managing and reporting on them usually falls out of the scope of research papers. The practice of sharing the code alleviates this issue somewhat, and should be saluted, since we can always dig up in the code what the parameters were. However, ideally, choices that drastically change the performance of the algorithms, should be made explicit as much as possible in the research papers and the ablation studies.
-
-## Comparisons with fixed update frequency
-
-We want to make the community aware of the importance of the update frequency when comparing algorithms and when designing benchmarks. Running benchmarks without any constraints allows using different update frequencies for each algorithm. We believe that such benchmarks are valuable for the community. However it would also be very informative for the community to have benchmarks with comparable update frequencies between the algorithms. This would for instance help to find the potentially best algorithms for real applications with constraints on the update frequency.
-
-Coming back to the experiments run in MBPO's paper, as the default MBPO implementation updates the model each 250 steps, it might also make sense to allow PETS to be updated each 250 steps as well to have comparable results. We also note that the MBRL-Lib paper compares the MBRL-Lib implementations of PETS and MBPO with their respective original update frequency. We do not think that this would have a big impact for these two algorithms but it would be fairer to use the same update frequency. Finally, looking at the code of the MBRL benchmark done by , it is not clear whether the same update frequency is used for all the algorithms of the benchmark For instance it seems the update frequency on Acrobot is 3000 for RS (time_step_per_batch in https://github.com/WilsonWangTHU/mbbl/blob/master/scripts/exp_1_performance_curve/rs.sh) but 5000 for ME-TRPO (num_path_onpol $\times$ env_horizon in https://github.com/WilsonWangTHU/mbbl-metrpo/blob/master/configs/params_acrobot.json)..
-
-The BREMEN paper has a benchmark comparing different algorithms under fixed update frequencies. This gives valuable insights on the performance of the existing algorithms under these deployment constraints. The next step would be to evaluate the performance with a different number of deployments and a different number of steps per deployment, which we now argue for in the next section.
-
-## Ablation studies
-
-Comparisons of different update frequencies are very rare in existing benchmarks and existing papers. Even without real-life constraints it would be valuable to know how sensitive the performance of a given algorithm is with respect to the update frequency. The issue for the authors is that this could be asked for many other hyperparameters and represent additional computational budget and time. However we often find ablations on the number of models (if the model is an ensemble), the rollout length, the number of gradient updates for the model-free policy, but very rarely on the update frequency. It is very likely that the agents that are good for small deployments would be bad for large deployments, a setting that would tend to be closer to the pure offline setting (for the same total budget of real system interactions). We perform such an ablation study using MBPO in the next section, showing that MBPO's performance is degrading with larger update frequencies.
-
-
-### Varying the update frequency in MBPO
-
-Using the MBPO implementation and the examples provided by MBRL-Lib we ran MBPO on Gym-Halfcheetah-v4, Gym-Hopper-v4 and Gym-Walker2d-v4 with different update frequencies: updating the agent at each step (default implementation described above), each 1000 steps, each 5000 steps and each 10 000 steps. Each curve shows the mean episode return obtained with at least 10 seeds. We did not run Hopper and Walker with an update frequency of 10 000 steps as the performance obtained with 5000 was already poor. The lightly shaded areas indicate the 95% bootstrap confidence interval.
-
-{% include figure.html path="assets/img/2024-05-07-update-frequency-in-mbrl/update_frequency_cheetah.png" class="img-fluid" %}
-
-{% include figure.html path="assets/img/2024-05-07-update-frequency-in-mbrl/update_frequency_hopper.png" class="img-fluid" %}
-
-{% include figure.html path="assets/img/2024-05-07-update-frequency-in-mbrl/update_frequency_walker.png" class="img-fluid" %}
-
-Except for the update frequency of 1000 steps on Halfcheetah and Walker which achieves similar performance than the default configuration updating the agent at each step, the results indicate a decline in asymptotic performance with larger update frequencies. Although MBPO exhibits good performance over different environments for the default update frequency, this is not the case for the other update frequencies that we consider here. We note here that 1000 steps is the usual maximum episode length and therefore a reasonable value to try for the update frequency. One insight from this experiment is that even though MBPO is one of the state-of-the-art MBRL algorithms, practical constraints like the update frequency can potentially alleviate its performance in real-world applications.
-
-When trying these values of updates frequencies we adjusted the number of gradient steps to maintain a constant ratio of gradient steps per step on the real system. For the maximum buffer size of SAC we used the rule provided in MBPO's code. The table below shows the values obtained for the maximum buffer size. As shown in the figure below, using a smaller buffer size negatively impacts the performance for the update frequency of 1000 steps and 10 000 steps. While there is a possibility that better values for the hyperparameters (other than the update frequency) could be found, we did what appeared to be the natural way to adapt the other hyperparameters when increasing the update frequency. See the Appendix for the complete description of the hyperparameters used in these experiments.
-
-| Agent update frequency | Model update frequency | Policy update frequency | Max SAC buffer size |
-|------------------|--------------------------|-----------------------------------|-------------|
-|default (1 step) | 250 | 1 | 400 000 |
-| 1 000 steps | 1000 | 1000 | 400 000 |
-| 5 000 steps | 5000 | 5000 | 2 million |
-|10 000 steps | 10 000 | 10 000 | 4 million |
-
-{% include figure.html path="assets/img/2024-05-07-update-frequency-in-mbrl/buffer_size.png" class="img-fluid" %}
-
-
-## Conclusion
-
-The goal of this blog post is to shed light on a frequently overlooked hyperparameter in MBRL: the update frequency. Despite its importance for real-life applications, this parameter is rarely discussed or analyzed. We emphasize the importance of running more evaluations using consistent update frequencies across different algorithms and more ablation studies. We for instance show how the update frequency impacts the performance of MBPO. Similar to the update frequency, we can identify several other hyperparameters that deserve more attention when benchmarking different MBRL algorithms. A typical example is the continual training (of the model and/or policy) versus retraining from scratch (referred to as the primacy bias in some previous work ). We believe this blog post offers valuable insights to researchers, providing directions that would be worth investigating to explain the differences between MBRL algorithms and whether these differences really impact the existing comparisons.
-
-
-## Appendix
-
-We provide here the configuration files we used to run the different experiments.
-#### Halfcheetah
-* Update frequency of 1000 steps
-
-```yaml
-# @package _group_
-env: "gym___HalfCheetah-v4"
-term_fn: "no_termination"
-
-num_steps: 400000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 1000
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 1]
-num_sac_updates_per_step: 10000
-sac_updates_every_steps: 1000
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_target_entropy: -1
-sac_hidden_size: 512
-sac_lr: 0.0003
-sac_batch_size: 256
-```
-
-* Update frequency of 5000 steps
-
-```yaml
-# @package _group_
-env: "gym___HalfCheetah-v4"
-term_fn: "no_termination"
-
-num_steps: 400000
-epoch_length: 5000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 5000
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 1]
-num_sac_updates_per_step: 50000
-sac_updates_every_steps: 5000
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_target_entropy: -1
-sac_hidden_size: 512
-sac_lr: 0.0003
-sac_batch_size: 256
-```
-
-* Update frequency of 10000 steps
-
-```yaml
-# @package _group_
-env: "gym___HalfCheetah-v4"
-term_fn: "no_termination"
-
-num_steps: 400000
-epoch_length: 10000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 10000
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 1]
-num_sac_updates_per_step: 100000
-sac_updates_every_steps: 10000
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 1
-sac_automatic_entropy_tuning: true
-sac_target_entropy: -1
-sac_hidden_size: 512
-sac_lr: 0.0003
-sac_batch_size: 256
-```
-
-#### Hopper
-* Update frequency of 1000 steps
-
-```yaml
-# @package _group_
-env: "gym___Hopper-v4"
-term_fn: "hopper"
-
-num_steps: 125000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 1000
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 15]
-num_sac_updates_per_step: 40_000
-sac_updates_every_steps: 1000
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: 1 # ignored, since entropy tuning is false
-sac_hidden_size: 512
-sac_lr: 0.0003
-sac_batch_size: 256
-```
-
-* Update frequency of 5000 steps
-
-```yaml
-# @package _group_
-env: "gym___Hopper-v4"
-term_fn: "hopper"
-
-num_steps: 125000
-epoch_length: 1000
-num_elites: 5
-patience: 5
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 5000
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 15]
-num_sac_updates_per_step: 200000
-sac_updates_every_steps: 5000
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: 1 # ignored, since entropy tuning is false
-sac_hidden_size: 512
-sac_lr: 0.0003
-sac_batch_size: 256
-```
-
-#### Walker
-* Update frequency of 1000 steps
-
-```yaml
-# @package _group_
-env: "gym___Walker2d-v4"
-term_fn: "walker2d"
-
-num_steps: 300000
-epoch_length: 1000
-num_elites: 5
-patience: 10
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 1000
-effective_model_rollouts_per_step: 400
-rollout_schedule: [20, 150, 1, 1]
-num_sac_updates_per_step: 20000
-sac_updates_every_steps: 1000
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: -1 # ignored, since entropy tuning is false
-sac_hidden_size: 1024
-sac_lr: 0.0001
-sac_batch_size: 256
-```
-
-* Update frequency of 5000 steps
-We only used a maximum buffer size of 1 million to limit the memory usage of this experiment.
-
-```yaml
-# @package _group_
-env: "gym___Walker2d-v4"
-term_fn: "walker2d"
-
-num_steps: 300000
-epoch_length: 1000
-num_elites: 5
-patience: 10
-model_lr: 0.001
-model_wd: 0.00001
-model_batch_size: 256
-validation_ratio: 0.2
-freq_train_model: 5000
-effective_model_rollouts_per_step: 200
-rollout_schedule: [20, 150, 1, 1]
-num_sac_updates_per_step: 100000
-sac_updates_every_steps: 5000
-num_epochs_to_retain_sac_buffer: 1
-
-sac_gamma: 0.99
-sac_tau: 0.005
-sac_alpha: 0.2
-sac_policy: "Gaussian"
-sac_target_update_interval: 4
-sac_automatic_entropy_tuning: false
-sac_target_entropy: -1 # ignored, since entropy tuning is false
-sac_hidden_size: 1024
-sac_lr: 0.0001
-sac_batch_size: 256
-```
diff --git a/_posts/2024-05-07-what-exactly-has-tabpfn-learned-to-do.md b/_posts/2024-05-07-what-exactly-has-tabpfn-learned-to-do.md
deleted file mode 100644
index 06ece9e7..00000000
--- a/_posts/2024-05-07-what-exactly-has-tabpfn-learned-to-do.md
+++ /dev/null
@@ -1,218 +0,0 @@
----
-layout: distill
-title: What exactly has TabPFN learned to do?
-description: TabPFN [Hollmann et al., 2023], a Transformer model pretrained to perform in-context learning on fresh tabular classification problems, was presented at the last ICLR conference. To better understand its behavior, we treat it as a black-box function approximator generator and observe its generated function approximations on a varied selection of training datasets. Exploring its learned inductive biases in this manner, we observe behavior that is at turns either brilliant or baffling. We conclude this post with thoughts on how these results might inform the development, evaluation, and application of prior-data fitted networks (PFNs) in the future.
-date: 2024-05-07
-future: true
-htmlwidgets: true
-
-authors:
- - name: Calvin McCarter
- url: "https://calvinmccarter.com/"
- affiliations:
- name: BigHat Biosciences
-
-# must be the exact same name as your blogpost
-bibliography: 2024-05-07-what-exactly-has-tabpfn-learned-to-do.bib
-
-# Add a table of contents to your post.
-# - make sure that TOC names match the actual section names
-# for hyperlinks within the post to work correctly.
-# - please use this format rather than manually creating a markdown table of contents.
-toc:
- - name: Introduction
- - name: 1d binary classification
- - name: 2d multiclass classification
- - name: Cancer status classification from high-dimensional gene expressions
- - name: Computer vision as a tabular classification problem
- - name: Closing thoughts
-
-# Below is an example of injecting additional post-specific styles.
-# This is used in the 'Layouts' section of this post.
-# If you use this post as a template, delete this _styles block.
-_styles: >
- .fake-img {
- background: #bbb;
- border: 1px solid rgba(0, 0, 0, 0.1);
- box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1);
- margin-bottom: 12px;
- }
- .fake-img p {
- font-family: monospace;
- color: white;
- text-align: left;
- margin: 12px 0;
- text-align: center;
- font-size: 16px;
- }
----
-
-
-## Introduction
-
-TabPFN is a deep learning model pretrained to perform in-context learning for tabular classification.
-Since then, it has attracted attention both for its high predictive performance on small dataset benchmarks and for its unique meta-learning approach.
-This meta-learning approach, which builds upon earlier work on prior-data fitted networks (PFN) , requires only synthetically-generating data: structural causal models (SCMs) are randomly generated, then training datasets are sampled from each SCM.
-On fresh classification tasks, no training (i.e. weight updating) is needed; instead, training data is given as context to TabPFN, a Transformer model with self-attention among training samples and cross-attention from test samples to training samples.
-TabPFN can be optionally used with ensembling, wherein the forward pass is repeated with random permutations of features and class labels, and with power transformation applied to random subsets of features.
-Subsequent works have reproduced its classification performance on other tabular benchmarks , and analyzed its theoretical foundations .
-
-At the same time, TabPFN has received criticism from within the applied ML community, around concerns that its "one large neural network is all you need" approach is fundamentally flawed and that its performance on public benchmarks may be due to overfitting.
-
-{% twitter https://twitter.com/tunguz/status/1583417038965334017 %}
-
-{% twitter https://twitter.com/predict_addict/status/1726286748173385732 %}
-
-
-In this article, we will attempt to demystify TabPFN's behavior in order to move towards a resolution to these questions.
-With this goal, we will take a different tack to analyzing TabPFN than previous works:
-we will neither theoretically analyze its meta-learning pre-training approach, nor run it on yet another dataset-of-datasets, nor even mechanistically interpret the meaning of specific model weights or subnetworks.
-
-Instead, we will first explore its holistic behavior on two simple settings, in order to develop an intuition about TabPFN as a function approximation generator.
-This is motivated by the observation that TabPFN once fitted on fresh training data (even though "fitting" is merely storing the training data), is not mathematically different from any other fitted model: it is simply a function $$f_{\mathcal{D}, \theta}: x \rightarrow y$$ from test input $$x$$ to prediction $$y$$,
-where $$\mathcal{D} = (X_{\textrm{train}}, y_{\textrm{train}})$$ is the training data and $$\theta$$ are the TabPFN model weights.
-By plotting $$f$$ for various case studies of $$(X_{\textrm{train}}, y_{\textrm{train}})$$, we aim to better understand what statistical knowledge has been represented in model parameters $$\theta$$.
-
-Next, we will evaluate TabPFN on two non-standard tabular ML classification tasks, comparing its performance with other methods.
-These atypical tasks can be thought of as out-of-distribution relative to the synthetic pretraining datasets upon which TabPFN was pretrained.
-This analysis will help indicate whether TabPFN was overfit to the statistical peculiarities of publicly-available small tabular datasets, or whether it has learned generalizable principles that lead to sensible behavior even in out-of-domain settings.
-
-## 1d binary classification
-
-We begin by examining the case of binary classification with 1d inputs. To better illustrate the inductive biases of the base TabPFN model, we do not use ensembling in this section unless otherwise indicated.
-
-Below, we show the predictions for two training samples located at +1 and -1, labeled green and red, respectively. We see that the probabilities are non-monotonic, as one would obtain from a sigmoid function; not only do we see that the model has higher uncertainty on the far sides of the training points, we see that between them there is a small wiggle. We also see that the decision boundary biased below 0.5; likely this is because TabPFN has learned that features are have right-skewed distributions.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-nonmonotone.png" class="img-fluid" %}
-
-These wiggles and asymmetry more-or-less disappear once we incorporate ensembling, shown below.
-However, the general shape of the predicted probability function is similar regardless of the number of ensembles.
-
-
- {% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-ensembles2.png" class="img-fluid" %}
-
-
- TabPFN predicted probabilities for test data, in red and green, for varying number of ensembles. Also shown are the predicted probabilities from using inverse-square-root of Euclidean distance within softmax, in orange and lime-green.
-
-
-The above results raise the question of what parametric attention function might have been learned by TabPFN.
-No simple dot-product-based or Euclidean distance-based function (used within the softmax operation) exactly recapitulated the observed predicted probabilities.
-However, the general shape of inverse-square-root of Euclidean distance matched reasonably well, particularly between the two training points.
-Still, it appears that TabPFN has meta-learned an attention function that outperforms previously-known attention functions on small datasets.
-
-Next, we look at the effect of duplicating features. We tried repeating the +1 and -1 inputs for a total of 1, 4, 16, and 64 copies, as shown below. The effect is to push the predicted probabilities away from 0.5, although we observe diminishing marginal effects as the number of repeats increases.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-repeats.png" class="img-fluid" %}
-
-Meanwhile, there is no discernible effect from replicating samples, when both red and green samples are replicated. Below we show the predicted probabilities, when both red and green samples are each copied for a total of 1, 4, 16, and 64 times.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-repeatboth.png" class="img-fluid" %}
-
-In contrast, there is an impact to repeating only the red sample.
-Below is shown the effect of repeating only the red sample.
-While this unsurprisingly increases the probability of red for $$X < 0$$, it bizarrely increases the probability of green for $$X > 0$$.
-This is especially strange because repeating green samples in the previous setting did not have the same effect.
-This behavior of TabPFN seems suboptimal; it remains to be seen whether this behavior was optimal for its pretraining data, or whether this is some kind of artifact of TabPFN's architecture or training.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-repeatred.png" class="img-fluid" %}
-
-Finally, we were unable to find evidence that TabPFN is able to detect periodic patterns in the training data, as exemplified for three different training patterns shown below.
-This behavior of TabPFN suggests that it does not support either periodic interpolation or extrapolation.
-Furthermore, we observe that as the number of observed cycles in the data increases, the predicted probabilities trend toward 0.5, which also seems suboptimal.
-We also notice that there is marked left-right asymmetry in these settings.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-periodic-2pair.png" class="img-fluid" %}
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-periodic-3pair.png" class="img-fluid" %}
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/plusminus1-periodic-2skippair.png" class="img-fluid" %}
-
-
-## 2d multiclass classification
-
-Here, we examine the behavior of TabPFN on 2d input data, on problems with as many samples as classes.
-Below we show results for both randomly-spaced and grid-spaced inputs, and for both ensembling and no-ensembling settings of TabPFN.
-In each plot, we show the training data, their corresponding Voronoi diagrams, and finally the model predictions for the test inputs.
-We see that, without ensembling, TabPFN performs quite poorly, partitioning the input space in a non-sensical manner.
-The results markedly improve when we use 32 ensembles.
-Particularly for the randomly-spaced training points, the model predictions clearly resemble the Voronoi diagram, suggesting that (ensembled) TabPFN has meta-learned to perform 1-nearest-neighbor classification in the setting where each class has a single training sample.
-
-On the other hand, that this behavior relies upon ensembling suggests that the base TabPFN model could be further improved.
-In the original paper, Hollmann et al. express the hope that a future better version of TabPFN would not need to rely upon ensembling for permutation invariance, by having internalized that behavior through better architecture and training.
-The aforementioned observed behavior suggests that ensembling improves performance not only by (approximately) enforcing permutation invariance, but also by producing lower variance estimators; if so, the base model could also be trained to do the latter directly.
-
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/voronois.png" class="img-fluid" %}
-
-
- TabPFN predictions on randomly-spaced points (left) and grid-spaced points (right). The training points are depicted as $\times$s. The yellow lines depict the Voronoi diagram of the training points. The test points are colored by TabPFN's predictions, using the same color scheme as the training points. We see that, without ensembling, TabPFN's predicted classes do not form contiguous regions over the input space.
-
-
-
-## Cancer status classification from high-dimensional gene expressions
-
-We now turn towards a comparison of TabPFN with logistic regression (LR), support vector classification (SVC), and XGBoost on the BladderBatch cancer status classification task.
-The bladderbatch dataset consists of 57 samples, 22,283 gene expression features, and 3 classes ("normal" vs "biopsy" vs "cancer").
-This is an extremely high-dimensional problem compared to TabPFN's intended use for $$d \le 100$$; also, linear models tend to be sufficient for predicting cancer status given gene expressions.
-Thus, this setting is far outside the domain on which we would expect TabPFN to perform well, particularly if it had been overfit to small tabular datasets.
-Furthermore, the 57 samples come from 5 different batches of gene microarray measurements.
-This adds additional difficulty to the task, because there is confounded shift between the technical batch effect and the unequal proportions of cancer status in the different batches.
-
-For all methods, we do not perform hyperparameter search, in order to simulate the scenario where there are too few samples to perform cross-validation without the risk of overfitting.
-We use the scikit-learn implementations of LR and SVC with their default hyperparameters.
-For TabPFN, we use the default hyperparameter of 32 ensembles; we also enable feature subsampling as is required for $$d > 100$$ problems.
-
-Results are shown below, aggregated over 10 random 75-25 train-test splits, and evaluated via both accuracy and macro-averaged F1-score.
-TabPFN has a surprisingly strong showing, handily beating SVC and XGBoost, while almost matching logistic regression.
-This pattern holds both when we use all features and also when we use only the first 1k features.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/bladderbatch-comparison.png" class="img-fluid" %}
-
-We also evaluate the different methods on a more realistic setting, where we train on 4 out of 5 batches of data and evaluate on all samples from the remaining unseen batch.
-Results are shown below, with scatterplot labels used to indicate the identity of the test batch.
-While all methods perform worse in this setting, TabPFN still almost matches LR while beating the other baselines.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/bladderbatch-batch-scatterplot.png" class="img-fluid" %}
-
-We also verify that TabPFN is not simply memorizing the class imbalance in favor of cancer.
-We compute confusion matrices, shown below, for each train-test split.
-Even though cancer is the most common class in every training split, there does not appear to be any systematic bias across the splits in favor of predicting cancer.
-
-{% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/bladderbatch-batch-confusion.png" class="img-fluid" %}
-
-
-## Computer vision as a tabular classification problem
-
-Finally, we compare TabPFN with other methods on two computer vision (CV) tasks.
-As in the previous section, we use the default hyperparameter settings for all methods.
-We treat MNIST and CIFAR-10 as tabular ML problems with $$28*28^2$$ and $$3*32^2$$ features, respectively.
-We aggregate over 10 train-test splits, where the test set is the full MNIST / CIFAR-10 test set, and the training set is a random subsample of size 30, 100, 300, and 1000.
-In this experiment, TabPFN was competitive for smaller training set sizes, but lagged as we trained on more samples.
-Interestingly, while for cancer classification SVC performed poorly, it performed well for large sample sizes on the CV tasks.
-Meanwhile, while logistic regression (LR) performed well on cancer classification, it struggled in the current setting.
-It remains to be seen whether the shared behavioral characteristics of TabPFN and LR in these tasks hold more generally.
-If so, this could motivate future work on meta-learning TabPFN to perform robust classification with a hinge-type loss.
-
-
-
- {% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/mnist-vs-samples.png" class="img-fluid" %}
-
-
- {% include figure.html path="assets/img/2024-05-07-what-exactly-has-tabpfn-learned-to-do/cifar10-vs-samples.png" class="img-fluid" %}
-
-
-
- Test accuracy on MNIST (left) and CIFAR-10 (right).
-
-
-## Closing thoughts
-
-Taken together, our preliminary results are suggestive of future developments in tabular PFNs. Currently, an applied ML practitioner will likely choose between training a model on their own small dataset and using the TabPFN "one size fits all" model. Our results suggest that TabPFN model will likely perform quite well, even outside its intended domain. However, it still came second-place to logistic regression on our cancer classification task and last or second-last on the CV classification problems. This suggests that the future will not look like a binary choice between training a non-PFN and selecting a single state-of-the-art tabular PFN. Rather, we suspect that there will exist PFNs for specific modalities of data (e.g. gene expression), or for specific settings (e.g. robust classification) that bridge the gap between the two extremes.
-
-In such a future, we believe our approach to evaluating TabPFN will become increasingly essential. In the burgeoning field of large language models (LLMs), evaluation on various public benchmarks is widely considered necessary but insufficient. LLM researchers and users will also evaluate a newly-announced model by trying their favorite personal examples on the new LLM. When the LLM fails on a prompt, one modifies the prompt slightly to see whether the LLM simply expected a different prompting style. When the LLM succeeds, one tries variants to see whether its satisfactory response was in fact brittle to the prompt. By interacting with an LLM, one gets a sense for its expected prompting style and the type of outputs it generates. In particular, providing out-of-distribution (adversarial) inputs (e.g. "poem poem poem") to an LLM tells us something useful about how it will operate on future unanticipated out-of-distribution inputs.
-
-By analogy, we argue that, while open tabular benchmarks are valuable resources, these should not be fully determinative for researchers and users of tabular ML methods. Benchmarks do allow us to quickly discover which methods are Pareto-dominated and can therefore be safely ignored. However, as we move into a world with multiple available PFN options, with different sorts of inductive priors, it will become increasingly useful to interact with them on simple problems to gain an intuition for whether their priors match one's own use-case. For our analysis on 1d inputs, it is important to notice that there is not necessarily one "right answer". Thus, evaluations of tabular ML approaches will need to be more granular than to describe TabPFN as state of the art for all of tabular ML. Instead, evaluations should aim at identifying specific tabular PFN checkpoints, based on different inductive priors and synthetic datasets, as being best suited for specific classes of problem settings.
-
-Furthermore, our results illuminate a key practical difference between TabPFN, which relies on in-context learning, and other neural network models for tabular ML. Skepticism around neural networks for tabular ML has been justified by problems stemming from the non-convexity of neural network training. Note that the problem (in the small dataset context) with neural net training non-convexity is not fundamentally about the fact that one may have missed a global optimum with better performance. Rather, deep learning requires babysitting during training runs and optimization of training hyperparameters which are unrelated to one's beliefs about the nature of one's specific problem. Thus, a modified architecture, preprocessing method, or data selection approach might be better matched for a particular dataset, but in the end perform worse due to problematic training dynamics -- which one might be unable to fix without risk of overfitting. In the small dataset regime, the maximum performance (over all training hyperparameter settings) matters less than the performance on the default hyperparameter settings.
-
-Because the overall approach of TabPFN obviates this problem with pure in-context learning, the fundamental weaknesses of other neural network approaches do not apply. For example, our 1d experiments would not have been straightforwardly possible if we had retrained a neural network on each reconfiguration of the training data. If we had done so while keeping the training hyperparameters fixed, it would not represent how people actually use such a neural network. On the other hand, if we had plotted results for carefully optimized hyperparameters, it is not clear whether the results would be illustrative of the general inductive biases of the neural network architecture, or merely of the behavior of an optimally-trained neural network. However, the flip side of this advantage of TabPFN is that our analysis applies not so much to TabPFN-the-method, as it does to [prior_diff_real_checkpoint_n_0_epoch_42.cpkt](https://github.com/automl/TabPFN/blob/d76f4ac7/tabpfn/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt)-the-checkpoint.
-
-Finally, we believe our evaluation helps address some of the popular skepticism around TabPFN. While our results indicate that there remains substantial room for improvement, we found no evidence that would suggest that TabPFN's results were solely the result of overfitting a large neural network to public benchmarks. Rather, our results suggest that TabPFN learns a simple "world model" of small-n statistical learning for tabular classification. This, in itself, makes TabPFN worthy of further careful empirical study.
diff --git a/_projects/1_project.md b/_projects/1_project.md
deleted file mode 100644
index 3f7cf783..00000000
--- a/_projects/1_project.md
+++ /dev/null
@@ -1,80 +0,0 @@
----
-layout: page
-title: project 1
-description: a project with a background image
-img: assets/img/12.jpg
-importance: 1
-category: work
----
-
-Every project has a beautiful feature showcase page.
-It's easy to include images in a flexible 3-column grid format.
-Make your photos 1/3, 2/3, or full width.
-
-To give your project a background in the portfolio page, just add the img tag to the front matter like so:
-
- ---
- layout: page
- title: project
- description: a project with a background image
- img: /assets/img/12.jpg
- ---
-
-
- Caption photos easily. On the left, a road goes through a tunnel. Middle, leaves artistically fall in a hipster photoshoot. Right, in another hipster photoshoot, a lumberjack grasps a handful of pine needles.
-
- This image can also have a caption. It's like magic.
-
-
-You can also put regular text between your rows of images.
-Say you wanted to write a little bit about your project before you posted the rest of the images.
-You describe how you toiled, sweated, *bled* for your project, and then... you reveal its glory in the next row of images.
-
-
-
- You can also have artistically styled 2/3 + 1/3 images, like these.
-
-
-
-The code is simple.
-Just wrap your images with `
` and place them inside `
` (read more about the Bootstrap Grid system).
-To make images responsive, add `img-fluid` class to each; for rounded corners and shadows use `rounded` and `z-depth-1` classes.
-Here's the code for the last row of images above:
-
-{% raw %}
-```html
-
-```
-{% endraw %}
diff --git a/_projects/2_project.md b/_projects/2_project.md
deleted file mode 100644
index bebf7961..00000000
--- a/_projects/2_project.md
+++ /dev/null
@@ -1,80 +0,0 @@
----
-layout: page
-title: project 2
-description: a project with a background image
-img: assets/img/3.jpg
-importance: 2
-category: work
----
-
-Every project has a beautiful feature showcase page.
-It's easy to include images in a flexible 3-column grid format.
-Make your photos 1/3, 2/3, or full width.
-
-To give your project a background in the portfolio page, just add the img tag to the front matter like so:
-
- ---
- layout: page
- title: project
- description: a project with a background image
- img: /assets/img/12.jpg
- ---
-
-
- Caption photos easily. On the left, a road goes through a tunnel. Middle, leaves artistically fall in a hipster photoshoot. Right, in another hipster photoshoot, a lumberjack grasps a handful of pine needles.
-
- This image can also have a caption. It's like magic.
-
-
-You can also put regular text between your rows of images.
-Say you wanted to write a little bit about your project before you posted the rest of the images.
-You describe how you toiled, sweated, *bled* for your project, and then... you reveal its glory in the next row of images.
-
-
-
- You can also have artistically styled 2/3 + 1/3 images, like these.
-
-
-
-The code is simple.
-Just wrap your images with `
` and place them inside `
` (read more about the Bootstrap Grid system).
-To make images responsive, add `img-fluid` class to each; for rounded corners and shadows use `rounded` and `z-depth-1` classes.
-Here's the code for the last row of images above:
-
-{% raw %}
-```html
-
-```
-{% endraw %}
diff --git a/_projects/3_project.md b/_projects/3_project.md
deleted file mode 100644
index 3f3cbf70..00000000
--- a/_projects/3_project.md
+++ /dev/null
@@ -1,81 +0,0 @@
----
-layout: page
-title: project 3
-description: a project that redirects to another website
-img: assets/img/7.jpg
-redirect: https://unsplash.com
-importance: 3
-category: work
----
-
-Every project has a beautiful feature showcase page.
-It's easy to include images in a flexible 3-column grid format.
-Make your photos 1/3, 2/3, or full width.
-
-To give your project a background in the portfolio page, just add the img tag to the front matter like so:
-
- ---
- layout: page
- title: project
- description: a project with a background image
- img: /assets/img/12.jpg
- ---
-
-
- Caption photos easily. On the left, a road goes through a tunnel. Middle, leaves artistically fall in a hipster photoshoot. Right, in another hipster photoshoot, a lumberjack grasps a handful of pine needles.
-
- This image can also have a caption. It's like magic.
-
-
-You can also put regular text between your rows of images.
-Say you wanted to write a little bit about your project before you posted the rest of the images.
-You describe how you toiled, sweated, *bled* for your project, and then... you reveal its glory in the next row of images.
-
-
-
- You can also have artistically styled 2/3 + 1/3 images, like these.
-
-
-
-The code is simple.
-Just wrap your images with `
` and place them inside `
` (read more about the Bootstrap Grid system).
-To make images responsive, add `img-fluid` class to each; for rounded corners and shadows use `rounded` and `z-depth-1` classes.
-Here's the code for the last row of images above:
-
-{% raw %}
-```html
-
-```
-{% endraw %}
diff --git a/_projects/4_project.md b/_projects/4_project.md
deleted file mode 100644
index edb5dd25..00000000
--- a/_projects/4_project.md
+++ /dev/null
@@ -1,80 +0,0 @@
----
-layout: page
-title: project 4
-description: another without an image
-img:
-importance: 3
-category: fun
----
-
-Every project has a beautiful feature showcase page.
-It's easy to include images in a flexible 3-column grid format.
-Make your photos 1/3, 2/3, or full width.
-
-To give your project a background in the portfolio page, just add the img tag to the front matter like so:
-
- ---
- layout: page
- title: project
- description: a project with a background image
- img: /assets/img/12.jpg
- ---
-
-
- Caption photos easily. On the left, a road goes through a tunnel. Middle, leaves artistically fall in a hipster photoshoot. Right, in another hipster photoshoot, a lumberjack grasps a handful of pine needles.
-
- This image can also have a caption. It's like magic.
-
-
-You can also put regular text between your rows of images.
-Say you wanted to write a little bit about your project before you posted the rest of the images.
-You describe how you toiled, sweated, *bled* for your project, and then... you reveal its glory in the next row of images.
-
-
-
- You can also have artistically styled 2/3 + 1/3 images, like these.
-
-
-
-The code is simple.
-Just wrap your images with `
` and place them inside `
` (read more about the Bootstrap Grid system).
-To make images responsive, add `img-fluid` class to each; for rounded corners and shadows use `rounded` and `z-depth-1` classes.
-Here's the code for the last row of images above:
-
-{% raw %}
-```html
-
-```
-{% endraw %}
diff --git a/_projects/5_project.md b/_projects/5_project.md
deleted file mode 100644
index efd9b6cf..00000000
--- a/_projects/5_project.md
+++ /dev/null
@@ -1,80 +0,0 @@
----
-layout: page
-title: project 5
-description: a project with a background image
-img: assets/img/1.jpg
-importance: 3
-category: fun
----
-
-Every project has a beautiful feature showcase page.
-It's easy to include images in a flexible 3-column grid format.
-Make your photos 1/3, 2/3, or full width.
-
-To give your project a background in the portfolio page, just add the img tag to the front matter like so:
-
- ---
- layout: page
- title: project
- description: a project with a background image
- img: /assets/img/12.jpg
- ---
-
-
- Caption photos easily. On the left, a road goes through a tunnel. Middle, leaves artistically fall in a hipster photoshoot. Right, in another hipster photoshoot, a lumberjack grasps a handful of pine needles.
-
- This image can also have a caption. It's like magic.
-
-
-You can also put regular text between your rows of images.
-Say you wanted to write a little bit about your project before you posted the rest of the images.
-You describe how you toiled, sweated, *bled* for your project, and then... you reveal its glory in the next row of images.
-
-
-
- You can also have artistically styled 2/3 + 1/3 images, like these.
-
-
-
-The code is simple.
-Just wrap your images with `
` and place them inside `
` (read more about the Bootstrap Grid system).
-To make images responsive, add `img-fluid` class to each; for rounded corners and shadows use `rounded` and `z-depth-1` classes.
-Here's the code for the last row of images above:
-
-{% raw %}
-```html
-
-```
-{% endraw %}
diff --git a/_projects/6_project.md b/_projects/6_project.md
deleted file mode 100644
index 9a95d6e8..00000000
--- a/_projects/6_project.md
+++ /dev/null
@@ -1,80 +0,0 @@
----
-layout: page
-title: project 6
-description: a project with no image
-img:
-importance: 4
-category: fun
----
-
-Every project has a beautiful feature showcase page.
-It's easy to include images in a flexible 3-column grid format.
-Make your photos 1/3, 2/3, or full width.
-
-To give your project a background in the portfolio page, just add the img tag to the front matter like so:
-
- ---
- layout: page
- title: project
- description: a project with a background image
- img: /assets/img/12.jpg
- ---
-
-
- Caption photos easily. On the left, a road goes through a tunnel. Middle, leaves artistically fall in a hipster photoshoot. Right, in another hipster photoshoot, a lumberjack grasps a handful of pine needles.
-
- This image can also have a caption. It's like magic.
-
-
-You can also put regular text between your rows of images.
-Say you wanted to write a little bit about your project before you posted the rest of the images.
-You describe how you toiled, sweated, *bled* for your project, and then... you reveal its glory in the next row of images.
-
-
-
- You can also have artistically styled 2/3 + 1/3 images, like these.
-
-
-
-The code is simple.
-Just wrap your images with `
` and place them inside `
` (read more about the Bootstrap Grid system).
-To make images responsive, add `img-fluid` class to each; for rounded corners and shadows use `rounded` and `z-depth-1` classes.
-Here's the code for the last row of images above:
-
-{% raw %}
-```html
-
The deadline for submitting an abstract and your final blog post have been updated to December 11th and 17th (00:00 GMT) respectively!
ICLR 2024 Blogposts Track
The Machine Learning community is currently experiencing a reproducibility crisis and a reviewing crisis [Littman, 2021]. Because of the highly competitive and noisy reviewing process of ML conferences [Tran et al., 2020], researchers have an incentive to oversell their results, slowing down the progress and diminishing the integrity of the scientific community. Moreover with the growing number of papers published and submitted at the main ML conferences [Lin et al., 2020], it has become more challenging to keep track of the latest advances in the field.
Blog posts are becoming an increasingly popular and useful way to talk about science [Brown and Woolston, 2018]. They offer substantial value to the scientific community by providing a flexible platform to foster open, human, and transparent discussions about new insights or limitations of a scientific publication. However, because they are not as recognized as standard scientific publications, only a minority of researchers manage to maintain an active blog and get visibility for their efforts. Many are well-established researchers (Francis Bach, Ben Recht, Ferenc Huszár, Lilian Weng) or big corporations that leverage entire teams of graphic designers designer and writers to polish their blogs (Facebook AI, Google AI, DeepMind, OpenAI). As a result, the incentives for writing scientific blog posts are largely personal; it is unreasonable to expect a significant portion of the machine learning community to contribute to such an initiative when everyone is trying to establish themselves through publications.
Last year, we ran the second iteration of the Blogpost track at ICLR 2023!
It was very successful, with accepted posts presented in person at the main conference.
Our goal is to create a formal call for blog posts at ICLR to incentivize and reward researchers to review past work and summarize the outcomes, develop new intuitions, or highlight some shortcomings. A very influential initiative of this kind happened after the Second World War in France. Because of the lack of up-to-date textbooks, a collective of mathematicians under the pseudonym Nicolas Bourbaki [Halmos 1957], decided to start a series of textbooks about the foundations of mathematics [Bourbaki, 1939]. In the same vein, we aim to provide a new way to summarize scientific knowledge in the ML community.
Due to the large diversity of topics that can be discussed in a blog post, we decided to restrict the range of topics for this call for blog posts. We identified that the blog posts that would bring to most value to the community and the conference would be posts that distill and discuss previously published papers.
Key Dates
Abstract deadline: December 11th 00:00GMT, 2023 (submit to OpenReview - to be announced soon).
Submission deadline: December 17th 00:00GMT, 2023 (any modifications to your blog post, via a pull request on GitHub).
Decision Notification: January 30th, 2024 UPDATED: February 15th, 2024
Camera-ready merge: March 15th, 2024
A call for blog posts discussing work previously published at ICLR
Content
Write a post on a subject that has been published at a top-tier venue (ICLR, ICML, NeurIPS, AAAI, UAI, CVPR, SIGGRAPH, ECCV, ICCV, etc.) relatively recently.
Conflict of interest
The authors of the blog posts will have to declare their conflicts of interest (positive or negative) with the paper (and the paper’s authors) they write about. Conflicts of interest include:
Recent collaborators (less than 3 years)
Current institution Reviewers will be asked to judge if the submission is sufficiently critical and objective of the papers addressed in the blog post.
Blog Posts must not be used to highlight or advertise past publications of the **authors or their lab**.
We will only ask the authors to report if they have a conflict of interest. If so, reviewers will be asked to judge if the submission is sufficiently critical and objective of the papers addressed in the blog post.
Publication
Blog post
The posts will be created and published under a unified template; see the submission instructions and the sample post hosted on the blog of this website.
Poster
Additionally, accepted posts will have the option to present their work as a poster during the main poster session. For more information about the main poster session (time, poster format, etc.) please refer to the ICLR homepage.
Submissions
Our goal is to avoid heavily engineered, professionally-made blog posts —Such as the “100+ hours” mentioned as a standard by the Distill guidelines—to entice ideas and clear writing rather than dynamic visualizations or embedded javascript engines. Please check our submission instructions for more details. We accept submissions in both Markdown and HTML. We believe this is a good trade-off between complexity and flexibility.